This may be too basic of a question, but what do the docs mean by the input to the GRU needs to be 3 dimensional? The GRU docs for PyTorch state:
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() for details.
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html
Let us say I am trying to predict the next # in a sequence and have the following dataset:
n, label
1, 2
2, 3
3, 6
4, 9
...
If I window the data using the prior 2 inputs for consideration when guessing the next, the dataset becomes:
t-2, t-1, t, label
na, na, 1, 2
na, 1, 2, 3
1, 2, 3, 6
2, 3, 4, 10
...
where tx just represents using an input value from a prior time step.
So, when creating a sequential loader, it should create the following tensor for the line 1,2,3,6:
inputs: tensor([[1,2,3]]) #shape(1,3)
labels: tensor([[6]]) #shape(1,1)
I currently understand the input shape as (# batches, # features per batch) and the output shape as (# batches, # output features per batch)
My question is, should that input tensor actually look like:
tensor([[[1],[2],[3]]])
Which represents (# batches, #prior inputs to consider, #features per input)
I guess I am better trying to understand why the input to a GRU has 3 dimensions in PyTorch. What does that 3rd dimension fundamentally represent? And if I have a transformed dataset like above, how to properly pass it to the model.
Edit: So the pattern present is:
1 + 1 = 2
2 + 1 = 3
3 + 2 + 1 = 6
4+ 3 + 2 + 1 = 10
I want it where t-2, t-1, and t represent the features at each time step used to help guess. For example, at every point in time there could be 2 features. The dimensions would be (1 batch size, 3 timesteps, 2 features).
My question is wether the GRU takes a flattened input:
(1 batch size, 3 time steps * 2 features per time step)
or the unflattened input:
(1 batch size, 3 time steps, 2 features per timestep)
I am currently under the impression that it is the 2nd input, but would like to check my understanding.
The nn.GRU
module works like other PyTorch RNN modules. It takes a three-dimensional tensor (seq_len, batch, input_size)
or (batch, seq_len, input_size)
if the argument batch_first
is set to True
. I think the last dimension is what's bothering you.
You explained you have your sequences set up like:
t-2 | t-1 | t | label |
---|---|---|---|
na | na | 1 | 2 |
na | 1 | 2 | 3 |
1 | 2 | 3 | 6 |
2 | 3 | 4 | 10 |
What you are missing is your input encoding: how will you represent your predictions and your labels? Feeding integers like this will not work right. What you probably need is to convert your data into one-hot-encodings.
Imagine having 10 different labels, ie you vocabulary is made up of 10 elements. Converting to one-hot-encodings is a straight forward process. Take a vector of zeros whose length is the vocabulary size and put a 1 at the index corresponding to a particular label.
The vocaulary size would in turn be... input_size
. Given a label
, this would look like:
encoding = torch.zeros(input_size)
encoding[label] = 1
label | one-hot-encoding |
---|---|
0 |
[1,0,0,..., 0, 0] |
1 |
[0,1,0,..., 0, 0] |
... | ... |
9 |
[0,0,0,..., 0, 1] |
Therefore, your training point ( input sequence : 1, 2, 3
, label : 6
) would translate to ( input sequence : [[0,1,0,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0,0,0]]
, label : 6
). Which is two-dimensional, added the extra dimension for the batch (see first §), this makes three.
I intentionnaly left the label as it is because it's common for PyTorch loss functions (such as nn.CrossEntropyLoss
) to require an index instead of a one-hot-encoding for the target (// label).
I figured it out. Essentially, the sequence length of 3 means that the input to the system needs to be: [[[1],[2],[3]], [[2], [3], [4]]] for a batch size of 2, sequence length of 3, and feature input per time step of 1. Essentially each sequence is an input at some time t to consider.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.