繁体   English   中英

连接循环内生成的 N 个 pytorch 张量(形状相同)

[英]Concatenate N pytorch tensors (of the same shape) generated from within loop

从循环中返回相同形状的张量,我想尽可能简洁地连接它们,并且尽可能以 Python 方式/pytorchly 方式连接它们。

当前解决方案:

import torch

for object_id in object_ids:
    
    dataset = Dataset(object_id)

    image_tensor = dataset.get_random_image_tensor()

    if 'concatenated_image_tensors' in locals():
        concatenated_image_tensors = torch.cat((merged_image_tensors, image_tensor))
    else:
        concatenated_image_tensors = image_tensor

有没有更好的办法?

一个好的方法是首先附加到一个 python列表,然后在末尾连接整个列表 否则,每次调用torch.cat时,您最终都会在内存中移动数据。

all_img = []
for object_id in object_ids:
    dataset = Dataset(object_id)
    image_tensor = dataset.get_random_image_tensor()
    all_img.append(image_tensor)

all_img = torch.cat(all_img)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM