简体   繁体   English

继续训练Keras模型并加载和保存权重

[英]Keep training Keras model with loading and saving the weights

Since I cannot install h5py due to package inconsistency I am wondering if it is possible to save and load the weights in Keras to keep training your model on a new data. 由于由于软件包不一致而无法安装h5py,所以我想知道是否可以在Keras中保存和加载权重以继续在新数据上训练模型。 I know I can do the following: 我知道我可以执行以下操作:

   old_weights = model.get_weights()
   del model
   new_model.set_weights(old_weights)

where model is the old model and new_model is the new one.Here is a complete example: 这里的模型是旧模型,新模型是新模型,这是一个完整的示例:

for i in training data:
    model = Sequential()
    model.add(Dense(20, activation='tanh', input_dim=Input))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    model.fit(X, y, epochs=8, batch_size=16, shuffle=False, verbose=0)
    new_model = Sequential()
    new_model.add(Dense(20, activation='tanh', input_dim=Input))
    new_model.add(Dense(1))
    new_model.compile(optimizer='adam', loss='mse')
    old_weights = model.get_weights()
    del model
    new_model.set_weights(old_weights)
    model=new_model

I want after reading each training example (X and y are different at each iteration) save the weights and load it again and start from pre-trained model. 我想在阅读每个训练示例(每次迭代的X和y不同)后保存权重并再次加载它,然后从预先训练的模型开始。 I am not sure if my code does that since I am defining optimizer and model.compile again. 我不确定我的代码是否执行此操作,因为我将再次定义优化器和model.compile。 Can anyone help me if the following code save the model and every iteration starts from pre-trained model. 如果以下代码保存了模型并且每次迭代都从预训练的模型开始,谁能帮我。

You don't need to keep recompiling the model. 您无需继续重新编译模型。 Instead just fit your model multiple times after loading your samples. 而是在加载样本后多次拟合模型。

from keras.models import Sequential
from keras.layers import Dense

model = Sequential()
model.add(Dense(20, activation='tanh', input_dim=Input))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')
# load the data into training_data 
for data in training_data:  
    model.fit(data[0], data[1], epochs=8, batch_size=16, shuffle=False, verbose=0)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM