繁体   English   中英

如何从HKL或其他任何已保存的keras配置中重建keras模型?

[英]How to rebuild a keras model from a saved keras config in a hkl or any?

我有以下代码,用于训练模型并将其保存到hickle文件中(也可以是任何类型的文件)

from keras import Sequential
from keras.layers import Dense
from keras.models import load_model

import hickle as hkl

import numpy as np

class Model:
    def __init__(self, data=None):
        self.data = data
        self.metrics = []
        self.model = self.__build_model()

    def __build_model(self):
        model = Sequential()
        model.add(Dense(4, activation='relu', input_shape=(3,)))
        model.add(Dense(1, activation='relu'))
        model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
        return model

    def train(self, epochs):
        self.model.fit(self.data[:, :-1], self.data[:,-1], epochs=epochs)
        return self

    def test(self, data):
        self.metrics = self.model.evaluate(data[:, :-1], data[:, -1])
        return self

    def predict(self, input):
        return self.model.predict(input)

    def save(self, path):
      data = {'metrics': self.metrics, 'k_model': self.model.get_config()}
      hkl.dump(data, path, mode='w')
      return self

    def load(self, path):
      data = hkl.load('model.hkl')
      self.metrics = data['metrics']
      self.model = Sequential.from_config(data['k_model'])
      return self

def train():
    train_data = np.random.rand(1000, 4)
    test_data = np.random.rand(100, 4)
    print("TRAINING, TESTING & SAVING..")
    model = Model(train_data)\
                .train(epochs=5)\
                .test(test_data)\
                .save('./model.hkl')
    print('metrics: ', model.metrics)
    conf = model.model.get_config()
    print("type: ", type(conf))
    print("length: ", len(conf))    


if __name__ == '__main__':
    train()

    print('USING SAVED MODEL..')
    model = Model()
    model.load('./model.hkl')
    print(model.metrics) 

这会打印错误

    TypeError: type object argument after ** must be a mapping, not PyContainer

怎么了? 错误形式是keras还是来自ckle锁?

NB。 在这里,我只是保存指标,但是它可以包含任何其他附加信息

谢谢。

我在另一个问题上回答了类似的问题。 这就是我总是保存keras模型的方式:

model.save('model.h5')
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)

用于加载保存的模型:

model = load_model('model_w1.h5')

打印摘要:

model.summary()

要再次训练,您可以在加载后直接使用fit 如果要在某些应用程序中使用它,则首先使其成为全局( 无需为每个预测重新加载 )负载模型,然后使其权重为:

def load_model():

    global model

    json_file = open('model.json', 'r')
    model_json = json_file.read()
    model = model_from_json(model_json)
    model.load_weights("model.h5")
    model._make_predict_function()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM