简体   繁体   English

如何从 PyTorch split() 获取张量

[英]How to get tensor from PyTorch split()

PyTorch's split function returns back a tuple of tensors. PyTorch 的split function 返回一个张量元组。 But I need to batch matrix multiply the result.但我需要批量矩阵乘以结果。 Is there an easy way to split a tensor and get back a tensor?有没有一种简单的方法可以拆分张量并取回张量? This is what I tried:这是我试过的:

m = [[2, 3, 5, 7],
     [11, 13, 17, 19],
     [23, 29, 31, 37],
     [41, 43, 47, 53]]
m_split = torch.tensor(m).split(2, dim=1)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)

This gives me an error because m_split is a tuple of tensors rather than being a tensor.这给了我一个错误,因为m_split是张量的元组而不是张量。 Is there a view or reshape call I can make instead?我可以改为进行viewreshape调用吗?

i think you can do as following我想你可以做如下

m = [[2, 3, 5, 7],
     [11, 13, 17, 19],
     [23, 29, 31, 37],
     [41, 43, 47, 53]]
m_split = torch.tensor(m).tensor_split(2, dim=1)
m_split=torch.stack(list(m_split), dim=0)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM