简体   繁体   中英

Tensorflow : How to transform the input data for tensorflow LSTM?

So , I am trying to use tensorflow for simple classification , My doubt is

If i use LSTM for text classification ( ex : sentiment classification ) then we do padding of data , after that for feeding to LSTM tensorflow we use word_embedding so after word_embedding lookup 2 dimension data become 3 dimension or rank 2 matrix become rank 3 :

like if i have two text :

import tensorflow as tf

text_seq=[[11,21,43,22,11,4,1,3,5,2,8],[4,2,11,4,11,0,0,0,0,0,0]]  #2x11 

#text_seq are index of words from word_to_index dict

a=tf.get_variable('word_embedding',shape=[50,50],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))

lookup=tf.nn.embedding_lookup(a,text_seq)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(lookup).shape)

I will get :

(2, 11, 50)

Which I can easily feed to LSTM because LSTM accept rank 3

But my problem is supposed if I have numerical float data instead of text data and I want to use RNN for classification ,

So suppose my data is:

import numpy as np

float_data=[[11.1,21.5,43.6,22.1,11.44],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1]]


labels=[1,2,3,4,5,6]
#2x5

batch_size=2

input_data_batch=[[11.1,21.5,43.6,22.1,11.44],[33.5,12.7,7.4,73.1,89.1]]


 #now should I reshape my data to make it rank 3 like this 


reshape_one=np.reshape(input_data_batch,[-1,batch_size,5])
print(reshape_one)


# or like this ?



reshape_two=np.reshape(input_data_batch,[batch_size,-1,5])

print(reshape_two)

output:

first one

[[[11.1  21.5  43.6  22.1  11.44]
  [33.5  12.7   7.4  73.1  89.1 ]]]

second one


[[[11.1  21.5  43.6  22.1  11.44]]

 [[33.5  12.7   7.4  73.1  89.1 ]]]

LSTMs and other sequence models can take input which is either time-major (ie the dimensions are time, batch, channel) or batch-major (the dimensions are batch, time, channel). I don't know what flags you are passing to which implementation of tf, so I can't tell from the code you provide whether you want batch-major or time-major inputs.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM