簡體   English   中英

Keras:TensorFlow 1.3模型在TensorFlow 1.4或更高版本下失敗(錯誤的預測)

[英]Keras: TensorFlow 1.3 model fails under TensorFlow 1.4 or later (wrong predictions)

我有一個使用tensorflow.contrib Python API在TensorFlow 1.3,Keras 2.0.6-tf上訓練的模型。 奇跡般有效。

但是,當我在TensorFlow 1.4(或更高版本)環境中加載模型時,預測是恆定的,即不正確。 沒有任何錯誤信息。

我要做的就是:

from tensorflow.contrib.keras.api.keras.models import load_model

model = load_model(..)
predictions  = model.predict(input, batch_size=batch_size)

獨立加載模型和權重,而不僅僅是加載模型.h5文件沒有任何區別。

這是一個已知問題嗎?如果是,是否有解決方法?

謝謝你的幫助。

這是模型的h5文件 如果它有助於解決這個難題,請參考以下模型摘要:

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_3 (InputLayer)             (None, 40, 256, 1)    0                                            
____________________________________________________________________________________________________
BN0 (BatchNormalization)         (None, 40, 256, 1)    4           input_3[0][0]                    
____________________________________________________________________________________________________
Conv1 (Conv2D)                   (None, 40, 256, 16)   96          BN0[0][0]                        
____________________________________________________________________________________________________
BN1 (BatchNormalization)         (None, 40, 256, 16)   64          Conv1[0][0]                      
____________________________________________________________________________________________________
Conv2 (Conv2D)                   (None, 40, 256, 16)   1296        BN1[0][0]                        
____________________________________________________________________________________________________
BN2 (BatchNormalization)         (None, 40, 256, 16)   64          Conv2[0][0]                      
____________________________________________________________________________________________________
Conv3 (Conv2D)                   (None, 40, 256, 16)   1296        BN2[0][0]                        
____________________________________________________________________________________________________
average_pooling2d_9 (AveragePool (None, 8, 256, 16)    0           Conv3[0][0]                      
____________________________________________________________________________________________________
BN3 (BatchNormalization)         (None, 8, 256, 16)    64          average_pooling2d_9[0][0]        
____________________________________________________________________________________________________
Conv4.1 (Conv2D)                 (None, 8, 256, 24)    12312       BN3[0][0]                        
____________________________________________________________________________________________________
Conv4.2 (Conv2D)                 (None, 8, 256, 24)    24600       BN3[0][0]                        
____________________________________________________________________________________________________
Conv4.3 (Conv2D)                 (None, 8, 256, 24)    36888       BN3[0][0]                        
____________________________________________________________________________________________________
Conv4.4 (Conv2D)                 (None, 8, 256, 24)    49176       BN3[0][0]                        
____________________________________________________________________________________________________
Conv4.5 (Conv2D)                 (None, 8, 256, 24)    73752       BN3[0][0]                        
____________________________________________________________________________________________________
Conv4.6 (Conv2D)                 (None, 8, 256, 24)    98328       BN3[0][0]                        
____________________________________________________________________________________________________
Concat.Conv4 (Concatenate)       (None, 8, 256, 144)   0           Conv4.1[0][0]                    
                                                                   Conv4.2[0][0]                    
                                                                   Conv4.3[0][0]                    
                                                                   Conv4.4[0][0]                    
                                                                   Conv4.5[0][0]                    
                                                                   Conv4.6[0][0]                    
____________________________________________________________________________________________________
Conv4.1x1 (Conv2D)               (None, 8, 256, 36)    5220        Concat.Conv4[0][0]               
____________________________________________________________________________________________________
average_pooling2d_10 (AveragePoo (None, 4, 256, 36)    0           Conv4.1x1[0][0]                  
____________________________________________________________________________________________________
BN4 (BatchNormalization)         (None, 4, 256, 36)    144         average_pooling2d_10[0][0]       
____________________________________________________________________________________________________
Conv5.1 (Conv2D)                 (None, 4, 256, 24)    27672       BN4[0][0]                        
____________________________________________________________________________________________________
Conv5.2 (Conv2D)                 (None, 4, 256, 24)    55320       BN4[0][0]                        
____________________________________________________________________________________________________
Conv5.3 (Conv2D)                 (None, 4, 256, 24)    82968       BN4[0][0]                        
____________________________________________________________________________________________________
Conv5.4 (Conv2D)                 (None, 4, 256, 24)    110616      BN4[0][0]                        
____________________________________________________________________________________________________
Conv5.5 (Conv2D)                 (None, 4, 256, 24)    165912      BN4[0][0]                        
____________________________________________________________________________________________________
Conv5.6 (Conv2D)                 (None, 4, 256, 24)    221208      BN4[0][0]                        
____________________________________________________________________________________________________
Concat.Conv5 (Concatenate)       (None, 4, 256, 144)   0           Conv5.1[0][0]                    
                                                                   Conv5.2[0][0]                    
                                                                   Conv5.3[0][0]                    
                                                                   Conv5.4[0][0]                    
                                                                   Conv5.5[0][0]                    
                                                                   Conv5.6[0][0]                    
____________________________________________________________________________________________________
Conv5.1x1 (Conv2D)               (None, 4, 256, 36)    5220        Concat.Conv5[0][0]               
____________________________________________________________________________________________________
average_pooling2d_11 (AveragePoo (None, 2, 256, 36)    0           Conv5.1x1[0][0]                  
____________________________________________________________________________________________________
BN5 (BatchNormalization)         (None, 2, 256, 36)    144         average_pooling2d_11[0][0]       
____________________________________________________________________________________________________
Conv6.1 (Conv2D)                 (None, 2, 256, 24)    27672       BN5[0][0]                        
____________________________________________________________________________________________________
Conv6.2 (Conv2D)                 (None, 2, 256, 24)    55320       BN5[0][0]                        
____________________________________________________________________________________________________
Conv6.3 (Conv2D)                 (None, 2, 256, 24)    82968       BN5[0][0]                        
____________________________________________________________________________________________________
Conv6.4 (Conv2D)                 (None, 2, 256, 24)    110616      BN5[0][0]                        
____________________________________________________________________________________________________
Conv6.5 (Conv2D)                 (None, 2, 256, 24)    165912      BN5[0][0]                        
____________________________________________________________________________________________________
Conv6.6 (Conv2D)                 (None, 2, 256, 24)    221208      BN5[0][0]                        
____________________________________________________________________________________________________
Concat.Conv6 (Concatenate)       (None, 2, 256, 144)   0           Conv6.1[0][0]                    
                                                                   Conv6.2[0][0]                    
                                                                   Conv6.3[0][0]                    
                                                                   Conv6.4[0][0]                    
                                                                   Conv6.5[0][0]                    
                                                                   Conv6.6[0][0]                    
____________________________________________________________________________________________________
Conv6.1x1 (Conv2D)               (None, 2, 256, 36)    5220        Concat.Conv6[0][0]               
____________________________________________________________________________________________________
average_pooling2d_12 (AveragePoo (None, 1, 256, 36)    0           Conv6.1x1[0][0]                  
____________________________________________________________________________________________________
BN6 (BatchNormalization)         (None, 1, 256, 36)    144         average_pooling2d_12[0][0]       
____________________________________________________________________________________________________
Conv7.1 (Conv2D)                 (None, 1, 256, 24)    27672       BN6[0][0]                        
____________________________________________________________________________________________________
Conv7.2 (Conv2D)                 (None, 1, 256, 24)    55320       BN6[0][0]                        
____________________________________________________________________________________________________
Conv7.3 (Conv2D)                 (None, 1, 256, 24)    82968       BN6[0][0]                        
____________________________________________________________________________________________________
Conv7.4 (Conv2D)                 (None, 1, 256, 24)    110616      BN6[0][0]                        
____________________________________________________________________________________________________
Conv7.5 (Conv2D)                 (None, 1, 256, 24)    165912      BN6[0][0]                        
____________________________________________________________________________________________________
Conv7.6 (Conv2D)                 (None, 1, 256, 24)    221208      BN6[0][0]                        
____________________________________________________________________________________________________
Concat.Conv7 (Concatenate)       (None, 1, 256, 144)   0           Conv7.1[0][0]                    
                                                                   Conv7.2[0][0]                    
                                                                   Conv7.3[0][0]                    
                                                                   Conv7.4[0][0]                    
                                                                   Conv7.5[0][0]                    
                                                                   Conv7.6[0][0]                    
____________________________________________________________________________________________________
Conv7.1x1 (Conv2D)               (None, 1, 256, 36)    5220        Concat.Conv7[0][0]               
____________________________________________________________________________________________________
BN7 (BatchNormalization)         (None, 1, 256, 36)    144         Conv7.1x1[0][0]                  
____________________________________________________________________________________________________
flatten_3 (Flatten)              (None, 9216)          0           BN7[0][0]                        
____________________________________________________________________________________________________
dropout_3 (Dropout)              (None, 9216)          0           flatten_3[0][0]                  
____________________________________________________________________________________________________
dense_7 (Dense)                  (None, 64)            589888      dropout_3[0][0]                  
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 64)            256         dense_7[0][0]                    
____________________________________________________________________________________________________
dense_8 (Dense)                  (None, 64)            4160        batch_normalization_5[0][0]      
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 64)            256         dense_8[0][0]                    
____________________________________________________________________________________________________
dense_9 (Dense)                  (None, 256)           16640       batch_normalization_6[0][0]      
====================================================================================================
Total params: 2,921,684
Trainable params: 2,921,042
Non-trainable params: 642

這是我最終使Keras / TF 1.3模型與Keras / TF> 1.3一起使用的方式:

在TensorFlow 1.3環境中

import tensorflow as tf
from tensorflow.contrib.keras.python.keras import backend
from tensorflow.contrib.keras.python.keras.models import load_model

name = 'my_model_name'
model = load_model('{}.h5'.format(name))

# save state using TensorFlow
saver = tf.train.Saver()
saver.save(backend.get_session(), '{}_weights.tf'.format(name))
backend.clear_session()

在TensorFlow> 1.3環境中(我已使用1.10.1):

import tensorflow as tf
from tensorflow.python.keras import backend  # <- different import!
from tensorflow.contrib.keras.api.keras.models import load_model

name = 'my_model_name'

# first load model architecture
model = load_model('{}.h5'.format(name))

# then load correct state using TensorFlow
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
sess = backend.get_session()
sess.run(tf.variables_initializer(all_variables))

# create a list of variables that does not include the state of
# the used Adam optimizer (it's missing in the .h5 file).
# however, I believe THAT WAS NOT THE ISSUE.
var_list = [v for v in all_variables if "Adam" not in v.name]
saver = tf.train.Saver(var_list=var_list)
saver.restore(sess, '{}_weights.tf'.format(name))

# now save the whole model again using Keras (this time the correct way)
model.save('{}_new.h5'.format(name))

解決方法基於此文章 顯然Keras(不是TensorFlow)如何恢復已保存模型的狀態是有問題的。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM