[英]How to set Tensorflow dynamic_rnn, zero_state without a fixed batch_size?
根据 Tensorflow 的官方网站,( https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state ) zero_state 必须指定一个 batch_size。 我发现的许多示例都使用此代码:
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in,
initial_state=init_state, time_major=False)
对于训练步骤,可以固定批量大小。 但是,在预测时,测试集的形状可能与训练集的批次大小不同。 例如,我的一批训练数据的形状为 [100, 255, 128]。 批量大小为 100,具有 255 个步骤和 128 个输入。 而测试集是 [2000, 255, 128]。 我无法预测,因为在 dynamic_rnn(initial_state) 中,它已经设置了固定的 batch_size = 100。我该如何解决这个问题?
谢谢。
您可以将batch_size
指定为占位符,而不是常量。 只需确保在feed_dict
提供相关数字,这对于训练和测试会有所不同
重要的是,指定[]
作为占位符的维度,因为如果指定None
可能会出错,这是其他地方的惯例。 所以这样的事情应该有效:
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in,
initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})
显然要确保批处理参数与输入的形状匹配[seq_len, batch_size, features]
如果time_major
设置为True
,则dynamic_rnn
将解释为[batch_size, seq_len, features]
或[seq_len, batch_size, features]
有一个相当简单的实现。 只需删除initial_state! 这是因为初始化过程可能会预先分配一个批量大小的内存。
正如@陈狗蛋所回答的, tf.compat.v1.nn.dynamic_rnn
没有必要设置initial_state
,因为它是可选的。 你可以简单地这样做
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell,
X_inputs,
time_major=False,
dtype=tf.float32)
不要忘记设置dtype
,在这里我设置tf.float32
,您可以设置dtype
,因为你需要。
正如tf.compat.v1.nn.rnn_cell.LSTMCell
的文档所说:
batch_size: int, float, or unit Tensor 表示批量大小
batch_size
必须是显式值。 因此,为batch_size
参数使用占位符是一种解决方法,但不是推荐的方法。 我建议你不要使用它,因为它在未来的版本中可能是一种无效的方式。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.