简体   繁体   中英

Tensorflow : Understanding tf.contrib.seq2seq.BasicDecoder

I am trying to understand tf.contrib.seq2seq.BasicDecoder, Every example on web just use that wrapper but I couldn't find the explanation of what actually tf.contrib.seq2seq.BasicDecoder doing , I tried with one simple example :

import numpy as np
import tensorflow as tf
from pprint import pprint
from tensorflow.python.framework import tensor_shape
from tensorflow.contrib.rnn import BasicRNNCell

from tensorflow.contrib.seq2seq.python.ops.basic_decoder import BasicDecoder, BasicDecoderOutput
from tensorflow.contrib.seq2seq.python.ops.helper import TrainingHelper
from tensorflow.python.layers.core import Dense

sequence_length = [3, 4, 3, 1, 3]
batch_size = 5
max_time = 8
input_size = 7
hidden_size = 10
output_size = 3

inputs = np.random.randn(batch_size, max_time, input_size).astype(np.float32)

output_layer = Dense(output_size) # will get a trainable variable size [hidden_size x output_size]

dec_cell = BasicRNNCell(hidden_size)

helper = TrainingHelper(inputs, sequence_length)

decoder = BasicDecoder(
    cell=dec_cell,
    helper=helper,
    initial_state=dec_cell.zero_state(dtype=tf.float32, batch_size=batch_size),
    output_layer=output_layer)

first_finished, first_inputs, first_state = decoder.initialize()
(first_finished, first_inputs, first_state)

step_outputs, step_state, step_next_inputs, step_finished = decoder.step(
 tf.constant(0), first_inputs, first_state)
(step_outputs, step_state, step_next_inputs, step_finished)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    results = sess.run({
        "batch_size": decoder.batch_size,
        "first_finished": first_finished,
        "first_inputs": first_inputs,
        "first_state": first_state,
        "step_outputs": step_outputs,
        "step_state": step_state,
        "step_next_inputs": step_next_inputs,
        "step_finished": step_finished})
    pprint(results)

output is :

{'batch_size': 5,
 'first_finished': array([False, False, False, False,  True]),
 'first_inputs': array([[-0.1305329 ,  0.7027261 , -0.8157375 ,  0.01787353,  2.3610914 ,
         0.8905939 , -0.2685608 ],
       [-1.1782284 ,  1.6488065 ,  0.58254075,  0.12861735,  0.47683764,
        -2.05314   , -0.166469  ],
       [ 0.8365086 , -1.7963833 , -2.5053551 ,  2.3320568 , -0.357463  ,
        -0.01917691,  0.5789354 ],
       [-1.7942209 , -0.19699056,  0.42065838, -0.81790465,  2.5130792 ,
         1.2232817 ,  0.7819383 ],
       [ 1.2460921 , -0.16332811,  0.70908403, -1.334465  , -0.10106717,
        -0.26541698, -1.3249161 ]], dtype=float32),
 'first_state': array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 'step_finished': array([False, False, False,  True,  True]),
 'step_next_inputs': array([[ 1.3291198 , -0.15886226,  1.4437864 ,  0.41159418,  0.55492574,
        -0.90773547,  0.83662   ],
       [ 1.0856647 ,  2.3009017 ,  1.2625048 , -0.7682241 , -0.58327836,
        -1.2566029 ,  0.32073924],
       [ 0.2532574 ,  1.3086783 , -0.6277142 ,  1.8158357 , -0.9641214 ,
        -0.4462067 , -0.11307725],
       [ 0.48346692, -0.58842784,  0.4114005 ,  0.23313236, -0.81712246,
        -1.4564492 ,  0.7117556 ],
       [ 0.7588838 , -0.82005906,  0.663568  ,  0.24783312, -1.4573535 ,
         1.4284246 , -0.30952594]], dtype=float32),
 'step_outputs': BasicDecoderOutput(rnn_output=array([[ 1.4097914 , -0.69918895, -1.2088122 ],
       [-1.266958  , -0.8121094 , -0.03660662],
       [ 0.40251616, -0.11823708,  0.23454508],
       [ 1.3780088 , -0.86239576, -0.9247706 ],
       [ 0.09462224, -0.14165601,  0.39751652]], dtype=float32), sample_id=array([0, 2, 0, 0, 2], dtype=int32)),
 'step_state': array([[-0.19132493,  0.8753218 ,  0.07888561, -0.6356789 ,  0.72481483,
         0.4161568 ,  0.7337458 ,  0.06502081,  0.20294249, -0.73887783],
       [ 0.4778563 ,  0.1592015 , -0.86701995,  0.8127028 ,  0.09732129,
        -0.9266094 , -0.5395306 , -0.8694291 ,  0.87705237, -0.545192  ],
       [ 0.66678804,  0.82219815,  0.9689762 , -0.9692538 , -0.3958014 ,
         0.24547155,  0.05074365,  0.0893333 , -0.5242875 ,  0.18463017],
       [-0.8668696 ,  0.9405894 , -0.69780034, -0.1462304 ,  0.9349755 ,
         0.41605997,  0.9185027 , -0.07991812, -0.5194315 , -0.5538262 ],
       [ 0.47941405, -0.8954227 , -0.7062361 ,  0.3774918 ,  0.28503373,
         0.617851  , -0.36548492,  0.2932893 ,  0.3323133 , -0.35999647]],
      dtype=float32)}

I got it that it's returning rnn output and sample_id but i have confusion about time finished boolean output

So tf.contrib.seq2seq.BasicDecoder step function parameter are :

step(
    time,
    inputs,
    state,
    name=None
)

Now what actually time represent here ? if my sequence length is [3, 4, 3, 1, 3] now if i pass decoder.step(tf.constant(1), step_next_inputs, step_state)

output is:

array([False, False, False,  True,  True]))}  

so it means 5,4 sequence are unrolled , it means i have to pass sequence length as input , so i tried :

decoder.step(tf.constant(3), step_next_inputs, step_state)

output should be :

array([True, False, True,  True,  True]))}  

but i am getting:

array([ True,  True,  True,  True,  True]))}

How this is working and even if i am passing any arbitary value , then it's not giving error , it means it can unroll arbitary times ?

Here is google colab notebook , You can run this code online on my notebook

Please provide info about this .

Thank you !

As you run decoder.step(tf.constant(3), step_next_inputs, step_state) , it means the decoder already decodes 4 steps (from 0 to 3 and the 3 is finished), so the step finished is array([ True, True, True, True, True]) .

If you run decoder.step(tf.constant(2), step_next_inputs, step_state) , you will get array([ True, False, True, True, True]) as expected.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM