簡體   English   中英

Python ResNet50: model.save() NotImplementedError

[英]Python ResNet50: model.save() NotImplementedError

我的目標是保存(然后加載)重新發送的 model。我已經按照教程進行操作,最終得到了一個可以學習的 model,但是當我嘗試保存它時,它會出錯。

我發現了這個類似的 stackoverflow 問題,但對於我的生活,我無法弄清楚如何解決它。

我看過的另一件事是來自 Keras.io 的這篇文章,但我使用的是 Sequential() model 而不是一些自定義的。 我不確定這個 get_config function 應該在哪里。

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt


DATASET_PATH = "/XX/dataset"
CLASS_NAMES = ["0", "1", "2", "3", "4"]

img_height,img_width=180,180
batch_size=32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  DATASET_PATH,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  DATASET_PATH,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

resnet_model = Sequential()

pretrained_model= tf.keras.applications.ResNet50(include_top=False,
       input_shape=(180,180,3),
       pooling='avg',classes=5,
       weights='imagenet')

for layer in pretrained_model.layers:
    layer.trainable=False

resnet_model.add(pretrained_model)
resnet_model.add(Flatten())
resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(5, activation='softmax'))

resnet_model.summary()

resnet_model.compile(optimizer=Adam(lr=0.001),loss='sparse_categorical_crossentropy',metrics=['accuracy'])

epochs=1
history = resnet_model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

resnet_model.save("/XX/test.h5", save_format="h5")

和錯誤:

NotImplementedError: 
Layer ModuleWrapper has arguments ['self', 'module', 'method_name']
in `__init__` and therefore must override `get_config()`.

Example:

class CustomLayer(keras.layers.Layer):
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2

    def get_config(self):
        config = super().get_config()
        config.update({
            "arg1": self.arg1,
            "arg2": self.arg2,
        })
        return config

問題出在這一行

from tensorflow.python.keras.layers import Dense, Flatten

如果您將其替換為此它應該可以解決您的問題

from tensorflow.keras.layers import Dense, Flatten

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM