简体   繁体   English

在 tf.function 中,如何将 tf.shape() 获得的动态张量形状转换为 python 值而不是张量本身?

[英]In tf.function, how to convert dynamic tensor shape getted by tf.shape() to python values but not tensors itself?

I have a padded batch from tf.dataset, because every padded batch's shape is not fixed.So I have to use tf.shape method to get the dynamic shape of padded batch.The question is how can I convert the tensor shape getted by tf.shape to python values under tf.function?我有一个来自 tf.dataset 的填充批次,因为每个填充批次的形状都不是固定的。所以我必须使用 tf.shape 方法来获取填充批次的动态形状。问题是如何转换由 tf 获得的张量形状.shape 到 tf.function 下的 python 值?

@tf.function
def train_step(padded_batch):
    shape = tf.shape(padded_batch)
    x = np.zeros(shape[0], shape[1])

As the above code, I want to create a numpy array as the same shape of padded_batch,but 'shape' is a tensor, it can't be used directly in numpy.If there is someway to convert tensor to python values under tf.function.如上面的代码,我想创建一个与padd_batch形状相同的numpy数组,但是'shape'是一个张量,在numpy中不能直接使用。功能。

The tensorflow version I use is tf2.0我使用的tensorflow版本是tf2.0

assuming you have a tensor named a_tensor :假设你有一个名为a_tensor的张量:

this_is_a_regular_non_tensor_shape = a_tensor.shape.as_list()

(BTW: you don't seem to be using np.zeros correctly...you need to pass the shape as a single tuple/list argument. Not separate arguments for each dimension. For instance: (顺便说一句:您似乎没有正确使用np.zeros ……您需要将形状作为单个元组/列表参数传递。不是每个维度的单独参数。例如:

shape = padded_batch.shape.as_list()
x = np.zeros(shape)

Hope that helps.)希望有帮助。)

As described in TF documents ,TF 文档中所述

within @tf.function or within a compat.v1 context, not all dimensions may be known until execution time.@tf.function或 compat.v1 上下文中,在执行之前并非所有维度都是已知的。 Hence when defining custom layers and models for graph mode, prefer the dynamic tf.shape(x) over the static x.shape因此,在为图形模式定义自定义层和模型时,更喜欢动态tf.shape(x)而不是静态 x.shape

Your code was ok.你的代码没问题。 I just replaced np.zeros with tf.zeros.我只是用 tf.zeros 替换了 np.zeros。 The @tf.function decorator means the code will run in graph mode . @tf.function 装饰器意味着代码将以图形模式运行。 numpy is not allowed within graph.图形中不允许使用 numpy。 Tested in TF 2.x.在 TF 2.x 中测试。

@tf.function
def train_step(padded_batch):
    shape = tf.shape(padded_batch)
    return tf.zeros((shape[0], shape[1]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 类型错误:tf.shape 不是 function - TypeError: tf.shape is not a function 如何使用“ tf.shape()”获得“静态”形状? - How to get 'static' shape using 'tf.shape()'? tf.shape() 在张量流中得到错误的形状 - tf.shape() get wrong shape in tensorflow 如何在 tf.function(图形模式)中展平梯度(张量列表) - How to flatten a gradient (list of tensors) in tf.function (graph mode) 检查张量中是否包含元素(带有'@tf.function'的Python Tensorflow2) - Check if an element is contained in a tensor (Python Tensorflow2 with '@tf.function') 如何访问在 tf.function 中更新的张量值(例如指标)? - How do I access Tensor values (e.g. Metrics) which are updated within a tf.function? 我如何在Tensorflow中调整未知尺寸的图像的大小(tf.shape(input)方法不起作用) - How do I resize image with unknown size in Tensorflow(tf.shape(input) method doesn't work) 'Tensor' 对象在 TF 2.0 的 tf.function 中没有属性 'numpy' - 'Tensor' object has no attribute 'numpy' in tf.function in TF 2.0 tensorflow 2.0 中的 x.shape 和 tf.shape() 有什么区别? - What is the difference between x.shape and tf.shape() in tensorflow 2.0? TF 2.0:如何将 Tensor("mul_1:0", shape=(1000, 64), dtype=float32) 转换为 Numpy 数组 - TF 2.0: How to convert Tensor("mul_1:0", shape=(1000, 64), dtype=float32) to Numpy array
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM