繁体   English   中英

tf.where() 在首次使用每个新张量时需要额外的时间

[英]tf.where() takes extra time on first use for every new tensor

编辑:想澄清一下我的机器>我使用的是 Nvidia Xavier AGX(该程序也应该在 NX 上运行),这可能是性能时间缓慢的原因,我想知道是否有更快的选择。

我正在编写应该读取图像的代码,使用它从分段 model 中获取推断,然后为 output 着色并将其显示(/保存)为 rgb 图像。

我的 model 的 output 是一个形状为 (256, 352, 7) 的张量,其中 7 表示类的数量,每个值对应于 object class 的置信度。

我需要所有这些实时运行大约 50 毫秒,但少一点也很好。 我的 model 在 17 毫秒内运行了一个推理,但是当我运行 tf 命令时我的问题出现了。

我尝试将张量转换为 numpy 数组以便为其着色并更轻松地使用 cv2 操作,但转换 function 需要大约 80 毫秒。

相反,我尝试创建一个新的 numpy 数组,并使用tf.where()命令为每个像素的每个 class 填充正确的 rgb 颜色(这是预先确定的)。

这是我的代码:`

tensor = tf.convert_to_tensor(input, dtype=tf.uint8)

output_model = model(tensor)
output = tf.argmax(output_model[0], axis=-1)
mask = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)

for i in range(len(colors)):
    indices = tf.where(tf.equal(output, i))
    mask[indices[:, 0], indices[:, 1]] = colors[i][::-1]

`

现在,当我针对某个张量运行此程序时,它第一次运行张量为 (256, 352) 的tf.where() function 时,需要 ~80 毫秒,之后的 6 次需要 ~3-5 毫秒,然后对于每个新张量,第一次调用需要相同的 80 毫秒,尽管它们的形状和类型完全相同。

我读到这是因为 tensorflow 第一次为张量构建正确的结构然后每隔一段时间重复使用它,但它似乎为每个新张量重置。

我的问题是:是一种向 tensorflow 介绍我将要输入的张量的形状和类型的方法,因为它们的大小和类型都相同,因此导致我每次调用时的运行时间约为 4 毫秒tf.where()

我查看了tf,function s 但我无法让它更快地工作:

`

@tf.function(input_signature=(tf.TensorSpec(shape=[HEIGHT, WIDTH], dtype=tf.int64), tf.TensorSpec(shape=[], dtype=tf.int64),))
def tfwhere(tensor, value):
    return tf.where(tf.equal(tensor, value))

`

我还尝试先使用具有相同形状的“虚拟张量”,但得到了相同类型的结果。

任何帮助表示赞赏,谢谢!

您可以完全删除循环并实现vectorized实现。

tf.squeeze(tf.matmul(tf.one_hot(output, n_classes), colors[None, None,...]))

考虑这个例子,

HEIGHT, WIDTH = 32, 32
n_classes = 7

output = np.random.randint(0, 8,size=(HEIGHT, WIDTH, 1), dtype=np.uint8)
#an random output with 0-7 as values.

mask = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
colors = np.ones((n_classes,3)) *np.arange(n_classes)[..., None]
#here the color is chosen such that its RGB values are class_ids, just for debugging purposes
#color[3] = (3,3,3) and so on.

此处one_hot将零添加到除所选类别之外的其他类别,并且与matmul的 matmul 将只获得从colors矩阵中选择的特定值。

colored_out  = tf.squeeze(tf.matmul(tf.one_hot(output, n_classes), colors[None, None,...]))

Checking output for a random index,
colored_out[7][7].numpy(), output[7][7]
 #array([3., 3., 3.],  array([3] 

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM