[英]TensorFlow dynamic_rnn input for regression
我被困试图将现有的张量流序列转换为序列分类器到回归器。
目前,我一直在处理tf.nn.dynamic_rnn()
的输入。 根据文档和其他答案,输入应为(batch_size, sequence_length, input_size)
的形状。 但是我的输入数据只有两个维度: (sequence_length, batch_size)
。
原始解决方案使用tf.nn.embedding_lookup()
作为将输入提供给dynamic_rnn()
的中间步骤。 如果我理解正确,我认为我不需要此步骤,因为我正在研究回归问题,而不是分类问题。
我需要embedding_lookup步骤吗? 如果是这样,为什么? 如果没有,如何将我的encoder_inputs
直接放入dynamic_rnn()
?
以下是该总体思路的一个可行的最小化示例:
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
sess = tf.InteractiveSession()
PAD = 0
EOS = 1
VOCAB_SIZE = 10 # Don't think I should need this for regression?
input_embedding_size = 20
encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units
LENGTH_MIN = 3
LENGTH_MAX = 8
VOCAB_LOWER = 2
VOCAB_UPPER = VOCAB_SIZE
BATCH_SIZE = 10
def get_random_sequences():
sequences = []
for j in range(BATCH_SIZE):
random_numbers = np.random.randint(3, 10, size=8)
sequences.append(random_numbers)
sequences = np.asarray(sequences).T
return(sequences)
def next_feed():
batch = get_random_sequences()
encoder_inputs_ = batch
eos = np.ones(BATCH_SIZE)
decoder_targets_ = np.hstack((batch.T, np.atleast_2d(eos).T)).T
decoder_inputs_ = np.hstack((np.atleast_2d(eos).T, batch.T)).T
#print(encoder_inputs_)
#print(decoder_inputs_)
return {
encoder_inputs: encoder_inputs_,
decoder_inputs: decoder_inputs_,
decoder_targets: decoder_targets_,
}
### "MAIN"
# Placeholders
encoder_inputs = tf.placeholder(shape=(LENGTH_MAX, BATCH_SIZE), dtype=tf.int32, name='encoder_inputs')
decoder_targets = tf.placeholder(shape=(LENGTH_MAX + 1, BATCH_SIZE), dtype=tf.int32, name='decoder_targets')
decoder_inputs = tf.placeholder(shape=(LENGTH_MAX + 1, BATCH_SIZE), dtype=tf.int32, name='decoder_inputs')
# Don't think I should need this for regression problems
embeddings = tf.Variable(tf.random_uniform([VOCAB_SIZE, input_embedding_size], -1.0, 1.0), dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)
# Encoder RNN
encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
encoder_cell, encoder_inputs_embedded, # Throws 'ValueError: Shape (8, 10) must have rank at least 3' if encoder_inputs is used
dtype=tf.float32, time_major=True,
)
# Decoder RNN
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
decoder_cell, decoder_inputs_embedded,
initial_state=encoder_final_state,
dtype=tf.float32, time_major=True, scope="plain_decoder",
)
decoder_logits = tf.contrib.layers.linear(decoder_outputs, VOCAB_SIZE)
decoder_prediction = tf.argmax(decoder_logits, 2)
# Loss function
loss = tf.reduce_mean(tf.squared_difference(decoder_logits, tf.one_hot(decoder_targets, depth=VOCAB_SIZE, dtype=tf.float32)))
train_op = tf.train.AdamOptimizer().minimize(loss)
sess.run(tf.global_variables_initializer())
max_batches = 5000
batches_in_epoch = 500
print('Starting train')
try:
for batch in range(max_batches):
feed = next_feed()
_, l = sess.run([train_op, loss], feed)
if batch == 0 or batch % batches_in_epoch == 0:
print('batch {}'.format(batch))
print(' minibatch loss: {}'.format(sess.run(loss, feed)))
predict_ = sess.run(decoder_prediction, feed)
for i, (inp, pred) in enumerate(zip(feed[encoder_inputs].T, predict_.T)):
print(' sample {}:'.format(i + 1))
print(' input > {}'.format(inp))
print(' predicted > {}'.format(pred))
if i >= 2:
break
print()
except KeyboardInterrupt:
print('training interrupted')
我在这里已经阅读了关于stackoverflow的类似问题,但是我仍然对如何解决这个问题感到困惑。
编辑:我想我应该澄清一下上面的代码效果很好,但是真正想要的输出应该模拟一个嘈杂的信号(例如,文本到语音),这就是为什么我认为我需要连续的输出值而不是单词或字母的原因。
如果您尝试连续进行操作,为什么不将输入占位符重塑为[BATCH, TIME_STEPS, 1]
的形状[BATCH, TIME_STEPS, 1]
并通过tf.expand_dims(input, 2)
在输入中添加一个额外的尺寸。 这样,您的输入将符合dynamic_rnn
期望的尺寸(实际上,在您的情况下,因为您正在执行time_major=True
您的输入应为[TIME_STEPS, BATCH, 1])
形状[TIME_STEPS, BATCH, 1])
我很想知道您接下来将如何处理将输出尺寸从像元大小切换为1的操作。
decoder_logits = tf.contrib.layers.linear(decoder_outputs, VOCAB_SIZE)
但是,由于您不再进行分类,因此VOCAB_SIZE
仅为1? 几天前我在这里问了类似的问题,但没有得到任何回应。 我这样做(使用1),但是不确定是否合适(似乎是在实践中进行排序,但不是很完美)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.