简体   繁体   English

当模型处于预测阶段时如何获取输入数据形状?

[英]How to get input data shape when model is in predicting phase?

I tried to build Seq2Seq model by using TensorFlow 2.0.我尝试使用 TensorFlow 2.0 构建 Seq2Seq 模型。
It is not for NLP, but for predict timeseries data.它不是用于 NLP,而是用于预测时间序列数据。

Since Seq2Seq model behave differently in training and predicting phase, I needed to write call function based on the phase.由于 Seq2Seq 模型在训练和预测阶段的行为不同,我需要根据阶段编写调用函数。

In training phase, the decoder takes shifted timeseries data padded with zero.在训练阶段,解码器采用用零填充的移位时间序列数据。 (eg encoder input: [1,2,3,4], decoder input: [0,1,2,3]) (例如编码器输入:[1,2,3,4],解码器输入:[0,1,2,3])

And in predicting phase, the decoder takes single zero tensor at first, and it takes its output as input until it produces same length of input timeseries data (eg decoder input: 1) [0] -- (decoder predict 1) --> [1] -- ... --> final output : [0,1,2,3])而在预测阶段,解码器首先采用单个零张量,并将其输出作为输入,直到产生相同长度的输入时间序列数据(例如解码器输入:1)[0] --(解码器预测1) --> [1] -- ... --> 最终输出:[0,1,2,3])

Therefore, the call function in the predicting phase ( training=False ), it gets the input shape of encoder and makes a tf.zeros tensor based on the shape, and it goes to decoder's input.因此,在预测阶段的调用函数( training=False ),它获取编码器的输入形状并根据该形状制作一个tf.zeros张量,然后进入解码器的输入。

Here is some part of the class of model.这是模型类的某些部分。

class Seq2Seq(Model):
    def __init__(self, **kwargs):
            self.encoder = layers.LSTM(5, return_state=True, input_shape=(None, 1))
            self.decoder = layers.LSTM(5, return_state=True, return_sequences=True,
                                       input_shape=(None, 1))
            self.dense = layers.Dense(1)

    def call(self, inputs, training=False, mask=None, **kwargs):
        decoded = None

        if training:
            encoder_input, decoder_input = inputs[0], inputs[1]
            encoded, state_h, state_c = self.encoder(encoder_input)
            decoded, _, _ = self.decoder(decoder_input, initial_state=[state_h, state_c])
            decoded = self.dense(decoded)
        else:
            encoder_input = inputs
            batch, time_step, features = tf.shape(encoder_input)
            encoded, state_h, state_c = self.encoder(encoder_input)
            decoder_input = tf.zeros(shape=[batch, 1, features])
            decoded = None
            for i in range(time_step):
                decoder_input, state_h, state_c = self.decoder(decoder_input, initial_state=[state_h, state_c])
                decoder_input = self.dense(decoder_input)
                if decoded is not None:
                    decoded = tf.concat([decoded, decoder_input], axis=1)
                else:
                    decoded = tf.identity(decoder_input)
        return decoded

In training phase, inputs has two input data.在训练阶段, inputs有两个输入数据。 One is the encoder_input has shape [batch_size, time_step, features] and, another is the decoder_input has same shape as encoder_input .一个是encoder_input具有形状[batch_size, time_step, features] ,另一个是decoder_input具有与encoder_input相同的形状。 But, in predicting phase, inputs is just same as the encoder_input of the training's.但是,在预测阶段, inputs与训练的encoder_input相同。

Training worked well, but in validation, it produces error.训练效果很好,但在验证中,它会产生错误。

model.fit(train_dataset, epochs=10, validation_data=(seq_test_data, seq_test_data))
# >>> train_dataset.element_spec
# (TensorSpec(shape=(2, None, 140, 1), dtype=tf.float32, name=None),
# TensorSpec(shape=(None, 140, 1), dtype=tf.float32, name=None))
# >>> seq_test_data.shape
# TensorShape([1000, 140, 1])
Epoch 1/10
5/5 [==============================] - ETA: 0s - MSE: 0.1487Traceback (most recent call last):
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-91afa9c067da>", line 1, in <module>
    runfile('/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder_example_using_ecg_data.py', wdir='/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source')
  File "/home/dbadmin/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/dbadmin/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder_example_using_ecg_data.py", line 112, in <module>
    history = ae_rnn_model.fit(train_dataset, epochs=10, validation_data=(seq_test_data, seq_test_data))
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1133, in fit
    return_dict=True)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1379, in evaluate
    tmp_logs = test_function(iterator)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:1224 test_function  *
        return step_function(self, iterator)
    /home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder.py:57 call  *
        batch, time_step, features = tf.shape(encoder_input)
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:503 __iter__
        self._disallow_iteration()
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:496 _disallow_iteration
        self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
        " indicate you are trying to use an unsupported feature.".format(task))
    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

I guess the error comes in the reason of that it is not possible to get exact input shape until I put explicit data.我猜这个错误是因为在我输入显式数据之前不可能获得精确的输入形状。

However, still not clear how to fix it.但是,仍然不清楚如何修复它。

If I understand it correctly, the call function also takes the tuple of (x, y) in the evaluation phase.如果我理解正确的话,调用函数在求值阶段也会采用 (x, y) 的元组。 You should change encoder_input = inputs to encoder_input, decoder_input = inputs[0], inputs[1] also for the training==False case.对于training==False情况encoder_input, decoder_input = inputs[0], inputs[1]您也应该将encoder_input = inputs更改为encoder_input, decoder_input = inputs[0], inputs[1] But it does make sence to first check the length of the input to the call function, before separating into encoder_input, decoder_input.但是在分离为encoder_input、decoder_input之前,首先检查调用函数的输入的长度确实是有意义的。 The predict method of the model would eg only give one object to the call function.例如,模型的预测方法将只给调用函数一个对象。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Vertex AI 似乎认为部署的 model 输入形状与在本地预测时不同 - Vertex AI seems to think a deployed model input shape is different then when predicting locally TF Keras 加载 model 时如何获得预期的输入形状? - TF Keras how to get expected input shape when loading a model? 预测值与 model 适合的训练数据的形状不同 - Predicting values that are not the same shape as the training data that the model fit to Python SKLearn:预测序列时出现“输入形状错误”错误 - Python SKLearn: 'Bad input shape' error when predicting a sequence 使用柱式变压器进行预处理并预测测试数据时,形状会发生变化 - Shape gets changed when preprocessing with column transformer and predicting the testing data 构建顺序 model 时与数据类型和输入形状相关的错误 - Errors related to data type and input shape when building a sequential model 预测新结果时检查模型输入角点时出错 - Error when checking model input keras when predicting new results 如何获取未知 PyTorch model 的输入张量形状 - How to get input tensor shape of an unknown PyTorch model 如何正确塑造 keras model 的输入数据? - How do i correctly shape my input data for a keras model? Tensorflow 在预测时给出错误:输入形状的预期轴 -1 的值为 784,但收到的输入形状为 [None, 28] - Tensorflow gives error when predicting: expected axis -1 of input shape to have value 784 but received input with shape [None, 28]
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM