I try to implement a two-layer bidirectional LSTM with torch.nn.LSTM .
I made a toy example: a batch of 3 tensors, which are exactly the same (see my code below). And I expected the outputs of the BiLSTM to be the the same along the batch dimension, ie out[:,0,:] == out[:,1,:] == out[:, 2, :]
.
But it seemed not to be the case. According to my experiments, 20%~40% of the time, the output were not the same. So I wonder where I got it wrong.
# Python 3.6.6, Pytorch 0.4.1
import torch
def test(hidden_size, in_size):
seq_len, batch = 4, 3
bilstm = torch.nn.LSTM(input_size=in_size, hidden_size=hidden_size,
num_layers=2, bidirectional=True)
# create a batch with 3 exactly the same tensors
a = torch.rand(seq_len, 1, in_size) # (seq_len, 1, in_size)
x = torch.cat((a, a, a), dim=1)
out, _ = bilstm(x) # (seq_len, batch, n_direction * hidden_size)
# expect the output should be the same along the batch dimension
assert torch.equal(out[:, 0, :], out[:, 1, :])
assert torch.equal(out[:, 1, :], out[:, 2, :])
if __name__ == '__main__':
count, total = 0, 0
for h_size in range(1, 51):
for in_size in range(1, 51):
total += 1
try:
test(h_size, in_size)
except AssertionError:
count += 1
print('percentage of assertion error:', count / total)
What is confusing you is the float precision. Floating point operations are slightly inaccurate and can differ by very small amounts Use this instead:
torch.set_default_dtype(torch.float64)
Then you will see they should be the same along the batch dim.
Thanks for correcting some English grammar mistakes.
I had the same issue with GRU
, and the following solved it for me.
Set a manual seed and set your model in evaluation mode before testing:
torch.manual_seed(42)
bilstm.eval() # or: bilstm.train(false)
Source: LSTMcell and LSTM returning different outputs
In addition, I had to set the same seed before each call to the model (during testing). In your case:
torch.manual_seed(42)
out, _ = bilstm(x) # (seq_len, batch, n_direction * hidden_size)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.