[英]How to copy a TF hub saved model architecture?
我想創建ESRGAN TF 集線器保存的 model架構的副本。
原因是直接使用 tf.saved_model.load() 加載它不適用於雲 TPU,而且我無權訪問 Google Cloud 存儲桶。
我嘗試創建一個 Keras Sequential model 然后創建它的副本並將副本保存為.h5格式
model = tf.keras.models.Sequential([Input(shape=(32,32,3)),
hub.KerasLayer("https://tfhub.dev/captain-pool/esrgan-tf2/1", trainable=False)])
model_copy = keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())
model_copy.save("model.h5")
但是當我加載保存的 model 時,它仍然從 tfhub.dev 站點獲得 model 架構,這是我想要克服的。 也可以在本地下載 tfhub model 文件。 但是從本地下載的文件中加載它們也無法在雲 TPU 上運行。
唯一可行的方法是將 model 架構保存在獨立的.h5 model 中,然后從中加載。 有什么辦法嗎?
如果您將 TPU 與 Tensorflow 2.3+ 一起使用,您可以使用load_options
參數從 TF Hub 加載 model。 代碼將如下所示:
...
with strategy.scope()
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') layer =
hub.KerasLayer(..., load_options=load_options)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.