[英]Tensorflow: Saving a model to model.pb, to visualize it later
我找到了以下代码片段来可视化保存到*.pb
文件的模型:
model_filename ='saved_model.pb'
with tf.Session() as sess:
with gfile.FastGFile(path_to_model_pb, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='.'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
现在我正在努力创建saved_model.pb
。 如果我的session.run看起来像这样:
_, cr_loss = sess.run([train_op,cross_entropy_loss],
feed_dict={input_image: images,
correct_label: gt_images,
keep_prob: KEEP_PROB,
learning_rate: LEARNING_RATE}
)
如何保存包含在图形train_op
到saved_model.pb
?
最简单的方法是使用tf.train.write_graph
。 通常,您只需要执行以下操作:
tf.train.write_graph(my_graph, path_to_model_pb,
'saved_model.pb', as_text=False)
如果您使用默认图形或任何其他tf.Graph
(或tf.GraphDef
)对象, my_graph
可以是tf.get_default_graph()
。
请注意,这会保存图形定义,可以将其可视化,但如果您有变量,则除非您先冻结图形 ,否则它们的值将不会保存在那里(因为它们仅在会话对象中,而不是图形本身)。
我将逐步介绍此问题:
为了可视化变量,如权重,偏差使用tf.summary.histogram
weights = {
'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
tf.summary.histogram("weight1", weights['h1'])
tf.summary.histogram("weight2", weights['h2'])
tf.summary.histogram("weight3", weights['out'])
biases = {
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
tf.summary.histogram("bias1", biases['b1'])
tf.summary.histogram("bias2", biases['b2'])
tf.summary.histogram("bias3", biases['out'])
cost = tf.sqrt(tf.reduce_mean(tf.squared_difference(pred, y)))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
tf.summary.scalar('rmse', cost)
然后培训包括以下代码。
summaries = tf.summary.merge_all()
with tf.Session() as sess:
sess.run(init)
# Get data
writer = tf.summary.FileWriter("histogram_example", sess.graph)
# Training cycle
# Run optimization op (backprop) and cost op (to get loss value)
summ, p, _, c = sess.run([summ, pred, optimizer, cost], feed_dict={x: batch_x,
y: batch_y,})
writer.add_summary(summ, global_step=epoch*total_batch+i)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.