[英]How to load and predict a pre-trained tensorflow model into Java code?
[英]How to load a trained tensorflow model
我在加载张量流模型以测试一些新数据时遇到各种麻烦。 训练模型时,我使用了以下方法:
save_model_file = 'my_saved_model'
saver = tf.train.Saver()
save_path = saver.save(sess, save_model_file)
这似乎导致创建以下文件:
my_saved_model.meta
checkpoint
my_saved_model.index
my_saved_model.data-00000-of-00001
我不知道该注意哪些文件。
现在模型已经训练好了,在没有抛出异常的情况下,我似乎无法加载或使用它。 这是我在做什么:
def neural_net_data_input(data_shape):
theshape=(None,)+tuple(data_shape)
return tf.placeholder(tf.float32,shape=theshape,name='x')
def neural_net_label_input(n_out):
return tf.placeholder(tf.float32,shape=(None,n_out),name='one_hot_labels')
def neural_net_keep_prob_input():
return tf.placeholder(tf.float32,name='keep_prob')
def do_generate_network(x):
#
# here is where i generate the network layer by layer.
# this code works fine so i am not showing it here
#
pass
#
# Now I want to restore the model
#
tf.reset_default_graph()
input_data_shape=(32,32,1)
final_num_outputs=43
graph1 = tf.Graph()
with graph1.as_default():
x = neural_net_data_input(input_data_shape)
one_hot_labels = neural_net_label_input(final_num_outputs)
keep_prob=neural_net_keep_prob_input()
logits = do_generate_network(x)
# Name logits Tensor, so that is can be loaded from disk after training
logits = tf.identity(logits, name='logits')
#
# accuracy: we use this for validation testing
#
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
################################
# Evaluate
################################
new_data=myutils.load_pickle_file(SOME_DATA_FILE_NAME)
new_features=new_data['features']
new_one_hot_labels=new_data['labels']
print('Evaluating on new data...')
with tf.Session(graph=graph1) as sess:
# Initializing the variables
sess.run(tf.global_variables_initializer())
saver.restore(sess,save_model_file)
new_acc = sess.run(accuracy, feed_dict={x: new_features, one_hot_labels: new_one_hot_labels, keep_prob: 1.})
print('Testing Accuracy For New Images: {}'.format(new_acc))
但是当我这样做时,我得到了:
TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.
因此,我尝试像这样在会话内移动图形:
################################
# Evaluate
################################
print('Evaluating on web data...')
with tf.Session() as sess:
x = neural_net_data_input(input_data_shape)
one_hot_labels = neural_net_label_input(final_num_outputs)
keep_prob=neural_net_keep_prob_input()
logits = do_generate_network(x)
# Name logits Tensor, so that is can be loaded from disk after training
logits = tf.identity(logits, name='logits')
#
# accuracy: we use this for validation testing
#
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
sess.run(tf.global_variables_initializer())
my_save_dir="/home/carnd/CarND-Traffic-Sign-Classifier-Project"
load_model_meta_file=os.path.join(my_save_dir,"my_saved_model.meta")
load_model_path=os.path.join(my_save_dir,"my_saved_model")
new_saver = tf.train.import_meta_graph(load_model_meta_file)
new_saver.restore(sess, load_model_path)
web_acc = sess.run(accuracy, feed_dict={x: web_features, one_hot_labels: web_one_hot_labels, keep_prob: 1.})
print('Testing Accuracy For Web Images: {}'.format(web_acc))
现在,它可以正常运行,而不会产生错误,但是它打印的准确性结果是0.02! 我输入的数据与在培训期间获得95%的准确性相同。 所以看来我以某种方式不正确地加载了我的模型。
我究竟做错了什么?
加载训练后的模型的步骤:
加载图 :您可以使用tf.train.import_meta_graph()
加载图。 示例代码为:
model_path = "my_saved_model" inference_graph = tf.Graph() with tf.Session(graph= inference_graph) as sess: # Load the graph with the trained states loader = tf.train.import_meta_graph(model_path+'.meta') loader.restore(sess, model_path)
获得张量:通过使用get_tensor_by_name()
获得推断所需的张量。 因此,在模型中,请确保按名称命名张量,以便可以在推理过程中调用它。
#Get the tensors by their variable name _accuracy = inference_graph.get_tensor_by_name('accuracy:0') _x = inference_graph get_tensor_by_name('x:0') _y = inference_graph.get_tensor_by_name('y:0')
测试:可以通过使用加载的张量来完成。 sess.run(_accuracy, feed_dict={_x: ... , _y:...}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.