简体   繁体   中英

How to get input data shape when model is in predicting phase?

I tried to build Seq2Seq model by using TensorFlow 2.0.
It is not for NLP, but for predict timeseries data.

Since Seq2Seq model behave differently in training and predicting phase, I needed to write call function based on the phase.

In training phase, the decoder takes shifted timeseries data padded with zero. (eg encoder input: [1,2,3,4], decoder input: [0,1,2,3])

And in predicting phase, the decoder takes single zero tensor at first, and it takes its output as input until it produces same length of input timeseries data (eg decoder input: 1) [0] -- (decoder predict 1) --> [1] -- ... --> final output : [0,1,2,3])

Therefore, the call function in the predicting phase ( training=False ), it gets the input shape of encoder and makes a tf.zeros tensor based on the shape, and it goes to decoder's input.

Here is some part of the class of model.

class Seq2Seq(Model):
    def __init__(self, **kwargs):
            self.encoder = layers.LSTM(5, return_state=True, input_shape=(None, 1))
            self.decoder = layers.LSTM(5, return_state=True, return_sequences=True,
                                       input_shape=(None, 1))
            self.dense = layers.Dense(1)

    def call(self, inputs, training=False, mask=None, **kwargs):
        decoded = None

        if training:
            encoder_input, decoder_input = inputs[0], inputs[1]
            encoded, state_h, state_c = self.encoder(encoder_input)
            decoded, _, _ = self.decoder(decoder_input, initial_state=[state_h, state_c])
            decoded = self.dense(decoded)
        else:
            encoder_input = inputs
            batch, time_step, features = tf.shape(encoder_input)
            encoded, state_h, state_c = self.encoder(encoder_input)
            decoder_input = tf.zeros(shape=[batch, 1, features])
            decoded = None
            for i in range(time_step):
                decoder_input, state_h, state_c = self.decoder(decoder_input, initial_state=[state_h, state_c])
                decoder_input = self.dense(decoder_input)
                if decoded is not None:
                    decoded = tf.concat([decoded, decoder_input], axis=1)
                else:
                    decoded = tf.identity(decoder_input)
        return decoded

In training phase, inputs has two input data. One is the encoder_input has shape [batch_size, time_step, features] and, another is the decoder_input has same shape as encoder_input . But, in predicting phase, inputs is just same as the encoder_input of the training's.

Training worked well, but in validation, it produces error.

model.fit(train_dataset, epochs=10, validation_data=(seq_test_data, seq_test_data))
# >>> train_dataset.element_spec
# (TensorSpec(shape=(2, None, 140, 1), dtype=tf.float32, name=None),
# TensorSpec(shape=(None, 140, 1), dtype=tf.float32, name=None))
# >>> seq_test_data.shape
# TensorShape([1000, 140, 1])
Epoch 1/10
5/5 [==============================] - ETA: 0s - MSE: 0.1487Traceback (most recent call last):
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-91afa9c067da>", line 1, in <module>
    runfile('/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder_example_using_ecg_data.py', wdir='/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source')
  File "/home/dbadmin/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/dbadmin/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder_example_using_ecg_data.py", line 112, in <module>
    history = ae_rnn_model.fit(train_dataset, epochs=10, validation_data=(seq_test_data, seq_test_data))
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1133, in fit
    return_dict=True)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1379, in evaluate
    tmp_logs = test_function(iterator)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:1224 test_function  *
        return step_function(self, iterator)
    /home/dbadmin/woosung/DWT-AutoEncoder-Comparison/source/autoencoder.py:57 call  *
        batch, time_step, features = tf.shape(encoder_input)
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:503 __iter__
        self._disallow_iteration()
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:496 _disallow_iteration
        self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
    /home/dbadmin/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
        " indicate you are trying to use an unsupported feature.".format(task))
    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

I guess the error comes in the reason of that it is not possible to get exact input shape until I put explicit data.

However, still not clear how to fix it.

If I understand it correctly, the call function also takes the tuple of (x, y) in the evaluation phase. You should change encoder_input = inputs to encoder_input, decoder_input = inputs[0], inputs[1] also for the training==False case. But it does make sence to first check the length of the input to the call function, before separating into encoder_input, decoder_input. The predict method of the model would eg only give one object to the call function.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM