简体   繁体   English

将数组元组转换为张量,然后将它们堆叠在 pytorch 中

[英]Convert tuple of arrays into tensors to then stack them in pytorch

I have this tuple called train, containing 2 arrays, first (10000,10), second (1000):我有这个名为 train 的元组,包含 2 个数组,第一个 (10000,10),第二个 (1000):

 (array([[0.0727882 , 0.82148589, 0.9932996 , ..., 0.9604997 , 0.48725072,
     0.87095636],
    [0.28299425, 0.94904277, 0.69887889, ..., 0.59392614, 0.96375439,
     0.23708264],
    [0.44746802, 0.46455956, 0.99537243, ..., 0.03077313, 0.60441346,
     0.5284877 ],
    ...,
    [0.74851845, 0.59469311, 0.20880812, ..., 0.82080042, 0.16033365,
     0.94729764],
    [0.56686195, 0.35784948, 0.15531381, ..., 0.95415527, 0.88907735,
     0.39981913],
    [0.61606041, 0.30158736, 0.65476444, ..., 0.0637397 , 0.76772078,
     0.85285724]]), array([ 9.78050432, 21.84804394, 13.14748592, ..., 17.86811178,
    14.94744237,  9.80791838]))

I've tried this to them stack them but there is a shape mismatch我试过用这个把它们叠起来,但形状不匹配

seq = torch.as_tensor(train[0], dtype=None, device=None)

label = torch.as_tensor(train[1], dtype=None, device=None)

#seq.size() = torch.Size([10000,10])
#label.size() = torch.Size([10000])

My goal is to stack 10000 tensors of len(10) with the 10000 tensors label.我的目标是将 len(10) 的 10000 张量与 10000 张量标签堆叠在一起。 Be able to treat a seq as single tensor like people do with images.能够像人们处理图像一样将 seq 视为单个张量。

Where one instance would look like this like this:其中一个实例看起来像这样:

[tensor(0.0727882 , 0.82148589, 0.9932996 , ..., 0.9604997 , 0.48725072,
     0.87095636]), tensor(9.78050432)]

Thanks you,谢谢,

Where/what is your error exactly?您的错误究竟在哪里/什么?

Because, to get your desired output it looks like you could just run:因为,要获得所需的输出,您似乎可以运行:

stack = [[seq[i],label[i]] for i in range(seq.shape[0])]

But, if you want a sequence of size [10000,11], then you need to expand the dims of the label tensor to be concatenatable (made that word up) along the second axis:但是,如果您想要一个大小为 [10000,11] 的序列,那么您需要将标签张量的维度扩展为可沿第二个轴连接(组成该词):

label = torch.unsqueeze(label,1)
stack = torch.cat([seq,label],1)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM