简体   繁体   English

获取给定批处理值的字典键 - Python

[英]Get dictionary keys for given batched values - Python

I defined a dictionary A and would like to find the keys given a batch of values a :我定义了一个字典A并想在给定一批值a情况下找到键:

def dictionary(r):
 return dict(enumerate(r))

def get_key(val, my_dict):
   for key, value in my_dict.items():
      if np.array_equal(val,value):
          return key
    

 # dictionary
 A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
 A = dictionary(A)

 a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
 keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)

The expected output should be: keys = [[1,2,3],[0,3,2]]预期的 output 应该是: keys = [[1,2,3],[0,3,2]]

Why am I getting None as an output?为什么我得到None作为 output?

JAX transforms like vmap work by tracing the function, meaning they replace the value with an abstract representation of the value to extract the sequence of operations encoded in the function (See How to think in JAX for a good intro to this concept). JAX 像vmap一样通过跟踪function 进行转换,这意味着它们将值替换为值的抽象表示,以提取在 function 中编码的操作序列(请参阅如何在 JAX 中思考以获得对此概念的良好介绍)。

What this means is that to work correctly with vmap , a function can only use JAX methods, not numpy methods, so your use of np.array_equal breaks the abstraction.这意味着要正确使用vmap , function 只能使用 JAX 方法,而不能使用 numpy 方法,因此您对np.array_equal的使用会破坏抽象。

Unfortunately, there's not really any replacement for it, because there's no mechanism to look up an abstract JAX value in a concrete Python dictionary.不幸的是,实际上并没有任何替代品,因为没有机制可以在具体的 Python 字典中查找抽象的 JAX 值。 If you want to do dict lookups of JAX values, you should avoid transforms and just use Python loops:如果你想对 JAX 值进行字典查找,你应该避免转换,只使用 Python 循环:

keys = jnp.array([[get_key(x, A) for x in row] for row in a])

On the other hand, I suspect this is more of an XY problem ;另一方面,我怀疑这更像是一个XY 问题 your goal is not to look up dictionary values within a jax transform, but rather to solve some problem.您的目标不是在 jax 转换中查找字典值,而是解决某些问题。 Perhaps you should ask a question about how to solve the problem, rather than how to get around an issue with the solution you have tried.也许你应该问一个关于如何解决问题的问题,而不是如何使用你尝试过的解决方案来解决问题。

But if you're willing to not directly use the dict, an alternative get_key implementation that is compatible with JAX might look something like this:但是,如果您不想直接使用字典,则与 JAX 兼容的替代get_key实现可能如下所示:

def get_key(val, my_dict):
  keys = jnp.array(list(my_dict.keys()))
  values = jnp.array(list(my_dict.values()))
  return keys[jnp.where((values == val).all(-1), size=1)]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM