簡體   English   中英

Tensorflow保存和恢復模型的問題

[英]Issue with Tensorflow save and restore model

我正在嘗試使用轉移學習方法。 這是我的代碼正在通過訓練數據學習的代碼的快照:

max_accuracy = 0.0
    saver = tf.train.Saver()
    for epoch in range(epocs):
        shuffledRange = np.random.permutation(n_train)
        y_one_hot_train = encode_one_hot(len(classes), Y_input)
        y_one_hot_validation = encode_one_hot(len(classes), Y_validation)
        shuffledX = X_input[shuffledRange,:]
        shuffledY = y_one_hot_train[shuffledRange]
        for Xi, Yi in iterate_mini_batches(shuffledX, shuffledY, mini_batch_size):
            sess.run(train_step,
                     feed_dict={bottleneck_tensor: Xi,
                                ground_truth_tensor: Yi})
            # Every so often, print out how well the graph is training.
            is_last_step = (i + 1 == FLAGS.how_many_training_steps)
            if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
                train_accuracy, cross_entropy_value = sess.run(
                  [evaluation_step, cross_entropy],
                  feed_dict={bottleneck_tensor: Xi,
                             ground_truth_tensor: Yi})
                validation_accuracy = sess.run(
                  evaluation_step,
                  feed_dict={bottleneck_tensor: X_validation,
                             ground_truth_tensor: y_one_hot_validation})
                print('%s: Step %d: Train accuracy = %.1f%%, Cross entropy = %f, Validation accuracy = %.1f%%' %
                    (datetime.now(), i, train_accuracy * 100, cross_entropy_value, validation_accuracy * 100))
                result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name))
                probs = sess.run(result_tensor,feed_dict={'pool_3/_reshape:0': Xi[0].reshape(1,2048)})
                if validation_accuracy > max_accuracy :
                   saver.save(sess, 'models/superheroes_model')
                   max_accuracy = validation_accuracy
                   print(probs)
            i+=1  

這是我的代碼,這里是我加載模型的地方:

def load_model () :
    sess=tf.Session()    
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('models/superheroes_model.meta')
    saver.restore(sess,tf.train.latest_checkpoint('models/'))
    sess.run(tf.global_variables_initializer())
    result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name))  
    X_feature = features[0].reshape(1,2048)        
    probs = sess.run(result_tensor,
                         feed_dict={'pool_3/_reshape:0': X_feature})
    print probs
    return sess  

因此,現在對於同一數據點,在培訓和測試時我得到的結果完全不同。 它甚至不接近。 在測試期間,由於我有4個班級,因此我的概率接近25%。 但是在訓練過程中,最高的上課率是90%。
保存或還原模型時是否有任何問題?

小心-您正在打電話

sess.run(tf.global_variables_initializer())

打電話后

saver.restore(sess,tf.train.latest_checkpoint('models/'))

我之前也做過類似的事情,我認為這會重置您所有訓練有素的權重/偏見/等。 在還原的模型中。

如果需要,請在還原模型之前調用初始化程序,並且如果需要初始化已還原模型中的特定內容,請單獨進行。

刪除sess.run(tf.global_variables_initializer())在函數load_model ,如果你這樣做,你的所有訓練的參數將與該會為每個班級1/4概率的初始值被替換

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM