[英]Pytorch: split a tensor by column
How can I split a tensor by column (axis = 1).如何按列(轴 = 1)拆分张量。 For example
例如
"""
input: result:
tensor([[1, 1], (tensor([1, 2, 3, 1, 2, 3]),
[2, 1], tensor([1, 1, 2, 2, 3, 3]))
[3, 2],
[1, 2],
[2, 3],
[3, 3]])
"""
The solution I came out with is first transpose the input tensor, split it and then flatten each of the split tensor.我提出的解决方案是首先转置输入张量,将其拆分,然后展平每个拆分张量。 However, is there a simpler and more effective way on doing this?
但是,有没有更简单、更有效的方法来做到这一点? Thank you
谢谢
import torch
x = torch.LongTensor([[1,1],[2,1],[3,2],[1,2],[2,3],[3,3]])
x1, x2 = torch.split(x.T, 1)
x1 = torch.flatten(x1)
x2 = torch.flatten(x2)
x1, x2 # output
Simply do:只需这样做:
x1 = x[:, 0]
x2 = x[:, 1]
# x1: (tensor([1, 2, 3, 1, 2, 3]), x2: tensor([1, 1, 2, 2, 3, 3]))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.