簡體   English   中英

加載預訓練的 keras 模型以在谷歌雲上繼續訓練

[英]Load pre-trained keras model for continued training on google cloud

我正在嘗試加載一個預訓練的 Keras 模型,以便在谷歌雲上繼續訓練。 它在本地工作,只需加載鑒別器和生成器

 model = load_model('myPretrainedModel.h5')

但顯然這在谷歌雲上不起作用,我嘗試使用與從我的谷歌存儲桶中讀取訓練數據相同的方法,使用:

fil = "gs://mygcbucket/myPretrainedModel.h5"    
f = BytesIO(file_io.read_file_to_string(fil, binary_mode=True))
return np.load(f)

但是,這似乎不適用於加載模型,運行該作業時出現以下錯誤。

ValueError:當allow_pickle=False時無法加載包含pickle數據的文件

添加allow_pickle=True ,引發另一個錯誤:

OSError:無法將文件 <_io.BytesIO 對象在 0x7fdf2bb42620> 解釋為泡菜

然后我嘗試了一些我發現的東西,因為有人建議我解決類似的問題,據我所知,它從存儲桶中臨時在本地重新保存模型(與作業運行的位置有關),然后加載它,使用:

fil = "gs://mygcbucket/myPretrainedModel.h5"  
model_file = file_io.FileIO(fil, mode='rb')
file_stream = file_io.FileIO(model_file, mode='r')
temp_model_location = './temp_model.h5'
temp_model_file = open(temp_model_location, 'wb')
temp_model_file.write(file_stream.read())
temp_model_file.close()
file_stream.close()
model = load_model(temp_model_location)
return model

但是,這會引發以下錯誤:

類型錯誤:預期的二進制或 unicode 字符串,得到 tensorflow.python.lib.io.file_io.FileIO 對象

我必須承認,我不太確定我需要做什么才能從我的存儲桶中實際加載一個預訓練的 keras 模型,以及在我在谷歌雲的訓練工作中的使用。 任何幫助深表感謝。

我建議使用 AI Platform Notebooks 來做到這一點。 使用此方法下載經過訓練的模型。 檢查代碼示例選項卡下的 Python 代碼。 在運行 Notebook 的 VM 中擁有模型后,您可以像在本地一樣加載它。 這里有一個使用tf.keras.models.load_model的示例。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM