简体   繁体   English

Pytorch:如何通过python字典中的张量(键)访问张量(值)

[英]Pytorch: How to access tensor(values) by tensor(keys) in python dictionary

I have a dictionary with tensor keys and tensor values.我有一个带有张量键和张量值的字典。 I want to access the values by the keys.我想通过键访问值。

from torch import tensor
x = {tensor(0): [tensor(1)], tensor(1): [tensor(0)]}
for i in x.keys():
  print(i, x[i]) 

Returns:返回:

tensor(0) [tensor(1)]
tensor(1) [tensor(0)]

But when i try to access the values without looping through the keys,但是当我尝试访问值而不遍历键时,

try:
    print(x[tensor(0)])

except:
    print(Exception)
    print(x[0])

Throws Exception:抛出异常:

 KeyError                                   Traceback (most recent call last)
 <ipython-input-34-746d28dcd450> in <module>()
  6 try:
  ----> 7   print(x[tensor(0)])
  8 

  KeyError: tensor(0)

  During handling of the above exception, another exception occurred:

  KeyError                                  Traceback (most recent call last)
  <ipython-input-34-746d28dcd450> in <module>()
  9 except:
  10   print(Exception)
  ---> 11   print(x[0])
  12 continue

  KeyError: 0

In PyTorch, hashes of Tensors are a function of their id , not the actual value.在 PyTorch 中,张量的哈希值是其id的函数,而不是实际值。 Because Python dictionaries use the hashes for lookup, lookup fails.因为 Python 字典使用哈希值进行查找,查找失败。 See this Github discussion .请参阅此 Github 讨论

In [4]: hash(tensor(0)) == hash(tensor(0))                                      
Out[4]: False

In [5]: hash(tensor(0))                                                         
Out[5]: 4364730928

In [6]: hash(tensor(0))                                                         
Out[6]: 4362187312

In [7]: hash(tensor(0))                                                         
Out[7]: 4364733808

In order to achieve what you want, you could either use plain Python integers as keys, or use an Embedding object as x .为了实现您想要的,您可以使用纯 Python 整数作为键,或者使用Embedding对象作为x

There's at least three issues here.这里至少有三个问题。

  1. If x is a dictionary of tensor keys, then of course x[0] will not work.如果xtensor键的字典,那么x[0]当然不起作用。 0 is not a key of it. 0不是它的关键。 Hence the inner KeyError that occured during the other exception.因此,在另一个异常期间发生了内部KeyError
  2. Not actually relevant to your errors, but print(Exception) is almost surely not what you want.实际上与您的错误无关,但print(Exception)几乎肯定不是您想要的。 It prints the class object (if that is the right term) of the class Exception .它打印类Exception类对象(如果这是正确的术语)。 You likely rather meant你可能更想说
    except Exception as e: print(e)
    or more specifically, except KeyError (otherwise it will just catch every kind of exception).或者更具体地说, except KeyError (否则它只会捕获各种异常)。
  3. The real thing: you don't want to use a tensor as key in the first place.真实的事情:您首先不想使用tensor作为键。 It's a mutable type, compared by reference, not by value.它是一种可变类型,通过引用进行比较,而不是通过值进行比较。 Every tensor(something) call will create a new object, hashing to a different value than the tensor(something) you specified as key.每个tensor(something)调用都会创建一个新对象,散列到与您指定为键的tensor(something)不同的值。

    Use the actual integers instead.请改用实际的整数。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM