繁体   English   中英

Tensorflow:从张量列表中过滤出一个张量

[英]Tensorflow: Filter a tensor out of a list of tensors

我正在使用一个数据集,我从每个 tfrecord 文件中解析出四个张量。 每隔一段时间,四个张量中的一个将是空的,我希望能够过滤掉这个张量并将其余的张量发送到 tf.data 管道中的下一步。 我将四个张量保存在字典中,我希望能够做这样的事情。

@tf.function
def filter_and_reshape(tensor_dict, shape):
    return {k: tf.reshape(t, shape)
            for k, t in tensor_dict.items() if not tf.equal(tf.size(t), 0)}

其中tensor_dict是我刚刚从文件中解析的张量字典,但尚未恢复到原始形状。

不幸的是,这不起作用,因为tf.equal(tf.size(t), 0)返回一个张量而不是 bool,并且签名似乎不能解决这个问题。

有没有其他方法可以做到这一点?

tf.data.Dataset生成的所有元素tf.data.Dataset必须具有相同的结构。 例如,如果数据集生成带有键“a”、“b”、“c”、“d”的字典,则数据集生成的所有元素将始终具有这四个键。 您无法生成具有 4 个键的某些元素和具有 3 个键的其他元素。

我建议更改您的模型代码以忽略可能为空的张量(如果它为空)。

要将 4 个可能为空的张量相加,您可以执行以下操作

import tensorflow as tf
tf.enable_v2_behavior()

a = tf.constant([1, 1, 1])
b = tf.constant([2, 2, 2])
c = tf.constant([], dtype=tf.int32)
d = tf.constant([7, 7, 7])

def add_py_fn(a, b, c, d):
  s = tf.zeros((3,), dtype=tf.int32)
  for t in [a, b, c, d]:
    if tf.size(t) > 0:
      s += t
  return s

def add_fn(a, b, c, d):
  return tf.py_function(add_py_fn, [a, b, c, d], tf.int32)

ds = tf.data.Dataset.from_tensors((a, b, c, d))
ds = ds.map(add_fn)
print(next(iter(ds)))

py_function 逃逸舱口是必需的,因为运行时不能告诉c永远不会被添加到s ,否则它会抛出一个形状错误。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM