简体   繁体   English

Unpickled tensorflow model 预测失败

[英]Unpickled tensorflow model fails to make predictions

I've seen this question and this one , but neither actually explain what is going on, nor offer a solution to the problem I'm facing.我看过这个问题这个问题,但既没有真正解释发生了什么,也没有为我面临的问题提供解决方案。

The code below is a snippet from what I'm trying to do in a larger context.下面的代码是我在更大范围内尝试做的事情的片段。 Basically, I'm creating an object that contains a tensorflow.keras model, I'm saving it to a file with pickle using a trick adapted from this answer .基本上,我正在创建一个 object,其中包含一个 tensorflow.keras model,我正在使用改编自此答案的技巧将其保存到 pickle 文件中。 The actual class I'm working on has several other fields and methods, hence why I'd prefer to make it pickle-able and do so in a flexible manner.我正在处理的实际 class 有几个其他字段和方法,因此我更愿意使其可以 pickle-able 并以灵活的方式进行。 See the code below just to reproduce the problem minimally.请参阅下面的代码以最低限度地重现该问题。 ReproduceProblem.py : ReproduceProblem.py

import pickle
import numpy as np
import tempfile
import tensorflow as tf


def __getstate__(self):
    model_str = ""
    with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
        tf.keras.models.save_model(self, fd.name, overwrite=True)
        model_str = fd.read()
    d = {"model_str": model_str}
    return d


def __setstate__(self, state):
    with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
        fd.write(state["model_str"])
        fd.flush()
        model = tf.keras.models.load_model(fd.name)
    self.__dict__ = model.__dict__


class ContainsSequential:
    def __init__(self):
        self.other_field = "potato"
        self.model = tf.keras.models.Sequential()
        self.model.__getstate__ = lambda mdl=self.model: __getstate__(mdl)
        self.model.__setstate__ = __setstate__
        self.model.add(tf.keras.layers.Input(shape=(None, 3)))
        self.model.add(tf.keras.layers.LSTM(3, activation="relu", return_sequences=True))
        self.model.add(tf.keras.layers.Dense(3, activation="linear"))


# Now do the business:
tf.keras.backend.clear_session()
file_name = 'pickle_file.pckl'
instance = ContainsSequential()
instance.model.predict(np.random.rand(3, 1, 3))
print(instance.other_field)
with open(file_name, 'wb') as fid:
    pickle.dump(instance, fid)
with open(file_name, 'rb') as fid:
    restored_instance = pickle.load(fid)
print(restored_instance.other_field)
restored_instance.model.predict(np.random.rand(3, 1, 3))
print('Done')

While is does not fail on the line instance.model.predict(np.random.rand(3, 1, 3)) it does fail on the line restored_instance.model.predict(np.random.rand(3, 1, 3)) , the error message is:虽然在行instance.model.predict(np.random.rand(3, 1, 3))上没有失败,但在行restored_instance.model.predict(np.random.rand(3, 1, 3)) , 3) 上确实失败了restored_instance.model.predict(np.random.rand(3, 1, 3)) ,错误信息是:

  File "<path>\ReproduceProblem.py", line 52, in <module>
    restored_instance.model.predict(np.random.rand(3, 1, 3))
  File "<path>\Python\Python39\lib\site-packages\keras\engine\training.py", line 1693, in predict
    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
  File "<path>\Python\Python39\lib\site-packages\keras\engine\training.py", line 716, in distribute_strategy
    return self._distribution_strategy or tf.distribute.get_strategy()
AttributeError: 'Sequential' object has no attribute '_distribution_strategy'

I don't have the slightest idea of what _distribution_strategy should be, but in my workflow, once I've saved the file I don't need to train it anymore, just use it to make predictions or consult other attributes of the class. I've tried setting it to None and adding more attributes, but with no success.我对_distribution_strategy应该是什么一无所知,但在我的工作流程中,一旦我保存了文件,我就不需要再训练它,只需使用它来进行预测或参考 class 的其他属性。我尝试将其设置为None并添加更多属性,但没有成功。

It is a dangerous approach to redefine methods of a Tensorflow class like this:像这样重新定义Tensorflow class 的方法是一种危险的方法:

self.model = tf.keras.models.Sequential()
self.model.__getstate__ = lambda mdl=self.model: __getstate__(mdl)
self.model.__setstate__ = __setstate__

I'd recommend to avoid that and redefine the __getstate__ and __setstate__ methods of the custom class instead.我建议避免这种情况,而是重新定义自定义 class 的__getstate____setstate__方法。 Here is a working example:这是一个工作示例:

import pickle
import numpy as np
import tempfile
import tensorflow as tf


class ContainsSequential:
    def __init__(self):
        self.other_field = "Potato"
        self.model = tf.keras.models.Sequential()
        self.model.add(tf.keras.layers.Input(shape=(None, 3)))
        self.model.add(tf.keras.layers.LSTM(3, activation="relu", return_sequences=True))
        self.model.add(tf.keras.layers.Dense(3, activation="linear"))
        
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
            tf.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = {"model_str": model_str, "other_field": self.other_field}
        return d
    
    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
            fd.write(state["model_str"])
            fd.flush()
            model = tf.keras.models.load_model(fd.name)
        self.model = model
        self.other_field = state["other_field"]

And a test:和一个测试:

tf.keras.backend.clear_session()
file_name = 'pickle_file.pkl'
instance = ContainsSequential()

rnd = np.random.rand(3, 1, 3)
print(1, instance.model.predict(rnd))
with open(file_name, 'wb') as fid:
    pickle.dump(instance, fid)
with open(file_name, 'rb') as fid:
    r_instance = pickle.load(fid)
print(2, r_instance.model.predict(rnd))
print(r_instance.other_field)

Instead of using pickle to serialize/de-serialize tensorflow models, you should be using model.save('path/to/location') and keras.models.load_model() .您应该使用model.save('path/to/location')keras.models.load_model()而不是使用 pickle 来序列化/反序列化 tensorflow 模型。 This is the recommended practice and you can have a look at the documentation at https://www.tensorflow.org/guide/keras/save_and_serialize .这是推荐的做法,您可以在https://www.tensorflow.org/guide/keras/save_and_serialize查看文档。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何基于训练有素的Tensorflow模型进行预测? - How to make predictions based on a trained Tensorflow Model? 使用keras模型中的张量流图进行预测 - Make predictions using a tensorflow graph from a keras model 对时间序列数据进行 Tensorflow 预测? (回归模型) - Make Tensorflow predictions on time series data? (Regression Model) Keras:TensorFlow 1.3模型在TensorFlow 1.4或更高版本下失败(错误的预测) - Keras: TensorFlow 1.3 model fails under TensorFlow 1.4 or later (wrong predictions) Tensorflow-使用批处理进行预测 - Tensorflow - Using batching to make predictions 是否可以使用keras / TensorFlow模型进行预测而无需下载TensorFlow及其所有相关性? - Is it possible to make predictions with a keras/TensorFlow model without downloading TensorFlow and all its dependancies? 如何在Tensorflow输入管道中根据受过训练的模型进行预测? - How can I make predictions from a trained model inside a Tensorflow input pipeline? 具有Tensorflow后端和Theano后端的Keras使用相同的模型和相同的输入进行不同的预测 - Keras with Tensorflow backend and Theano backend make different predictions with same model and same input 可以使用在大型 GPU 上创建的 Tensorflow 保存模型在小型 CPU 上进行预测吗? - Can a Tensorflow saved model created on a large GPU be used to make predictions on a small CPU? Tensorflow 2.0 - 这些模型预测是否代表概率? - Tensorflow 2.0 - do these model predictions represent probabilities?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM