簡體   English   中英

索引兩次后如何更新 Pytorch 中的張量?

[英]How do I update a tensor in Pytorch after indexing twice?

我知道如何在索引到張量的一部分后更新張量,如下所示:

import torch

b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b] = 2
b
# tensor([0, 2, 0, 2], dtype=torch.uint8)

但是有沒有辦法在索引兩次后更新原始張量? 例如

i = 1
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b][i] = 2
b
# tensor([0, 1, 0, 1], dtype=torch.uint8)

我想要的是btensor([0, 1, 0, 2]) 有沒有辦法做到這一點?

我知道我能做到

masked = b[b]
masked[i] = 2
b[b] = masked
b
# tensor([0, 1, 0, 2], dtype=torch.uint8)

但有沒有更好的方法? 看來這一定是低效的; 如果masked非常大,當我真的只更改了一個時,我會更新b許多位置。

(如果與索引兩次不同的方法會更好地工作,我遇到的一般問題是如何更改原始張量在該張量的掩碼版本的第i個位置處的值。)

我從這里采用了另一個解決方案,並將其與您的解決方案進行了比較:

解決方案:

b[b.nonzero()[i]] = 2

運行時比較:

import torch as t
import numpy as np
import timeit


if __name__ == "__main__":

    np.random.seed(12345)
    b = t.tensor(np.random.randint(0,2, [1000]), dtype=t.uint8)
    # inconvenient way to think of a random index halfway that is 1.
    halfway = np.array(list(range(len(b))))[b == 1][len(b[b == 1]) //2]

    runs = 100000

    elapsed1 = timeit.timeit("mask=b[b]; mask[halfway] = 2; b[b] = mask", 
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (original): {:.6f} ms per call".format(elapsed1 / runs))

    elapsed2 = timeit.timeit("b[b.nonzero()[halfway]]=2",
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (improved): {:.6f} ms per call".format(elapsed2 / runs))

結果:

Time taken (original): 0.000096 ms per call
Time taken (improved): 0.000047 ms per call

長度為100000向量的結果

Time taken: 0.010284 ms per call
Time taken: 0.003667 ms per call

因此,解決方案僅相差 2 倍。我不確定這是否是最佳解決方案,但根據您的規模(以及您調用該函數的頻率),它應該可以讓您大致了解正在查看的內容。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM