簡體   English   中英

加載 Tensorflow keras Model (.h5) 時出錯

[英]Error loading Tensorflow keras Model (.h5)

我訓練了一個數字圖像並制作了一個 model 文件。

對應的醬汁如下。

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense 
from tensorflow.keras.models import load_model
import cv2
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Flatten, Convolution2D, MaxPooling2D
from tensorflow.keras.layers import Dropout, Activation, Dense
from tensorflow.keras.layers import Conv2D

os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

path = '/Users/animalman/Documents/test/cnn/'

trainPath = os.listdir(path+'train')
testPath = os.listdir(path+'test')

categories = ["5"]
length = len(categories)

width = 28
height = 28
label = [1 for i in range(length)]
X = []
Y = []
for idx, categorie in enumerate(categories):
    label = [0 for i in range(length)]
    label[idx] = 1
    
    fileDir = path + 'train' + '/' + categorie + '/'
    for t, dir, f in os.walk(fileDir):
        for filename in f:
            print(fileDir + filename)
            img = cv2.imread(fileDir + filename)
            img = cv2.resize(img, None, fx=width/img.shape[0],fy=height/img.shape[1])
            X.append(img)
            Y.append(label)
X = np.array(X)
Y = np.array(Y)
X_train, X_test,Y_train, Y_test = train_test_split(X,Y)
xy = (X_train, X_test, Y_train, Y_test)


X_train = X_train.astype("float") / 256
X_test  = X_test.astype("float")  / 256

model = Sequential()
model.add(Conv2D(16, (3, 3), input_shape=X_train.shape[1:], padding='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))

model.add(Conv2D(64, (3, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten()) 
model.add(Dense(512))  
model.add(Activation('relu'))
model.add(Dropout(0.5))

model.add(Dense(length))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])


hdf5_file = "./7obj-model.h5"
if os.path.exists(hdf5_file):
    model.load_weights(hdf5_file)
else:
    model.fit(X_train, Y_train, batch_size=32, epochs=1)
    model.save_weights(hdf5_file)

然后我帶來了已保存的 model 文件。

loaded_model = tf.keras.models.load_model(hdf5_file)

但是,這是發生錯誤的地方。 什么原因?

回溯(最近一次通話最后):文件“/Users/animalman/Documents/test/train.py”,第 112 行,loaded_model = tf.keras.models.load_model(hdf5_file) 文件“/Users/animalman/opt/anaconda3 /lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py”,第 206 行,在 load_model 返回 hdf5_format.load_model_from_hdf5(文件路徑,custom_objects,文件“/Users/animalman/opt/anaconda3/lib /python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py”,第 181 行,在 load_model_from_hdf5 中引發 ValueError('No model found in config file.') ValueError: No model found in config file.

本文所述,您的h5文件僅包含權重。 您需要將 model 架構保存在 json 文件中,然后使用model_from_json加載 model 配置,因此,您可以使用load_weights加載權重

另一種選擇可能是通過將最后一行替換為

model.save("model.h5")

然后加載,你可以利用

model = load_model('model.h5')

添加到@Oscar 響應,對於較小和簡單的模型, 'h5'格式就足夠了,但對於具有 custom_layers 或自定義指標的復雜模型(功能和子類),最好以'tf'格式保存(也稱為 SavedModel 格式)

在此處查看有關Keras 網頁的更詳細指南

Keras SavedModel 格式限制:

SavedModel 進行的跟蹤以生成層調用函數的圖形使 SavedModel 比 H5 更具可移植性,但它也有缺點。

可能比 H5 更慢、更笨重。 無法序列化從掩碼參數生成的操作(即,如果使用 layer(..., mask=mask_value) 調用圖層,則掩碼參數不會保存到 SavedModel)。 不將覆蓋的 train_step() 保存在子類模型中。 使用掩碼或具有自定義訓練循環的自定義對象仍然可以從 SavedModel 保存和加載,除非它們必須覆蓋 get_config()/from_config(),並且在加載時必須將類傳遞給 custom_objects 參數。

H5 限制:

通過 model.add_loss() 和 model.add_metric() 添加的外部損失和指標不會保存(與 SavedModel 不同)。 如果您在 model 上有此類損失和指標,並且您想恢復訓練,則需要在加載 model 后自行添加這些損失。 請注意,這不適用於通過 self.add_loss() 和 self.add_metric() 在層內創建的損失/指標。 只要圖層被加載,這些損失和指標就會被保留,因為它們是圖層調用方法的一部分。 自定義層等自定義對象的計算圖不包含在保存的文件中。 在加載時,Keras 將需要訪問這些對象的 Python 類/函數,以重建 model。 請參閱自定義對象。 不支持預處理層。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM