簡體   English   中英

如何將來自 tf.nn.top_k 的索引與 tf.gather_nd 一起使用?

[英]How to use indices from tf.nn.top_k with tf.gather_nd?

我正在嘗試使用從 tf.nn.top_k 返回的索引從第二個張量中提取值。

我已經嘗試使用 numpy 類型索引,以及直接使用 tf.gather_nd,但我注意到索引是錯誤的。

#  temp_attention_weights of shape [I, B, 1, J]
top_values, top_indices = tf.nn.top_k(temp_attention_weights, k=top_k)

# top_indices of shape [I, B, 1, top_k], base_encoder_transformed of shape [I, B, 1, J]

# I now want to extract from base_encoder_transformed top_indices
base_encoder_transformed = tf.gather_nd(base_encoder_transformed, indices=top_indices)  

# base_encoder_transformed should be of shape [I, B, 1, top_k]

我注意到 top_indices 的格式錯誤,但我似乎無法將其轉換為在 tf.gather_nd 中使用,其中最內層的維度用於索引 base_encoder_transformed 中的每個相應元素。 有人知道將 top_indices 轉換為正確格式的方法嗎?

top_indices將僅在最后一個軸上建立索引,您也需要為其余軸添加索引。 使用tf.meshgrid很容易:

import tensorflow as tf

# Example input data
I = 4
B = 3
J = 5
top_k = 2
x = tf.reshape(tf.range(I * B * J), (I, B, 1, J)) % 7
# Top K
top_values, top_indices = tf.nn.top_k(x, k=top_k)
# Make indices for the rest of axes
ii, jj, kk, _ = tf.meshgrid(
    tf.range(I),
    tf.range(B),
    tf.range(1),
    tf.range(top_k),
    indexing='ij')
# Stack complete index
index = tf.stack([ii, jj, kk, top_indices], axis=-1)
# Get the same values again
top_values_2 = tf.gather_nd(x, index)
# Test
with tf.Session() as sess:
    v1, v2 = sess.run([top_values, top_values_2])
    print((v1 == v2).all())
    # True

我沒有看到使用tf.gather_nd的理由。 使用tf.gatherbatch_dims參數有一個更簡單、更快(不需要使用tf.meshgrid )的解決方案。

import tensorflow as tf

# Example input data
I = 4
B = 3
J = 5
top_k = 2
x = tf.reshape(tf.range(I * B * J), (I, B, 1, J)) % 7
# Top K
top_values, top_indices = tf.nn.top_k(x, k=top_k)
#Gather indices along last axis
top_values_2 = tf.gather(x, top_indices, batch_dims = 3)

tf.reduce_all(top_values_2 == top_values).numpy()
#True

請注意,在這種情況下, batch_dims為 3,因為我們要從最后一個軸收集,並且 x 的秩為 4。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM