[英]In tf.function, how to convert dynamic tensor shape getted by tf.shape() to python values but not tensors itself?
我有一个来自 tf.dataset 的填充批次,因为每个填充批次的形状都不是固定的。所以我必须使用 tf.shape 方法来获取填充批次的动态形状。问题是如何转换由 tf 获得的张量形状.shape 到 tf.function 下的 python 值?
@tf.function
def train_step(padded_batch):
shape = tf.shape(padded_batch)
x = np.zeros(shape[0], shape[1])
如上面的代码,我想创建一个与padd_batch形状相同的numpy数组,但是'shape'是一个张量,在numpy中不能直接使用。功能。
我使用的tensorflow版本是tf2.0
假设你有一个名为a_tensor
的张量:
this_is_a_regular_non_tensor_shape = a_tensor.shape.as_list()
(顺便说一句:您似乎没有正确使用np.zeros
……您需要将形状作为单个元组/列表参数传递。不是每个维度的单独参数。例如:
shape = padded_batch.shape.as_list()
x = np.zeros(shape)
希望有帮助。)
如TF 文档中所述,
在@tf.function或 compat.v1 上下文中,在执行之前并非所有维度都是已知的。 因此,在为图形模式定义自定义层和模型时,更喜欢动态tf.shape(x)而不是静态 x.shape
你的代码没问题。 我只是用 tf.zeros 替换了 np.zeros。 @tf.function 装饰器意味着代码将以图形模式运行。 图形中不允许使用 numpy。 在 TF 2.x 中测试。
@tf.function
def train_step(padded_batch):
shape = tf.shape(padded_batch)
return tf.zeros((shape[0], shape[1]))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.