簡體   English   中英

PyTorch 中的 Concat 張量

[英]Concat tensors in PyTorch

我有一個名為[128, 4, 150, 150]形狀data的張量[128, 4, 150, 150]其中 128 是批量大小,4 是通道數,最后兩個維度是高度和寬度。 我有另一個稱為[128, 1, 150, 150]形狀的fake張量。

我想從data的第二維中刪除最后一個list/array 數據的形狀現在是[128, 3, 150, 150] 並將其與fake連接起來,給出連接的輸出維度為[128, 4, 150, 150]

基本上,換句話說,我想將data的前 3 個維度與fake以給出一個 4 維張量。

我正在使用 PyTorch 並遇到了函數torch.cat()torch.stack()

這是我編寫的示例代碼:

fake_combined = []
        for j in range(batch_size):
            fake_combined.append(torch.stack((data[j][0].to(device), data[j][1].to(device), data[j][2].to(device), fake[j][0].to(device))))
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
fake_combined = fake_combined.to(device)

但是我在行中遇到錯誤:

fake_combined = torch.tensor(fake_combined, dtype=torch.float32)

錯誤是:

ValueError: only one element tensors can be converted to Python scalars

另外,如果我打印fake_combined的形狀,我得到的輸出為[128,]而不是[128, 4, 150, 150]

當我打印fake_combined[0]的形狀時,我得到的輸出為[4, 150, 150] ,這是預期的。

所以我的問題是,為什么我不能使用torch.tensor()將列表轉換為張量。 我錯過了什么嗎? 有沒有更好的方法來做我打算做的事情?

任何幫助將不勝感激! 謝謝!

@rollthedice32 的答案非常好。 出於教育目的,這里使用torch.cat

a = torch.rand(128, 4, 150, 150)
b = torch.rand(128, 1, 150, 150)

# Cut out last dimension
a = a[:, :3, :, :]
# Concatenate in 2nd dimension
result = torch.cat([a, b], dim=1)
print(result.shape)
# => torch.Size([128, 4, 150, 150])

您也可以只分配給該特定維度。

orig = torch.randint(low=0, high=10, size=(2,3,2,2))
fake = torch.randint(low=111, high=119, size=(2,1,2,2))
orig[:,[2],:,:] = fake

原版之前

tensor([[[[0, 1],
      [8, 0]],

     [[4, 9],
      [6, 1]],

     [[8, 2],
      [7, 6]]],


    [[[1, 1],
      [8, 5]],

     [[5, 0],
      [8, 6]],

     [[5, 5],
      [2, 8]]]])

偽造的

tensor([[[[117, 115],
      [114, 111]]],


    [[[115, 115],
      [118, 115]]]])

原版之后

tensor([[[[  0,   1],
      [  8,   0]],

     [[  4,   9],
      [  6,   1]],

     [[117, 115],
      [114, 111]]],


    [[[  1,   1],
      [  8,   5]],

     [[  5,   0],
      [  8,   6]],

     [[115, 115],
      [118, 115]]]])

希望這可以幫助! :)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM