[英]tf.where() takes extra time on first use for every new tensor
编辑:想澄清一下我的机器>我使用的是 Nvidia Xavier AGX(该程序也应该在 NX 上运行),这可能是性能时间缓慢的原因,我想知道是否有更快的选择。
我正在编写应该读取图像的代码,使用它从分段 model 中获取推断,然后为 output 着色并将其显示(/保存)为 rgb 图像。
我的 model 的 output 是一个形状为 (256, 352, 7) 的张量,其中 7 表示类的数量,每个值对应于 object class 的置信度。
我需要所有这些实时运行大约 50 毫秒,但少一点也很好。 我的 model 在 17 毫秒内运行了一个推理,但是当我运行 tf 命令时我的问题出现了。
我尝试将张量转换为 numpy 数组以便为其着色并更轻松地使用 cv2 操作,但转换 function 需要大约 80 毫秒。
相反,我尝试创建一个新的 numpy 数组,并使用tf.where()
命令为每个像素的每个 class 填充正确的 rgb 颜色(这是预先确定的)。
这是我的代码:`
tensor = tf.convert_to_tensor(input, dtype=tf.uint8)
output_model = model(tensor)
output = tf.argmax(output_model[0], axis=-1)
mask = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for i in range(len(colors)):
indices = tf.where(tf.equal(output, i))
mask[indices[:, 0], indices[:, 1]] = colors[i][::-1]
`
现在,当我针对某个张量运行此程序时,它第一次运行张量为 (256, 352) 的tf.where()
function 时,需要 ~80 毫秒,之后的 6 次需要 ~3-5 毫秒,然后对于每个新张量,第一次调用需要相同的 80 毫秒,尽管它们的形状和类型完全相同。
我读到这是因为 tensorflow 第一次为张量构建正确的结构然后每隔一段时间重复使用它,但它似乎为每个新张量重置。
我的问题是:是一种向 tensorflow 介绍我将要输入的张量的形状和类型的方法,因为它们的大小和类型都相同,因此导致我每次调用时的运行时间约为 4 毫秒tf.where()
。
我查看了tf,function
s 但我无法让它更快地工作:
`
@tf.function(input_signature=(tf.TensorSpec(shape=[HEIGHT, WIDTH], dtype=tf.int64), tf.TensorSpec(shape=[], dtype=tf.int64),))
def tfwhere(tensor, value):
return tf.where(tf.equal(tensor, value))
`
我还尝试先使用具有相同形状的“虚拟张量”,但得到了相同类型的结果。
任何帮助表示赞赏,谢谢!
您可以完全删除循环并实现vectorized
实现。
tf.squeeze(tf.matmul(tf.one_hot(output, n_classes), colors[None, None,...]))
考虑这个例子,
HEIGHT, WIDTH = 32, 32
n_classes = 7
output = np.random.randint(0, 8,size=(HEIGHT, WIDTH, 1), dtype=np.uint8)
#an random output with 0-7 as values.
mask = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
colors = np.ones((n_classes,3)) *np.arange(n_classes)[..., None]
#here the color is chosen such that its RGB values are class_ids, just for debugging purposes
#color[3] = (3,3,3) and so on.
此处one_hot
将零添加到除所选类别之外的其他类别,并且与matmul
的 matmul 将只获得从colors
矩阵中选择的特定值。
colored_out = tf.squeeze(tf.matmul(tf.one_hot(output, n_classes), colors[None, None,...]))
Checking output for a random index,
colored_out[7][7].numpy(), output[7][7]
#array([3., 3., 3.], array([3]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.