簡體   English   中英

如何復制 TF hub 保存的 model 架構?

[英]How to copy a TF hub saved model architecture?

我想創建ESRGAN TF 集線器保存的 model架構的副本。

原因是直接使用 tf.saved_model.load() 加載它不適用於雲 TPU,而且我無權訪問 Google Cloud 存儲桶。

我嘗試創建一個 Keras Sequential model 然后創建它的副本並將副本保存為.h5格式

model = tf.keras.models.Sequential([Input(shape=(32,32,3)),
        hub.KerasLayer("https://tfhub.dev/captain-pool/esrgan-tf2/1", trainable=False)])
model_copy = keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())
model_copy.save("model.h5")

但是當我加載保存的 model 時,它仍然從 tfhub.dev 站點獲得 model 架構,這是我想要克服的。 也可以在本地下載 tfhub model 文件。 但是從本地下載的文件中加載它們也無法在雲 TPU 上運行。

唯一可行的方法是將 model 架構保存在獨立的.h5 model 中,然后從中加載。 有什么辦法嗎?

如果您將 TPU 與 Tensorflow 2.3+ 一起使用,您可以使用load_options參數從 TF Hub 加載 model。 代碼將如下所示:

...
with strategy.scope()
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') layer = 
  hub.KerasLayer(..., load_options=load_options)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM