简体   繁体   English

Pytorch:按列拆分张量

[英]Pytorch: split a tensor by column

How can I split a tensor by column (axis = 1).如何按列(轴 = 1)拆分张量。 For example例如

"""
input:              result:
tensor([[1, 1],     (tensor([1, 2, 3, 1, 2, 3]), 
        [2, 1],      tensor([1, 1, 2, 2, 3, 3]))
        [3, 2],  
        [1, 2],
        [2, 3],
        [3, 3]])
"""

The solution I came out with is first transpose the input tensor, split it and then flatten each of the split tensor.我提出的解决方案是首先转置输入张量,将其拆分,然后展平每个拆分张量。 However, is there a simpler and more effective way on doing this?但是,有没有更简单、更有效的方法来做到这一点? Thank you谢谢

import torch
x = torch.LongTensor([[1,1],[2,1],[3,2],[1,2],[2,3],[3,3]])
x1, x2 = torch.split(x.T, 1)
x1 = torch.flatten(x1)
x2 = torch.flatten(x2)
x1, x2 # output

Simply do:只需这样做:

x1 = x[:, 0]
x2 = x[:, 1]
# x1: (tensor([1, 2, 3, 1, 2, 3]), x2: tensor([1, 1, 2, 2, 3, 3]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM