繁体   English   中英

如何在没有固定batch_size的情况下设置Tensorflow dynamic_rnn、zero_state?

[英]How to set Tensorflow dynamic_rnn, zero_state without a fixed batch_size?

根据 Tensorflow 的官方网站,( https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state ) zero_state 必须指定一个 batch_size。 我发现的许多示例都使用此代码:

    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)

    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)

对于训练步骤,可以固定批量大小。 但是,在预测时,测试集的形状可能与训练集的批次大小不同。 例如,我的一批训练数据的形状为 [100, 255, 128]。 批量大小为 100,具有 255 个步骤和 128 个输入。 而测试集是 [2000, 255, 128]。 我无法预测,因为在 dynamic_rnn(initial_state) 中,它已经设置了固定的 batch_size = 100。我该如何解决这个问题?

谢谢。

您可以将batch_size指定为占位符,而不是常量。 只需确保在feed_dict提供相关数字,这对于训练和测试会有所不同

重要的是,指定[]作为占位符的维度,因为如果指定None可能会出错,这是其他地方的惯例。 所以这样的事情应该有效:

batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})

显然要确保批处理参数与输入的形状匹配[seq_len, batch_size, features]如果time_major设置为True ,则dynamic_rnn将解释为[batch_size, seq_len, features][seq_len, batch_size, features]

有一个相当简单的实现。 只需删除initial_state! 这是因为初始化过程可能会预先分配一个批量大小的内存。

正如@陈狗蛋所回答的, tf.compat.v1.nn.dynamic_rnn没有必要设置initial_state ,因为它是可选的。 你可以简单地这样做

outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, 
                                         X_inputs, 
                                         time_major=False, 
                                         dtype=tf.float32)

不要忘记设置dtype ,在这里我设置tf.float32 ,您可以设置dtype ,因为你需要。

正如tf.compat.v1.nn.rnn_cell.LSTMCell文档所说:

batch_size: int, float, or unit Tensor 表示批量大小

batch_size必须是显式值。 因此,为batch_size参数使用占位符是一种解决方法,但不是推荐的方法。 我建议你不要使用它,因为它在未来的版本中可能是一种无效的方式。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM