I am trying to understand tf.contrib.seq2seq.BasicDecoder, Every example on web just use that wrapper but I couldn't find the explanation of what actually tf.contrib.seq2seq.BasicDecoder doing , I tried with one simple example :
import numpy as np
import tensorflow as tf
from pprint import pprint
from tensorflow.python.framework import tensor_shape
from tensorflow.contrib.rnn import BasicRNNCell
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import BasicDecoder, BasicDecoderOutput
from tensorflow.contrib.seq2seq.python.ops.helper import TrainingHelper
from tensorflow.python.layers.core import Dense
sequence_length = [3, 4, 3, 1, 3]
batch_size = 5
max_time = 8
input_size = 7
hidden_size = 10
output_size = 3
inputs = np.random.randn(batch_size, max_time, input_size).astype(np.float32)
output_layer = Dense(output_size) # will get a trainable variable size [hidden_size x output_size]
dec_cell = BasicRNNCell(hidden_size)
helper = TrainingHelper(inputs, sequence_length)
decoder = BasicDecoder(
cell=dec_cell,
helper=helper,
initial_state=dec_cell.zero_state(dtype=tf.float32, batch_size=batch_size),
output_layer=output_layer)
first_finished, first_inputs, first_state = decoder.initialize()
(first_finished, first_inputs, first_state)
step_outputs, step_state, step_next_inputs, step_finished = decoder.step(
tf.constant(0), first_inputs, first_state)
(step_outputs, step_state, step_next_inputs, step_finished)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
results = sess.run({
"batch_size": decoder.batch_size,
"first_finished": first_finished,
"first_inputs": first_inputs,
"first_state": first_state,
"step_outputs": step_outputs,
"step_state": step_state,
"step_next_inputs": step_next_inputs,
"step_finished": step_finished})
pprint(results)
output is :
{'batch_size': 5,
'first_finished': array([False, False, False, False, True]),
'first_inputs': array([[-0.1305329 , 0.7027261 , -0.8157375 , 0.01787353, 2.3610914 ,
0.8905939 , -0.2685608 ],
[-1.1782284 , 1.6488065 , 0.58254075, 0.12861735, 0.47683764,
-2.05314 , -0.166469 ],
[ 0.8365086 , -1.7963833 , -2.5053551 , 2.3320568 , -0.357463 ,
-0.01917691, 0.5789354 ],
[-1.7942209 , -0.19699056, 0.42065838, -0.81790465, 2.5130792 ,
1.2232817 , 0.7819383 ],
[ 1.2460921 , -0.16332811, 0.70908403, -1.334465 , -0.10106717,
-0.26541698, -1.3249161 ]], dtype=float32),
'first_state': array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
'step_finished': array([False, False, False, True, True]),
'step_next_inputs': array([[ 1.3291198 , -0.15886226, 1.4437864 , 0.41159418, 0.55492574,
-0.90773547, 0.83662 ],
[ 1.0856647 , 2.3009017 , 1.2625048 , -0.7682241 , -0.58327836,
-1.2566029 , 0.32073924],
[ 0.2532574 , 1.3086783 , -0.6277142 , 1.8158357 , -0.9641214 ,
-0.4462067 , -0.11307725],
[ 0.48346692, -0.58842784, 0.4114005 , 0.23313236, -0.81712246,
-1.4564492 , 0.7117556 ],
[ 0.7588838 , -0.82005906, 0.663568 , 0.24783312, -1.4573535 ,
1.4284246 , -0.30952594]], dtype=float32),
'step_outputs': BasicDecoderOutput(rnn_output=array([[ 1.4097914 , -0.69918895, -1.2088122 ],
[-1.266958 , -0.8121094 , -0.03660662],
[ 0.40251616, -0.11823708, 0.23454508],
[ 1.3780088 , -0.86239576, -0.9247706 ],
[ 0.09462224, -0.14165601, 0.39751652]], dtype=float32), sample_id=array([0, 2, 0, 0, 2], dtype=int32)),
'step_state': array([[-0.19132493, 0.8753218 , 0.07888561, -0.6356789 , 0.72481483,
0.4161568 , 0.7337458 , 0.06502081, 0.20294249, -0.73887783],
[ 0.4778563 , 0.1592015 , -0.86701995, 0.8127028 , 0.09732129,
-0.9266094 , -0.5395306 , -0.8694291 , 0.87705237, -0.545192 ],
[ 0.66678804, 0.82219815, 0.9689762 , -0.9692538 , -0.3958014 ,
0.24547155, 0.05074365, 0.0893333 , -0.5242875 , 0.18463017],
[-0.8668696 , 0.9405894 , -0.69780034, -0.1462304 , 0.9349755 ,
0.41605997, 0.9185027 , -0.07991812, -0.5194315 , -0.5538262 ],
[ 0.47941405, -0.8954227 , -0.7062361 , 0.3774918 , 0.28503373,
0.617851 , -0.36548492, 0.2932893 , 0.3323133 , -0.35999647]],
dtype=float32)}
I got it that it's returning rnn output and sample_id but i have confusion about time finished boolean output
So tf.contrib.seq2seq.BasicDecoder step function parameter are :
step(
time,
inputs,
state,
name=None
)
Now what actually time represent here ? if my sequence length is [3, 4, 3, 1, 3] now if i pass decoder.step(tf.constant(1), step_next_inputs, step_state)
output is:
array([False, False, False, True, True]))}
so it means 5,4 sequence are unrolled , it means i have to pass sequence length as input , so i tried :
decoder.step(tf.constant(3), step_next_inputs, step_state)
output should be :
array([True, False, True, True, True]))}
but i am getting:
array([ True, True, True, True, True]))}
How this is working and even if i am passing any arbitary value , then it's not giving error , it means it can unroll arbitary times ?
Here is google colab notebook , You can run this code online on my notebook
Please provide info about this .
Thank you !
As you run decoder.step(tf.constant(3), step_next_inputs, step_state)
, it means the decoder already decodes 4 steps (from 0
to 3
and the 3
is finished), so the step finished is array([ True, True, True, True, True])
.
If you run decoder.step(tf.constant(2), step_next_inputs, step_state)
, you will get array([ True, False, True, True, True])
as expected.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.