[英]Get dictionary keys for given batched values - Python
I defined a dictionary A
and would like to find the keys given a batch of values a
:我定义了一个字典
A
并想在给定一批值a
情况下找到键:
def dictionary(r):
return dict(enumerate(r))
def get_key(val, my_dict):
for key, value in my_dict.items():
if np.array_equal(val,value):
return key
# dictionary
A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
A = dictionary(A)
a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)
The expected output should be: keys = [[1,2,3],[0,3,2]]
预期的 output 应该是:
keys = [[1,2,3],[0,3,2]]
Why am I getting None
as an output?为什么我得到
None
作为 output?
JAX transforms like vmap
work by tracing the function, meaning they replace the value with an abstract representation of the value to extract the sequence of operations encoded in the function (See How to think in JAX for a good intro to this concept). JAX 像
vmap
一样通过跟踪function 进行转换,这意味着它们将值替换为值的抽象表示,以提取在 function 中编码的操作序列(请参阅如何在 JAX 中思考以获得对此概念的良好介绍)。
What this means is that to work correctly with vmap
, a function can only use JAX methods, not numpy methods, so your use of np.array_equal
breaks the abstraction.这意味着要正确使用
vmap
, function 只能使用 JAX 方法,而不能使用 numpy 方法,因此您对np.array_equal
的使用会破坏抽象。
Unfortunately, there's not really any replacement for it, because there's no mechanism to look up an abstract JAX value in a concrete Python dictionary.不幸的是,实际上并没有任何替代品,因为没有机制可以在具体的 Python 字典中查找抽象的 JAX 值。 If you want to do dict lookups of JAX values, you should avoid transforms and just use Python loops:
如果你想对 JAX 值进行字典查找,你应该避免转换,只使用 Python 循环:
keys = jnp.array([[get_key(x, A) for x in row] for row in a])
On the other hand, I suspect this is more of an XY problem ;另一方面,我怀疑这更像是一个XY 问题; your goal is not to look up dictionary values within a jax transform, but rather to solve some problem.
您的目标不是在 jax 转换中查找字典值,而是解决某些问题。 Perhaps you should ask a question about how to solve the problem, rather than how to get around an issue with the solution you have tried.
也许你应该问一个关于如何解决问题的问题,而不是如何使用你尝试过的解决方案来解决问题。
But if you're willing to not directly use the dict, an alternative get_key
implementation that is compatible with JAX might look something like this:但是,如果您不想直接使用字典,则与 JAX 兼容的替代
get_key
实现可能如下所示:
def get_key(val, my_dict):
keys = jnp.array(list(my_dict.keys()))
values = jnp.array(list(my_dict.values()))
return keys[jnp.where((values == val).all(-1), size=1)]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.