簡體   English   中英

Unpickled tensorflow model 預測失敗

[英]Unpickled tensorflow model fails to make predictions

我看過這個問題這個問題,但既沒有真正解釋發生了什么,也沒有為我面臨的問題提供解決方案。

下面的代碼是我在更大范圍內嘗試做的事情的片段。 基本上,我正在創建一個 object,其中包含一個 tensorflow.keras model,我正在使用改編自此答案的技巧將其保存到 pickle 文件中。 我正在處理的實際 class 有幾個其他字段和方法,因此我更願意使其可以 pickle-able 並以靈活的方式進行。 請參閱下面的代碼以最低限度地重現該問題。 ReproduceProblem.py

import pickle
import numpy as np
import tempfile
import tensorflow as tf


def __getstate__(self):
    model_str = ""
    with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
        tf.keras.models.save_model(self, fd.name, overwrite=True)
        model_str = fd.read()
    d = {"model_str": model_str}
    return d


def __setstate__(self, state):
    with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
        fd.write(state["model_str"])
        fd.flush()
        model = tf.keras.models.load_model(fd.name)
    self.__dict__ = model.__dict__


class ContainsSequential:
    def __init__(self):
        self.other_field = "potato"
        self.model = tf.keras.models.Sequential()
        self.model.__getstate__ = lambda mdl=self.model: __getstate__(mdl)
        self.model.__setstate__ = __setstate__
        self.model.add(tf.keras.layers.Input(shape=(None, 3)))
        self.model.add(tf.keras.layers.LSTM(3, activation="relu", return_sequences=True))
        self.model.add(tf.keras.layers.Dense(3, activation="linear"))


# Now do the business:
tf.keras.backend.clear_session()
file_name = 'pickle_file.pckl'
instance = ContainsSequential()
instance.model.predict(np.random.rand(3, 1, 3))
print(instance.other_field)
with open(file_name, 'wb') as fid:
    pickle.dump(instance, fid)
with open(file_name, 'rb') as fid:
    restored_instance = pickle.load(fid)
print(restored_instance.other_field)
restored_instance.model.predict(np.random.rand(3, 1, 3))
print('Done')

雖然在行instance.model.predict(np.random.rand(3, 1, 3))上沒有失敗,但在行restored_instance.model.predict(np.random.rand(3, 1, 3)) , 3) 上確實失敗了restored_instance.model.predict(np.random.rand(3, 1, 3)) ,錯誤信息是:

  File "<path>\ReproduceProblem.py", line 52, in <module>
    restored_instance.model.predict(np.random.rand(3, 1, 3))
  File "<path>\Python\Python39\lib\site-packages\keras\engine\training.py", line 1693, in predict
    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
  File "<path>\Python\Python39\lib\site-packages\keras\engine\training.py", line 716, in distribute_strategy
    return self._distribution_strategy or tf.distribute.get_strategy()
AttributeError: 'Sequential' object has no attribute '_distribution_strategy'

我對_distribution_strategy應該是什么一無所知,但在我的工作流程中,一旦我保存了文件,我就不需要再訓練它,只需使用它來進行預測或參考 class 的其他屬性。我嘗試將其設置為None並添加更多屬性,但沒有成功。

像這樣重新定義Tensorflow class 的方法是一種危險的方法:

self.model = tf.keras.models.Sequential()
self.model.__getstate__ = lambda mdl=self.model: __getstate__(mdl)
self.model.__setstate__ = __setstate__

我建議避免這種情況,而是重新定義自定義 class 的__getstate____setstate__方法。 這是一個工作示例:

import pickle
import numpy as np
import tempfile
import tensorflow as tf


class ContainsSequential:
    def __init__(self):
        self.other_field = "Potato"
        self.model = tf.keras.models.Sequential()
        self.model.add(tf.keras.layers.Input(shape=(None, 3)))
        self.model.add(tf.keras.layers.LSTM(3, activation="relu", return_sequences=True))
        self.model.add(tf.keras.layers.Dense(3, activation="linear"))
        
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
            tf.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = {"model_str": model_str, "other_field": self.other_field}
        return d
    
    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
            fd.write(state["model_str"])
            fd.flush()
            model = tf.keras.models.load_model(fd.name)
        self.model = model
        self.other_field = state["other_field"]

和一個測試:

tf.keras.backend.clear_session()
file_name = 'pickle_file.pkl'
instance = ContainsSequential()

rnd = np.random.rand(3, 1, 3)
print(1, instance.model.predict(rnd))
with open(file_name, 'wb') as fid:
    pickle.dump(instance, fid)
with open(file_name, 'rb') as fid:
    r_instance = pickle.load(fid)
print(2, r_instance.model.predict(rnd))
print(r_instance.other_field)

您應該使用model.save('path/to/location')keras.models.load_model()而不是使用 pickle 來序列化/反序列化 tensorflow 模型。 這是推薦的做法,您可以在https://www.tensorflow.org/guide/keras/save_and_serialize查看文檔。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM