[英]Load pre-trained keras model for continued training on google cloud
我正在嘗試加載一個預訓練的 Keras 模型,以便在谷歌雲上繼續訓練。 它在本地工作,只需加載鑒別器和生成器
model = load_model('myPretrainedModel.h5')
但顯然這在谷歌雲上不起作用,我嘗試使用與從我的谷歌存儲桶中讀取訓練數據相同的方法,使用:
fil = "gs://mygcbucket/myPretrainedModel.h5"
f = BytesIO(file_io.read_file_to_string(fil, binary_mode=True))
return np.load(f)
但是,這似乎不適用於加載模型,運行該作業時出現以下錯誤。
ValueError:當allow_pickle=False時無法加載包含pickle數據的文件
添加allow_pickle=True
,引發另一個錯誤:
OSError:無法將文件 <_io.BytesIO 對象在 0x7fdf2bb42620> 解釋為泡菜
然后我嘗試了一些我發現的東西,因為有人建議我解決類似的問題,據我所知,它從存儲桶中臨時在本地重新保存模型(與作業運行的位置有關),然后加載它,使用:
fil = "gs://mygcbucket/myPretrainedModel.h5"
model_file = file_io.FileIO(fil, mode='rb')
file_stream = file_io.FileIO(model_file, mode='r')
temp_model_location = './temp_model.h5'
temp_model_file = open(temp_model_location, 'wb')
temp_model_file.write(file_stream.read())
temp_model_file.close()
file_stream.close()
model = load_model(temp_model_location)
return model
但是,這會引發以下錯誤:
類型錯誤:預期的二進制或 unicode 字符串,得到 tensorflow.python.lib.io.file_io.FileIO 對象
我必須承認,我不太確定我需要做什么才能從我的存儲桶中實際加載一個預訓練的 keras 模型,以及在我在谷歌雲的訓練工作中的使用。 任何幫助深表感謝。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.