簡體   English   中英

計算“torch.tensor”中條目之間的成對距離

[英]Calculating pairwise distances between entries in a `torch.tensor`

我正在嘗試實施 此處所示的流形 alignment 類型的損失。

給定張量embs

tensor([[ 0.0178,  0.0004, -0.0217,  ..., -0.0724,  0.0698, -0.0180],
        [ 0.0160,  0.0002, -0.0217,  ..., -0.0725,  0.0655, -0.0207],
        [ 0.0155, -0.0010, -0.0153,  ..., -0.0750,  0.0688, -0.0253],
        ...,
        [ 0.0130, -0.0113, -0.0078,  ..., -0.0805,  0.0634, -0.0241],
        [ 0.0120, -0.0047, -0.0135,  ..., -0.0846,  0.0722, -0.0230],
        [ 0.0120, -0.0048, -0.0142,  ..., -0.0843,  0.0734, -0.0246]],
       grad_fn=<AddmmBackward0>)

形狀(256,64)是 a.network 生成的一批嵌入,我想計算行條目之間的所有成對距離。 我試過torch.nn.PairwiseDistance但我不清楚它是否對我正在尋找的東西有用。

覺得很奇怪,沒有。 有,它被稱為torch.cdist但它“隱藏”在頂層。

>>> a = torch.rand((5,3))
>>> a
tensor([[0.0215, 0.0843, 0.3414],
        [0.9878, 0.5835, 0.3052],
        [0.0903, 0.7347, 0.0711],
        [0.9774, 0.8202, 0.7721],
        [0.7877, 0.9891, 0.4619]])
>>> torch.cdist(a,a)
tensor([[0.0000, 1.0883, 0.7077, 1.2809, 1.1918],
        [1.0883, 0.0000, 0.9398, 0.5236, 0.4787],
        [0.7077, 0.9398, 0.0000, 1.1339, 0.8390],
        [1.2809, 0.5236, 1.1339, 0.0000, 0.4010],
        [1.1918, 0.4787, 0.8390, 0.4010, 0.0000]])
>>> torch.nn.functional.pairwise_distance(a[0], a[2])
tensor(0.7077)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM