繁体   English   中英

加载 Tensorflow keras Model (.h5) 时出错

[英]Error loading Tensorflow keras Model (.h5)

我训练了一个数字图像并制作了一个 model 文件。

对应的酱汁如下。

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense 
from tensorflow.keras.models import load_model
import cv2
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Flatten, Convolution2D, MaxPooling2D
from tensorflow.keras.layers import Dropout, Activation, Dense
from tensorflow.keras.layers import Conv2D

os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

path = '/Users/animalman/Documents/test/cnn/'

trainPath = os.listdir(path+'train')
testPath = os.listdir(path+'test')

categories = ["5"]
length = len(categories)

width = 28
height = 28
label = [1 for i in range(length)]
X = []
Y = []
for idx, categorie in enumerate(categories):
    label = [0 for i in range(length)]
    label[idx] = 1
    
    fileDir = path + 'train' + '/' + categorie + '/'
    for t, dir, f in os.walk(fileDir):
        for filename in f:
            print(fileDir + filename)
            img = cv2.imread(fileDir + filename)
            img = cv2.resize(img, None, fx=width/img.shape[0],fy=height/img.shape[1])
            X.append(img)
            Y.append(label)
X = np.array(X)
Y = np.array(Y)
X_train, X_test,Y_train, Y_test = train_test_split(X,Y)
xy = (X_train, X_test, Y_train, Y_test)


X_train = X_train.astype("float") / 256
X_test  = X_test.astype("float")  / 256

model = Sequential()
model.add(Conv2D(16, (3, 3), input_shape=X_train.shape[1:], padding='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))

model.add(Conv2D(64, (3, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten()) 
model.add(Dense(512))  
model.add(Activation('relu'))
model.add(Dropout(0.5))

model.add(Dense(length))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])


hdf5_file = "./7obj-model.h5"
if os.path.exists(hdf5_file):
    model.load_weights(hdf5_file)
else:
    model.fit(X_train, Y_train, batch_size=32, epochs=1)
    model.save_weights(hdf5_file)

然后我带来了已保存的 model 文件。

loaded_model = tf.keras.models.load_model(hdf5_file)

但是,这是发生错误的地方。 什么原因?

回溯(最近一次通话最后):文件“/Users/animalman/Documents/test/train.py”,第 112 行,loaded_model = tf.keras.models.load_model(hdf5_file) 文件“/Users/animalman/opt/anaconda3 /lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py”,第 206 行,在 load_model 返回 hdf5_format.load_model_from_hdf5(文件路径,custom_objects,文件“/Users/animalman/opt/anaconda3/lib /python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py”,第 181 行,在 load_model_from_hdf5 中引发 ValueError('No model found in config file.') ValueError: No model found in config file.

本文所述,您的h5文件仅包含权重。 您需要将 model 架构保存在 json 文件中,然后使用model_from_json加载 model 配置,因此,您可以使用load_weights加载权重

另一种选择可能是通过将最后一行替换为

model.save("model.h5")

然后加载,你可以利用

model = load_model('model.h5')

添加到@Oscar 响应,对于较小和简单的模型, 'h5'格式就足够了,但对于具有 custom_layers 或自定义指标的复杂模型(功能和子类),最好以'tf'格式保存(也称为 SavedModel 格式)

在此处查看有关Keras 网页的更详细指南

Keras SavedModel 格式限制:

SavedModel 进行的跟踪以生成层调用函数的图形使 SavedModel 比 H5 更具可移植性,但它也有缺点。

可能比 H5 更慢、更笨重。 无法序列化从掩码参数生成的操作(即,如果使用 layer(..., mask=mask_value) 调用图层,则掩码参数不会保存到 SavedModel)。 不将覆盖的 train_step() 保存在子类模型中。 使用掩码或具有自定义训练循环的自定义对象仍然可以从 SavedModel 保存和加载,除非它们必须覆盖 get_config()/from_config(),并且在加载时必须将类传递给 custom_objects 参数。

H5 限制:

通过 model.add_loss() 和 model.add_metric() 添加的外部损失和指标不会保存(与 SavedModel 不同)。 如果您在 model 上有此类损失和指标,并且您想恢复训练,则需要在加载 model 后自行添加这些损失。 请注意,这不适用于通过 self.add_loss() 和 self.add_metric() 在层内创建的损失/指标。 只要图层被加载,这些损失和指标就会被保留,因为它们是图层调用方法的一部分。 自定义层等自定义对象的计算图不包含在保存的文件中。 在加载时,Keras 将需要访问这些对象的 Python 类/函数,以重建 model。 请参阅自定义对象。 不支持预处理层。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM