[英]Python ResNet50: model.save() NotImplementedError
我的目標是保存(然后加載)重新發送的 model。我已經按照本教程進行操作,最終得到了一個可以學習的 model,但是當我嘗試保存它時,它會出錯。
我發現了這個類似的 stackoverflow 問題,但對於我的生活,我無法弄清楚如何解決它。
我看過的另一件事是來自 Keras.io 的這篇文章,但我使用的是 Sequential() model 而不是一些自定義的。 我不確定這個 get_config function 應該在哪里。
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
DATASET_PATH = "/XX/dataset"
CLASS_NAMES = ["0", "1", "2", "3", "4"]
img_height,img_width=180,180
batch_size=32
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATASET_PATH,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATASET_PATH,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
resnet_model = Sequential()
pretrained_model= tf.keras.applications.ResNet50(include_top=False,
input_shape=(180,180,3),
pooling='avg',classes=5,
weights='imagenet')
for layer in pretrained_model.layers:
layer.trainable=False
resnet_model.add(pretrained_model)
resnet_model.add(Flatten())
resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(5, activation='softmax'))
resnet_model.summary()
resnet_model.compile(optimizer=Adam(lr=0.001),loss='sparse_categorical_crossentropy',metrics=['accuracy'])
epochs=1
history = resnet_model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
resnet_model.save("/XX/test.h5", save_format="h5")
和錯誤:
NotImplementedError:
Layer ModuleWrapper has arguments ['self', 'module', 'method_name']
in `__init__` and therefore must override `get_config()`.
Example:
class CustomLayer(keras.layers.Layer):
def __init__(self, arg1, arg2):
super().__init__()
self.arg1 = arg1
self.arg2 = arg2
def get_config(self):
config = super().get_config()
config.update({
"arg1": self.arg1,
"arg2": self.arg2,
})
return config
問題出在這一行
from tensorflow.python.keras.layers import Dense, Flatten
如果您將其替換為此它應該可以解決您的問題
from tensorflow.keras.layers import Dense, Flatten
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.