简体   繁体   中英

Confused Regarding PyTorch GRU Docs

This may be too basic of a question, but what do the docs mean by the input to the GRU needs to be 3 dimensional? The GRU docs for PyTorch state:

input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() for details.

https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

Let us say I am trying to predict the next # in a sequence and have the following dataset:

n, label
1, 2
2, 3
3, 6
4, 9
...

If I window the data using the prior 2 inputs for consideration when guessing the next, the dataset becomes:

t-2, t-1, t, label
na, na, 1, 2
na, 1, 2, 3
1, 2, 3, 6
2, 3, 4, 10
...

where tx just represents using an input value from a prior time step.

So, when creating a sequential loader, it should create the following tensor for the line 1,2,3,6:

inputs: tensor([[1,2,3]]) #shape(1,3)
labels: tensor([[6]])     #shape(1,1)

I currently understand the input shape as (# batches, # features per batch) and the output shape as (# batches, # output features per batch)

My question is, should that input tensor actually look like:

tensor([[[1],[2],[3]]])

Which represents (# batches, #prior inputs to consider, #features per input)

I guess I am better trying to understand why the input to a GRU has 3 dimensions in PyTorch. What does that 3rd dimension fundamentally represent? And if I have a transformed dataset like above, how to properly pass it to the model.

Edit: So the pattern present is:

1 + 1 = 2
2 + 1 = 3
3 + 2 + 1 = 6
4+ 3 + 2 + 1 = 10

I want it where t-2, t-1, and t represent the features at each time step used to help guess. For example, at every point in time there could be 2 features. The dimensions would be (1 batch size, 3 timesteps, 2 features).

My question is wether the GRU takes a flattened input:

(1 batch size, 3 time steps * 2 features per time step)

or the unflattened input:

(1 batch size, 3 time steps, 2 features per timestep)

I am currently under the impression that it is the 2nd input, but would like to check my understanding.

The nn.GRU module works like other PyTorch RNN modules. It takes a three-dimensional tensor (seq_len, batch, input_size) or (batch, seq_len, input_size) if the argument batch_first is set to True . I think the last dimension is what's bothering you.

You explained you have your sequences set up like:

t-2 t-1 t label
na na 1 2
na 1 2 3
1 2 3 6
2 3 4 10

What you are missing is your input encoding: how will you represent your predictions and your labels? Feeding integers like this will not work right. What you probably need is to convert your data into one-hot-encodings.

Imagine having 10 different labels, ie you vocabulary is made up of 10 elements. Converting to one-hot-encodings is a straight forward process. Take a vector of zeros whose length is the vocabulary size and put a 1 at the index corresponding to a particular label.

The vocaulary size would in turn be... input_size . Given a label , this would look like:

encoding = torch.zeros(input_size)
encoding[label] = 1
label one-hot-encoding
0 [1,0,0,..., 0, 0]
1 [0,1,0,..., 0, 0]
... ...
9 [0,0,0,..., 0, 1]

Therefore, your training point ( input sequence : 1, 2, 3 , label : 6 ) would translate to ( input sequence : [[0,1,0,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0,0,0]] , label : 6 ). Which is two-dimensional, added the extra dimension for the batch (see first §), this makes three.

I intentionnaly left the label as it is because it's common for PyTorch loss functions (such as nn.CrossEntropyLoss ) to require an index instead of a one-hot-encoding for the target (// label).

I figured it out. Essentially, the sequence length of 3 means that the input to the system needs to be: [[[1],[2],[3]], [[2], [3], [4]]] for a batch size of 2, sequence length of 3, and feature input per time step of 1. Essentially each sequence is an input at some time t to consider.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM