簡體   English   中英

獲取給定批處理值的字典鍵 - Python

[英]Get dictionary keys for given batched values - Python

我定義了一個字典A並想在給定一批值a情況下找到鍵:

def dictionary(r):
 return dict(enumerate(r))

def get_key(val, my_dict):
   for key, value in my_dict.items():
      if np.array_equal(val,value):
          return key
    

 # dictionary
 A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
 A = dictionary(A)

 a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
 keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)

預期的 output 應該是: keys = [[1,2,3],[0,3,2]]

為什么我得到None作為 output?

JAX 像vmap一樣通過跟蹤function 進行轉換,這意味着它們將值替換為值的抽象表示,以提取在 function 中編碼的操作序列(請參閱如何在 JAX 中思考以獲得對此概念的良好介紹)。

這意味着要正確使用vmap , function 只能使用 JAX 方法,而不能使用 numpy 方法,因此您對np.array_equal的使用會破壞抽象。

不幸的是,實際上並沒有任何替代品,因為沒有機制可以在具體的 Python 字典中查找抽象的 JAX 值。 如果你想對 JAX 值進行字典查找,你應該避免轉換,只使用 Python 循環:

keys = jnp.array([[get_key(x, A) for x in row] for row in a])

另一方面,我懷疑這更像是一個XY 問題 您的目標不是在 jax 轉換中查找字典值,而是解決某些問題。 也許你應該問一個關於如何解決問題的問題,而不是如何使用你嘗試過的解決方案來解決問題。

但是,如果您不想直接使用字典,則與 JAX 兼容的替代get_key實現可能如下所示:

def get_key(val, my_dict):
  keys = jnp.array(list(my_dict.keys()))
  values = jnp.array(list(my_dict.values()))
  return keys[jnp.where((values == val).all(-1), size=1)]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM