[英]How to save a specific variable in TensorFlow?
我建立了一个网络来测试要保存的模型。 这是我的代码:
import tensorflow as tf
import numpy as np
import time
dimensions=100
batch_size=128
def add_layer(inputs, in_size, out_size, activation_function=None):
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
def f(batch_size,val,dims):
a = np.zeros(batch_size,dtype=np.int32)+val
b = np.zeros((batch_size, dims))
b[np.arange(batch_size), a] = 1
return b
xs = tf.placeholder(tf.float32, [None, dimensions])
ys = tf.placeholder(tf.float32, [None, 43])
l1 = add_layer(xs, dimensions, 64, activation_function=None)
l2 = add_layer(l1, 64, 64, activation_function=tf.nn.sigmoid)
prediction = add_layer(l2, 64, 43, activation_function=None)
loss = tf.reduce_mean(tf.square(ys - prediction))
train_step = tf.train.AdamOptimizer(0.003).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for step in range(100):
start_time = time.time()
X = f(batch_size=batch_size,val=step,dims=dimensions)
y = np.random.rand(batch_size,43)
sess.run(train_step, feed_dict={xs:X, ys:y})
duration = time.time()-start_time
if step%10 == 0:
loss_value = sess.run(loss, feed_dict={xs: X, ys: y})
format_str = ('step %d,loss=%5.2f (%.1f examples/sec;%.3f sec/batch)')
print(format_str %(step,loss_value,batch_size/duration,float(duration)))
saver = tf.train.Saver()
save_path = saver.save(sess, "./save_net.ckpt")
sess.close()
它将所有变量保存到“ ./save_net.ckpt”。
但是我只想节省l1层的重量和偏差。 怎么做?
以及如何在TensorFlow中提取这些变量?
您应该看一下tensorflow文档。 变量
尤其是有关选择要保存和还原的变量的部分
在你的情况下
您应该将名称传递给创建权重和偏差的函数,以便声明为
Weights = tf.Variable(tf.random_normal([in_size, out_size]), name=weights_name)
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = biases_name)
接着
saver = tf.train.Saver({"l1_wieghts": "l1_weights_name",
"l1_biases": "l1_biases_name",
"l2_weights":"l2_weights_names",
"l2_biases":"l2_biases_name"})
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.