So here is what I can get with torch.eye(3,4)
now
The matrix I get:
[[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0]]
Is there any (easy)way to transform it, or make such a mask in this format:
The matrix I want:
[[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]
You can do it by using torch.diagonal
and specifying the diagonal you want:
>>> torch.diag(torch.tensor([1,1,1]), diagonal=1)[:-1]
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
- If :attr:
diagonal
= 0, it is the main diagonal.- If :attr:
diagonal
> 0, it is above the main diagonal.- If :attr:
diagonal
< 0, it is below the main diagonal.
Here is another solution usingtorch.diagflat()
, and using a positive offset
for shifting/moving the diagonal above the main diagonal .
# diagonal values to fill
In [253]: diagonal_vals = torch.ones(3, dtype=torch.long)
# desired tensor but ...
In [254]: torch.diagflat(diagonal_vals, offset=1)
Out[254]:
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 0, 0, 0]])
The above operation gives us a square matrix; however, we need a non-square matrix of shape (3,4)
. So, we'll just ignore the last row with simple indexing:
# shape (3, 4) with 1's above the main diagonal
In [255]: torch.diagflat(diagonal_vals, offset=1)[:-1]
Out[255]:
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.