簡體   English   中英

model = tf.keras.models.load_model()

[英]model = tf.keras.models.load_model()

我用這種類型的代碼保存了一個 MLP 回歸類型算法:

#define model
model = Sequential()
model.add(Dense(80, input_dim=2, kernel_initializer='normal', activation='relu'))
model.add(Dense(60, kernel_initializer='normal', activation='relu'))
model.add(Dense(40, kernel_initializer='normal', activation='relu'))
model.add(Dense(20, kernel_initializer='normal', activation='relu'))
model.add(Dense(10, kernel_initializer='normal', activation='relu'))
model.add(Dense(5, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.summary()
model.compile(loss='mse', optimizer='adam', metrics=[rmse])



# train model, test callback option
history = model.fit(X_train, Y_train, epochs=75, batch_size=1, verbose=2, callbacks=[callback])
#history = model.fit(X_train, Y_train, epochs=60, batch_size=1, verbose=2)

# plot metrics
plt.plot(history.history['rmse'])
plt.title('kW RSME Vs Epoch')
plt.show()


model.save('./saved_model/kwSummer')

但是當我嘗試加載保存的 model 時:

model = tf.keras.models.load_model('./saved_model/kwSummer')

# Check its architecture
new_model.summary()

嘗試加載 model 時出現以下錯誤。有人有什么想法可以嘗試嗎?

ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.

I have been experimenting with using Python 3.7 to train the model and then IPython Anaconda Python 3.8 to load the model, would this have anything to do with the issue? 像 tensorflow 的 2 個不同版本?

編輯,這是整個腳本

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend

from datetime import datetime
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import seaborn as sns
import math


df = pd.read_csv('./colabData.csv', index_col='Date', parse_dates=True)

print(df.info())



# This function keeps the learning rate at 0.001
# and decreases it exponentially after that.
def scheduler(epoch):
  if epoch < 1:
    return 0.001
  else:
    return 0.001 * tf.math.exp(0.01 * (1 - epoch))

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)


#function to calculate RSME
def rmse(y_true, y_pred):
    return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))




dfTrain = df.copy()

# split into input (X) and output (Y) variables
X = dfTrain.drop(['kW'],1)
Y = dfTrain['kW']

#define training & testing data set
offset = int(X.shape[0] * 0.8)
X_train, Y_train = X[:offset], Y[:offset]
X_test, Y_test = X[offset:], Y[offset:]


#define model
model = Sequential()
model.add(Dense(80, input_dim=2, kernel_initializer='normal', activation='relu'))
model.add(Dense(60, kernel_initializer='normal', activation='relu'))
model.add(Dense(40, kernel_initializer='normal', activation='relu'))
model.add(Dense(20, kernel_initializer='normal', activation='relu'))
model.add(Dense(10, kernel_initializer='normal', activation='relu'))
model.add(Dense(5, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.summary()
model.compile(loss='mse', optimizer='adam', metrics=[rmse])



# train model, test callback option
history = model.fit(X_train, Y_train, epochs=75, batch_size=1, verbose=2, callbacks=[callback])
#history = model.fit(X_train, Y_train, epochs=60, batch_size=1, verbose=2)

# plot metrics
plt.plot(history.history['rmse'])
plt.title('kW RSME Vs Epoch')
plt.show()

model.save('./saved_model/kwSummer')
print('[INFO] Saved model to drive')

由於您有自定義 object,因此您必須使用custom_object參數加載它。 它還在錯誤日志中通知您。

In addition, please use the `custom_objects` arg when calling `load_model()`.

嘗試如下

new_model = tf.keras.models.load_model('./saved_model/kwSummer', , 
                                       custom_objects={"rmse": rmse})

我可以建議通過 google colab 運行代碼嗎? 這可能有助於查看代碼問題還是兼容性問題。 由於 google colab 將確保兼容性,因為它解決了我遇到的許多 ML 問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM