简体   繁体   English

如何在TensorFlow RNN中检索中间状态

[英]How to retrieve intermediary state in TensorFlow RNN

I am running an RNN on a signal in fixed-size segments. 我在固定大小段的信号上运行RNN。 The following code allows me to preserve the final state of the previous batch to initialize the initial state of the next batch. 以下代码允许我保留上一个批次的最终状态,以初始化下一个批次的初始状态。

rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=init_state)

This works when the batches are non-overlapping. 当批次不重叠时,此方法有效。 For example, my first batch processes samples 0:124 and final_state is the state after this processing. 例如,我的第一批处理样本0:124,而final_state是此处理后的状态。 Then, the next batch processes samples 124:256, setting init_state to final_state . 然后,下一批处理样本124:256,将init_state设置为final_state

My question is how to retrieve an intermediary state when the batches are overlapping. 我的问题是批次重叠时如何检索中间状态。 First, I process samples 0:124, then 10:134, 20:144, so the hop size is 10. I would like to retrieve not the final_state but the state after processing 10 samples. 首先,我处理样本0:124,然后处理10: final_state :144,因此跃点大小为10。我不希望检索final_state而是要处理10个样本后的状态。

Is it possible in TF to keep the intermediary state? 在TF中是否可以保持中介状态? The documentation shows that the return value consists only of the final state. 文档显示返回值仅包含最终状态。

The image shows the issue I am facing due to state discontinuity. 该图显示了由于状态不连续而面临的问题。 In my program, the RNN segment length is 215 and the hop length is 20. 在我的程序中,RNN段的长度为215,跳段的长度为20。

样品结果

Update: the easiest turned out to be what David Parks described: 更新:最简单的结果就是大卫·帕克斯David Parks)所描述的:

rnn_outputs_one, mid_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs_one, initial_state=rnn_tuple_state)
rnn_outputs_two, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs_two, initial_state=mid_state)
rnn_outputs = rnn_outputs_one + rnn_outputs_two

and

prev_state = sess.run(mid_state)

Now, after just a few iterations, the results look much better. 现在,只需几次迭代,结果看起来就会好得多。 在此处输入图片说明

In tensorflow the only thing that is kept after returning from a call to sess.run are variables. 在tensorflow中,从调用sess.run返回后保留的唯一sess.run是变量。 You should create a variable for the state, then use tf.assign to assign the result from your RNN cell to that variable. 您应该为状态创建一个变量,然后使用tf.assign将RNN单元的结果分配给该变量。 You can then use that Variable in the same way as any other tensor. 然后,您可以像使用其他任何张量一样使用该变量。

If you need to initialize the variable to something other than 0 you can call sess.run once with a placeholder and tf.assign specifically to setup the variable. 如果需要将变量初始化sess.run 0 ,则可以使用占位符和tf.assign调用一次tf.assign专门设置变量。


Added detail: 添加的详细信息:

If you need an intermediate state, let's say you ran for timesteps 0:124 and you want step 10, you should split that up into 2 RNN cells, one that processes the first 10 timesteps and the second that continues processing the next 114 timesteps. 如果您需要一个中间状态,假设您运行了0:124的时间步,而您想要第10步,则应将其拆分为2个RNN单元,其中一个处理前10个时间步,第二个继续处理接下来的114个时间步。 This shouldn't affect training and back propagation as long as you use the same cell (LSTM or other cell) in both static_rnn functions. 只要您在两个static_rnn函数中使用相同的单元格(LSTM或其他单元格),这就不会影响训练和向后传播。 The cell is where your weights are defined, and that has to remain constant. 单元格是您定义权重的地方,必须保持不变。 Your gradient will flow backwards through the second cell and then finally the first appropriately. 您的渐变将倒流通过第二个单元格,然后适当地最后一个。

So, I came here looking for an answer earlier, but I ended up creating one. 所以,我来这里是为了早些找到答案,但最终还是找到了答案。

Similar to above posters about making it assignable... 类似于以上关于使其可分配的海报...

When you build your graph, make a list of sequence placeholders like.. 构建图形时,请列出序列占位符,例如。

my_states = [None] * int(sequence_length + 1)

my_states[0] = cell.zero_state()

for step in steps:
    cell_out, my_states[step+1] = cell( ) 

Then outside of your graph after the sess.run() you say 然后在图外面sess.run()之后说

new_states = my_states[1:]

model.my_states = new_states

This situation is for stepping 1 timestep at a time, but it could easily be made for steps of 10. Just slice the list of states after sess.run() and make those the initial states. 这种情况是一次步进1个时间步,但很容易将其步进10个步。只需将sess.run()之后的状态列表切成片,然后将其设为初始状态。

Good luck! 祝好运!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM