简体   繁体   English

Tensorflow:按最大值过滤3D索引重复项

[英]Tensorflow: Filtering 3D Index duplicates by their maximum Values

I am trying to create a filter mask that removes duplicate Indices from an vector by comparing which of their respective values is greater. 我正在尝试创建一个过滤器蒙版,通过比较它们各自的较大值来从向量中删除重复的索引。

My current approach is: 我当前的方法是:

  1. Transform 3-D Index to 1-D 将3D索引转换为1D
  2. Check the 1-D Index for uniqueness 检查一维索引的唯一性
  3. Calculate the maximum values of each unique index 计算每个唯一索引的最大值
  4. Compare the maximum values with the original values. 将最大值与原始值进行比较。 If the same value exists, keep that 3-D Index. 如果存在相同的值,则保留该3-D索引。

I want to get an filter array so I can apply a boolean_mask to other tensors as well. 我想获得一个过滤器数组,以便可以将boolean_mask应用于其他张量。 For this example the mask should look the following way: [False True True True True] . 对于此示例,遮罩应如下所示: [False True True True True]

My current code kind of works unless the values themselves are also duplicated. 我当前的代码种类有效,除非值本身也被重复。 However this seems to be the case when I am using it therefore I need to find a better solution to it. 但是,当我使用它时似乎是这种情况,因此我需要找到一种更好的解决方案。

Here is an examplary of how my Code looks 这是我的代码外观的例证

import tensorflow as tf

# Dummy Input values with same Structure as the real
x_cells   = tf.constant([1,2,3,4,1], dtype=tf.int32)   # Index_1
y_cells   = tf.constant([4,4,4,4,4], dtype=tf.int32)   # Index_2
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32) # Index_3
iou_max   = tf.constant([1.,2.,3.,4.,5.], dtype=tf.float32) # Values

# my Output should be a mask that is [False True True True True]
# So if i filter this i get e.g. x_cells = [2,3,4,1] or iou_max = [2.,3.,4.,5.]

max_dim_y = tf.constant(10)
max_dim_x = tf.constant(20)
num_anchors = 5
stride = 32

# 1. Transforming the 3D-Index to 1D
tmp = tf.stack([x_cells, y_cells, iou_index], axis=1)
indices = tf.matmul(tmp, [[max_dim_y * num_anchors],     [num_anchors],[1]])

# 2. Looking for unique / duplicate indices
y, idx = tf.unique(tf.squeeze(indices))

# 3. Calculating the maximum values of each unique index.
# An function like unsorted_segment_argmax() would be awesome here
num_segments = tf.shape(y)[0]
ious = tf.unsorted_segment_max(iou_max, idx, num_segments)

iou_max_length = tf.shape(iou_max)[0]
ious_length = tf.shape(ious)[0]

# 4. Compare all max values to original values.
iou_max_tiled = tf.tile(iou_max, [ious_length])
iou_reshaped = tf.reshape(iou_max_tiled, [ious_length, iou_max_length])
iou_max_reshaped = tf.transpose(iou_reshaped)
filter_mask = tf.reduce_any(tf.equal(iou_max_reshaped, ious), -1)
filter_mask = tf.reshape(filter_mask, shape=[-1])

This code above will fail if we simply change the value of the iou_max Variable in the beginning to: 如果仅将开头的iou_max变量的值更改为以下代码,则以上代码将失败:

x_cells = tf.constant([1,2,3,4,1], dtype=tf.int32)
y_cells = tf.constant([4,4,4,4,4], dtype=tf.int32)
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32)
iou_max = tf.constant([2.,2.,3.,4.,5.], dtype=tf.float32)

My current workaround changed point 4 of my question: 我当前的解决方法更改了问题的第4点:

Basically I changed that I compare the tuples instead of the single values. 基本上,我改变了,我比较元组而不是单个值。 This leads to me beeing able to logically check if both, index AND value are in the remaining values from 3. 这使我能够逻辑上检查索引AND值是否都在3的剩余值中

# 4. Compare a Max Value and Indices with original values
rem_index_val_pair = tf.stack([ious, tf.cast(y, dtype=tf.float32)], axis=1)
orig_val_index_pair = tf.stack([iou_max, tf.cast(indices, dtype=tf.float32)], axis=1)

orig_val_index_pair_t = tf.tile(orig_val_index_pair, [1, ious_length])
orig_val_index_pair_s = tf.reshape(orig_val_index_pair_t, [iou_max_length, ious_length, 2])
filter_mask_1 = tf.equal(orig_val_index_pair_s, rem_index_val_pair)
filter_mask_2 = tf.reduce_all(filter_mask_1, -1)
filter_mask_3 = tf.reduce_any(filter_mask_2, -1)

# The orig_val_index_pair_s looks like the following
a =  [[[  2.  71.][  2.  71.][  2.  71.][  2.  71.]
     [[  2. 122.][  2. 122.][  2. 122.][  2. 122.]]
     [[  3. 173.][  3. 173.][  3. 173.][  3. 173.]]
     [[  4. 224.][  4. 224.][  4. 224.][  4. 224.]]
     [[  5.  71.][  5.  71.][  5.  71.][  5.  71.]]]
# I then compare it to the rem_max_val_pair which looks like this.
b =  [[  5.  71.][  2. 122.][  3. 173.][  4. 224.]]

# Using equal(a,b) will now compare each of the values resulting in:
c = [[[False  True][ True False][False False][False False]]
     [[False False][ True  True][False False][False False]]
     [[False False][False False][ True  True][False False]]
     [[False False][False False][False False][ True  True]]
     [[ True  True][False False][False False][False False]]]

# Using tf.reduce_all(c, -1) I can filter the bool pairs with a logical And. 
# (This kicks out my false positives from before).
# Afterwards I can check if the line has any true value by tf.reduce_any().

IMO this solution still is a dirty workaround. IMO这个解决方案仍然是一个肮脏的解决方法。 So if you have any better solution proposals please share them. 因此,如果您有更好的解决方案建议,请分享。 :) :)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM