繁体   English   中英

model = tf.keras.models.load_model()

[英]model = tf.keras.models.load_model()

我用这种类型的代码保存了一个 MLP 回归类型算法:

#define model
model = Sequential()
model.add(Dense(80, input_dim=2, kernel_initializer='normal', activation='relu'))
model.add(Dense(60, kernel_initializer='normal', activation='relu'))
model.add(Dense(40, kernel_initializer='normal', activation='relu'))
model.add(Dense(20, kernel_initializer='normal', activation='relu'))
model.add(Dense(10, kernel_initializer='normal', activation='relu'))
model.add(Dense(5, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.summary()
model.compile(loss='mse', optimizer='adam', metrics=[rmse])



# train model, test callback option
history = model.fit(X_train, Y_train, epochs=75, batch_size=1, verbose=2, callbacks=[callback])
#history = model.fit(X_train, Y_train, epochs=60, batch_size=1, verbose=2)

# plot metrics
plt.plot(history.history['rmse'])
plt.title('kW RSME Vs Epoch')
plt.show()


model.save('./saved_model/kwSummer')

但是当我尝试加载保存的 model 时:

model = tf.keras.models.load_model('./saved_model/kwSummer')

# Check its architecture
new_model.summary()

尝试加载 model 时出现以下错误。有人有什么想法可以尝试吗?

ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.

I have been experimenting with using Python 3.7 to train the model and then IPython Anaconda Python 3.8 to load the model, would this have anything to do with the issue? 像 tensorflow 的 2 个不同版本?

编辑,这是整个脚本

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend

from datetime import datetime
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import seaborn as sns
import math


df = pd.read_csv('./colabData.csv', index_col='Date', parse_dates=True)

print(df.info())



# This function keeps the learning rate at 0.001
# and decreases it exponentially after that.
def scheduler(epoch):
  if epoch < 1:
    return 0.001
  else:
    return 0.001 * tf.math.exp(0.01 * (1 - epoch))

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)


#function to calculate RSME
def rmse(y_true, y_pred):
    return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))




dfTrain = df.copy()

# split into input (X) and output (Y) variables
X = dfTrain.drop(['kW'],1)
Y = dfTrain['kW']

#define training & testing data set
offset = int(X.shape[0] * 0.8)
X_train, Y_train = X[:offset], Y[:offset]
X_test, Y_test = X[offset:], Y[offset:]


#define model
model = Sequential()
model.add(Dense(80, input_dim=2, kernel_initializer='normal', activation='relu'))
model.add(Dense(60, kernel_initializer='normal', activation='relu'))
model.add(Dense(40, kernel_initializer='normal', activation='relu'))
model.add(Dense(20, kernel_initializer='normal', activation='relu'))
model.add(Dense(10, kernel_initializer='normal', activation='relu'))
model.add(Dense(5, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.summary()
model.compile(loss='mse', optimizer='adam', metrics=[rmse])



# train model, test callback option
history = model.fit(X_train, Y_train, epochs=75, batch_size=1, verbose=2, callbacks=[callback])
#history = model.fit(X_train, Y_train, epochs=60, batch_size=1, verbose=2)

# plot metrics
plt.plot(history.history['rmse'])
plt.title('kW RSME Vs Epoch')
plt.show()

model.save('./saved_model/kwSummer')
print('[INFO] Saved model to drive')

由于您有自定义 object,因此您必须使用custom_object参数加载它。 它还在错误日志中通知您。

In addition, please use the `custom_objects` arg when calling `load_model()`.

尝试如下

new_model = tf.keras.models.load_model('./saved_model/kwSummer', , 
                                       custom_objects={"rmse": rmse})

我可以建议通过 google colab 运行代码吗? 这可能有助于查看代码问题还是兼容性问题。 由于 google colab 将确保兼容性,因为它解决了我遇到的许多 ML 问题。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM