簡體   English   中英

當模型處於預測階段時如何獲取輸入數據形狀?

[英]How to get input data shape when model is in predicting phase?

我嘗試使用 TensorFlow 2.0 構建 Seq2Seq 模型。
它不是用於 NLP,而是用於預測時間序列數據。

由於 Seq2Seq 模型在訓練和預測階段的行為不同,我需要根據階段編寫調用函數。

在訓練階段,解碼器采用用零填充的移位時間序列數據。 (例如編碼器輸入:[1,2,3,4],解碼器輸入:[0,1,2,3])

而在預測階段,解碼器首先采用單個零張量,並將其輸出作為輸入,直到產生相同長度的輸入時間序列數據(例如解碼器輸入:1)[0] --(解碼器預測1) --> [1] -- ... --> 最終輸出:[0,1,2,3])

因此,在預測階段的調用函數( training=False ),它獲取編碼器的輸入形狀並根據該形狀制作一個tf.zeros張量,然后進入解碼器的輸入。

這是模型類的某些部分。

class Seq2Seq(Model):
    def __init__(self, **kwargs):
            self.encoder = layers.LSTM(5, return_state=True, input_shape=(None, 1))
            self.decoder = layers.LSTM(5, return_state=True, return_sequences=True,
                                       input_shape=(None, 1))
            self.dense = layers.Dense(1)

    def call(self, inputs, training=False, mask=None, **kwargs):
        decoded = None

        if training:
            encoder_input, decoder_input = inputs[0], inputs[1]
            encoded, state_h, state_c = self.encoder(encoder_input)
            decoded, _, _ = self.decoder(decoder_input, initial_state=[state_h, state_c])
            decoded = self.dense(decoded)
        else:
            encoder_input = inputs
            batch, time_step, features = tf.shape(encoder_input)
            encoded, state_h, state_c = self.encoder(encoder_input)
            decoder_input = tf.zeros(shape=[batch, 1, features])
            decoded = None
            for i in range(time_step):
                decoder_input, state_h, state_c = self.decoder(decoder_input, initial_state=[state_h, state_c])
                decoder_input = self.dense(decoder_input)
                if decoded is not None:
                    decoded = tf.concat([decoded, decoder_input], axis=1)
                else:
                    decoded = tf.identity(decoder_input)
        return decoded

在訓練階段, inputs有兩個輸入數據。 一個是encoder_input具有形狀[batch_size, time_step, features] ,另一個是decoder_input具有與encoder_input相同的形狀。 但是,在預測階段, inputs與訓練的encoder_input相同。

訓練效果很好,但在驗證中,它會產生錯誤。

model.fit(train_dataset, epochs=10, validation_data=(seq_test_data, seq_test_data))
# >>> train_dataset.element_spec
# (TensorSpec(shape=(2, None, 140, 1), dtype=tf.float32, name=None),
# TensorSpec(shape=(None, 140, 1), dtype=tf.float32, name=None))
# >>> seq_test_data.shape
# TensorShape([1000, 140, 1])
Epoch 1/10
5/5 [==============================] - ETA: 0s - MSE: 0.1487Traceback (most recent call last):
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-91afa9c067da>", line 1, in <module>
    runfile('/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder_example_using_ecg_data.py', wdir='/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source')
  File "/home/dbadmin/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/dbadmin/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder_example_using_ecg_data.py", line 112, in <module>
    history = ae_rnn_model.fit(train_dataset, epochs=10, validation_data=(seq_test_data, seq_test_data))
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1133, in fit
    return_dict=True)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1379, in evaluate
    tmp_logs = test_function(iterator)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:1224 test_function  *
        return step_function(self, iterator)
    /home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder.py:57 call  *
        batch, time_step, features = tf.shape(encoder_input)
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:503 __iter__
        self._disallow_iteration()
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:496 _disallow_iteration
        self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
        " indicate you are trying to use an unsupported feature.".format(task))
    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

我猜這個錯誤是因為在我輸入顯式數據之前不可能獲得精確的輸入形狀。

但是,仍然不清楚如何修復它。

如果我理解正確的話,調用函數在求值階段也會采用 (x, y) 的元組。 對於training==False情況encoder_input, decoder_input = inputs[0], inputs[1]您也應該將encoder_input = inputs更改為encoder_input, decoder_input = inputs[0], inputs[1] 但是在分離為encoder_input、decoder_input之前,首先檢查調用函數的輸入的長度確實是有意義的。 例如,模型的預測方法將只給調用函數一個對象。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM