繁体   English   中英

在 tf.function 中,如何将 tf.shape() 获得的动态张量形状转换为 python 值而不是张量本身?

[英]In tf.function, how to convert dynamic tensor shape getted by tf.shape() to python values but not tensors itself?

我有一个来自 tf.dataset 的填充批次,因为每个填充批次的形状都不是固定的。所以我必须使用 tf.shape 方法来获取填充批次的动态形状。问题是如何转换由 tf 获得的张量形状.shape 到 tf.function 下的 python 值?

@tf.function
def train_step(padded_batch):
    shape = tf.shape(padded_batch)
    x = np.zeros(shape[0], shape[1])

如上面的代码,我想创建一个与padd_batch形状相同的numpy数组,但是'shape'是一个张量,在numpy中不能直接使用。功能。

我使用的tensorflow版本是tf2.0

假设你有一个名为a_tensor的张量:

this_is_a_regular_non_tensor_shape = a_tensor.shape.as_list()

(顺便说一句:您似乎没有正确使用np.zeros ……您需要将形状作为单个元组/列表参数传递。不是每个维度的单独参数。例如:

shape = padded_batch.shape.as_list()
x = np.zeros(shape)

希望有帮助。)

TF 文档中所述

@tf.function或 compat.v1 上下文中,在执行之前并非所有维度都是已知的。 因此,在为图形模式定义自定义层和模型时,更喜欢动态tf.shape(x)而不是静态 x.shape

你的代码没问题。 我只是用 tf.zeros 替换了 np.zeros。 @tf.function 装饰器意味着代码将以图形模式运行。 图形中不允许使用 numpy。 在 TF 2.x 中测试。

@tf.function
def train_step(padded_batch):
    shape = tf.shape(padded_batch)
    return tf.zeros((shape[0], shape[1]))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM