[英]Tensorflow: Filter a tensor out of a list of tensors
我正在使用一个数据集,我从每个 tfrecord 文件中解析出四个张量。 每隔一段时间,四个张量中的一个将是空的,我希望能够过滤掉这个张量并将其余的张量发送到 tf.data 管道中的下一步。 我将四个张量保存在字典中,我希望能够做这样的事情。
@tf.function
def filter_and_reshape(tensor_dict, shape):
return {k: tf.reshape(t, shape)
for k, t in tensor_dict.items() if not tf.equal(tf.size(t), 0)}
其中tensor_dict
是我刚刚从文件中解析的张量字典,但尚未恢复到原始形状。
不幸的是,这不起作用,因为tf.equal(tf.size(t), 0)
返回一个张量而不是 bool,并且签名似乎不能解决这个问题。
有没有其他方法可以做到这一点?
tf.data.Dataset
生成的所有元素tf.data.Dataset
必须具有相同的结构。 例如,如果数据集生成带有键“a”、“b”、“c”、“d”的字典,则数据集生成的所有元素将始终具有这四个键。 您无法生成具有 4 个键的某些元素和具有 3 个键的其他元素。
我建议更改您的模型代码以忽略可能为空的张量(如果它为空)。
要将 4 个可能为空的张量相加,您可以执行以下操作
import tensorflow as tf
tf.enable_v2_behavior()
a = tf.constant([1, 1, 1])
b = tf.constant([2, 2, 2])
c = tf.constant([], dtype=tf.int32)
d = tf.constant([7, 7, 7])
def add_py_fn(a, b, c, d):
s = tf.zeros((3,), dtype=tf.int32)
for t in [a, b, c, d]:
if tf.size(t) > 0:
s += t
return s
def add_fn(a, b, c, d):
return tf.py_function(add_py_fn, [a, b, c, d], tf.int32)
ds = tf.data.Dataset.from_tensors((a, b, c, d))
ds = ds.map(add_fn)
print(next(iter(ds)))
py_function 逃逸舱口是必需的,因为运行时不能告诉c
永远不会被添加到s
,否则它会抛出一个形状错误。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.