簡體   English   中英

2D數組作為Pytorch中的索引

[英]2d array as index in Pytorch

我想使用一組規則來“增長”矩陣。

規則示例:

0->[[1,1,1],[0,0,0],[2,2,2]],
1->[[2,2,2],[2,2,2],[2,2,2]],
2->[[0,0,0],[0,0,0],[0,0,0]]

增長矩陣的示例:

[[0]]->[[1,1,1],[0,0,0],[2,2,2]]->
[[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],
[1,1,1,1,1,1,1,1,1],[0,0,0,0,0,0,0,0,0],[2,2,2,2,2,2,2,2,2],
[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0]]

這是我一直試圖在Pytorch中工作的代碼

rules = np.random.randint(256,size=(10,256,3,3,3))
rules_tensor = torch.randint(256,size=(10,
            256, 3, 3, 3),
            dtype=torch.uint8, device = torch.device('cuda'))

rules = rules[0]
rules_tensor = rules_tensor[0]

seed = np.array([[128]])
seed_tensor = seed_tensor = torch.cuda.ByteTensor([[128]])

decode = np.empty((3**3, 3**3, 3))
decode_tensor = torch.empty((3**3,
                3**3, 3), dtype=torch.uint8,
                device = torch.device('cuda'))

for i in range(3):
    grow = seed
    grow_tensor = seed_tensor
    for j in range(1,4):
        grow = rules[grow,:,:,i].reshape(3**j,-1)
        grow_tensor = rules_tensor[grow_tensor,:,:,i].reshape(3**j,-1)

    decode[..., i] = grow
    decode_tensor[..., i] = grow_tensor

我似乎無法在這一行中以與Numpy中相同的方式選擇索引:

grow = rules[grow,:,:,i].reshape(3**j,-1)

有沒有辦法在Pytorch中執行以下操作?

您可以考慮使用torch.index_select() ,在重塑結果之前展平索引張量:

碼:

import torch
import numpy as np

rules_np = np.array([
    [[1,1,1],[0,0,0],[2,2,2]],  # for value 0
    [[2,2,2],[2,2,2],[2,2,2]],  # for value 1
    [[0,0,0],[0,0,0],[0,0,0]]]) # for value 2, etc.
rules = torch.from_numpy(rules_np).long()
rule_shape = rules[0].shape

seed = torch.zeros(1).long()
num_growth = 2
print("Seed:")
print(seed)

grow = seed
for i in range(num_growth):
    grow = (torch.index_select(rules, 0, grow.view(-1))
            .view(grow.shape + rule_shape)
            .squeeze())
    print("Growth #{}:".format(i))
    print(grow)

日志:

Seed:
tensor([ 0])
Growth #0:
tensor([[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]])
Growth #1:
tensor([[[[ 2,  2,  2], [ 2,  2,  2], [ 2,  2,  2]],
         [[ 2,  2,  2], [ 2,  2,  2], [ 2,  2,  2]],
         [[ 2,  2,  2], [ 2,  2,  2], [ 2,  2,  2]]],

        [[[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]],
         [[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]],
         [[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]]],

        [[[ 0,  0,  0], [ 0,  0,  0], [ 0,  0,  0]],
         [[ 0,  0,  0], [ 0,  0,  0], [ 0,  0,  0]],
         [[ 0,  0,  0], [ 0,  0,  0], [ 0,  0,  0]]]])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM