[英]How to get tensor from PyTorch split()
PyTorch's split
function returns back a tuple of tensors. PyTorch 的
split
function 返回一个张量元组。 But I need to batch matrix multiply the result.但我需要批量矩阵乘以结果。 Is there an easy way to split a tensor and get back a tensor?
有没有一种简单的方法可以拆分张量并取回张量? This is what I tried:
这是我试过的:
m = [[2, 3, 5, 7],
[11, 13, 17, 19],
[23, 29, 31, 37],
[41, 43, 47, 53]]
m_split = torch.tensor(m).split(2, dim=1)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)
This gives me an error because m_split
is a tuple of tensors rather than being a tensor.这给了我一个错误,因为
m_split
是张量的元组而不是张量。 Is there a view
or reshape
call I can make instead?我可以改为进行
view
或reshape
调用吗?
i think you can do as following我想你可以做如下
m = [[2, 3, 5, 7],
[11, 13, 17, 19],
[23, 29, 31, 37],
[41, 43, 47, 53]]
m_split = torch.tensor(m).tensor_split(2, dim=1)
m_split=torch.stack(list(m_split), dim=0)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.