繁体   English   中英

如何保存Keras的培训历史以进行交叉验证(循环)?

[英]How to save the training history of Keras for a cross valication(a loop)?

对于交叉验证,如何保存不同训练集和交叉验证集的训练历史记录? 我认为'一个'附加模式的泡菜写将起作用,但实际上它没有用。 如果可能的话,请你告诉我保存所有模型的方法,现在我只能用model.save(file)保存最后训练过的模型。

historyfile = 'history.pickle'
f = open(historyfile,'w')
f.close()
ind = 0
save = {}
for train, test in kfold.split(input,output):
    ind = ind+1
    #create model
    model = model_FCN()
    # fit the model
    history = model.fit(input[list(train)], output[list(train)], batch_size = 16, epochs = 100, verbose =1, validation_data =(input[list(test)],output[list(test)]))
    #save to file 
    try:
        f = open(historyfile,'a') ## appending mode??
        save['cv'+ str(ind)]= history.history
        pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
        f.close()
    except Exception as e:
        print('Unable to save data to', historyfile, ':', e)

    scores = model.evaluate(MR_patch[list(test)], CT_patch[list(test)], verbose=0)
    print("%s: %.2f" % (model.metrics_names[1], scores[1]))
    cvscores.append(scores[1])
    print("cross validation stage: " + str(ind))

print("%.2f (+/- %.2f)" % (np.mean(cvscores), np.std(cvscores)))

要在某个列车的每个纪元后保存模型并验证数据,您可以使用Callback

例如:

from keras.callbacks import ModelCheckpoint
import os

output_directory = '' # here should be path to output directory    
model_checkpoint = ModelCheckpoint(os.path.join(output_directory , 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'))
model.fit(input[list(train)],
          output[list(train)],
          batch_size=16,
          epochs=100,
          verbose=1,
          validation_data=(input[list(test)],output[list(test)]),
          callbacks=[model_checkpoint])

每个纪元后,您的模型将保存在文件中。 有关此回调的更多信息,请参阅文档( https://keras.io/callbacks/

如果要保存每次折叠训练的模型,只需在for循环中添加model.save(file):

model.fit(input[list(train)],
          output[list(train)],
          batch_size=16,
          epochs=100,
          verbose=1,
          validation_data=(input[list(test)],output[list(test)]))
model.save(os.path.join(output_directory, 'fold_{}_model.hdf5'.format(ind)))

要保存历史记录:您可以保存历史记录一次,而不必将其附加到每个循环上的文件中。 在for循环之后,您应该获得带有键(折叠标记)和值(每个折叠的历史记录)的字典并保存此字典,如下所示:

f = open(historyfile, 'wb')
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM