[英]In tf.function, how to convert dynamic tensor shape getted by tf.shape() to python values but not tensors itself?
I have a padded batch from tf.dataset, because every padded batch's shape is not fixed.So I have to use tf.shape method to get the dynamic shape of padded batch.The question is how can I convert the tensor shape getted by tf.shape to python values under tf.function?我有一个来自 tf.dataset 的填充批次,因为每个填充批次的形状都不是固定的。所以我必须使用 tf.shape 方法来获取填充批次的动态形状。问题是如何转换由 tf 获得的张量形状.shape 到 tf.function 下的 python 值?
@tf.function
def train_step(padded_batch):
shape = tf.shape(padded_batch)
x = np.zeros(shape[0], shape[1])
As the above code, I want to create a numpy array as the same shape of padded_batch,but 'shape' is a tensor, it can't be used directly in numpy.If there is someway to convert tensor to python values under tf.function.如上面的代码,我想创建一个与padd_batch形状相同的numpy数组,但是'shape'是一个张量,在numpy中不能直接使用。功能。
The tensorflow version I use is tf2.0我使用的tensorflow版本是tf2.0
assuming you have a tensor named a_tensor
:假设你有一个名为
a_tensor
的张量:
this_is_a_regular_non_tensor_shape = a_tensor.shape.as_list()
(BTW: you don't seem to be using np.zeros
correctly...you need to pass the shape as a single tuple/list argument. Not separate arguments for each dimension. For instance: (顺便说一句:您似乎没有正确使用
np.zeros
……您需要将形状作为单个元组/列表参数传递。不是每个维度的单独参数。例如:
shape = padded_batch.shape.as_list()
x = np.zeros(shape)
Hope that helps.)希望有帮助。)
As described in TF documents ,如TF 文档中所述,
within @tf.function or within a compat.v1 context, not all dimensions may be known until execution time.
在@tf.function或 compat.v1 上下文中,在执行之前并非所有维度都是已知的。 Hence when defining custom layers and models for graph mode, prefer the dynamic tf.shape(x) over the static x.shape
因此,在为图形模式定义自定义层和模型时,更喜欢动态tf.shape(x)而不是静态 x.shape
Your code was ok.你的代码没问题。 I just replaced np.zeros with tf.zeros.
我只是用 tf.zeros 替换了 np.zeros。 The @tf.function decorator means the code will run in graph mode .
@tf.function 装饰器意味着代码将以图形模式运行。 numpy is not allowed within graph.
图形中不允许使用 numpy。 Tested in TF 2.x.
在 TF 2.x 中测试。
@tf.function
def train_step(padded_batch):
shape = tf.shape(padded_batch)
return tf.zeros((shape[0], shape[1]))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.