繁体   English   中英

Python 成员运算符“在” TensorFlow 数据集

[英]Python Membership Operators “In” TensorFlow Datasets

我正在使用 TensorFlow 数据集开发输入管道,数据集只有两列,我想根据值列表进行过滤,但我只能使用运算符等于“==”过滤数据集,当我尝试使用会员操作员“在”我收到以下错误。

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

在我的代码下面:

import numpy as np
import tensorflow as tf


# Load file
file_path = 'drive/My Drive/Datasets/category_catalog.csv.gz'

def get_dataset(file_path, batch_size=5, num_epochs=1, **kwargs):
    return tf.data.experimental.make_csv_dataset(
        file_path,
        batch_size=batch_size,
        na_value="?",
        num_epochs=num_epochs,
        ignore_errors=True,
        **kwargs
    )

raw_data = get_dataset(
    file_path,
    select_columns=['description', 'department'],
    compression_type='GZIP'
)

此过滤器有效:

@tf.function
def filter_fn(features):
    return features['department'] == 'MOVEIS'

ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)

Output:

next(iter(ds))

OrderedDict([('description', <tf.Tensor: shape=(2,), dtype=string, numpy=
              array([b'KIT DE COZINHA KITS PARANA 8 PORTAS GOLDEN EM MDP LINHO BRANCO E LINHO PRETO',
                     b'ARMARIO AEREO PARA COZINHA 1 PORTA HORIZONTAL EXCLUSIVE ITATIAIA PRETO MATTE'],
                    dtype=object)>),
             ('department',
              <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'MOVEIS', b'MOVEIS'], dtype=object)>)])

此过滤器不起作用:

@tf.function
def filter_fn(features):
    return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']

ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)

错误:

---------------------------------------------------------------------------
OperatorNotAllowedInGraphError            Traceback (most recent call last)
<ipython-input-52-52131b5369b6> in <module>()
      6 
      7 ds = raw_data.unbatch()
----> 8 ds = ds.filter(filter_fn)
      9 ds = ds.batch(2)

18 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

OperatorNotAllowedInGraphError: in user code:

    <ipython-input-52-52131b5369b6>:5 filter_fn  *
        return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:778 __bool__
        self._disallow_bool_casting()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:545 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:532 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

这里是 Colab 的链接,可以在其中运行并检查错误:

我为过滤器 function 尝试了以下内容:

tf.reduce_any(tf.math.equal(features['department'], ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']))

它似乎在您提供的 colab 中有效,但我不确定它是否是您想要的。

逻辑如下: math.equal运算符将给出一个大小为 3 的张量,其中每个条目是TrueFalse 第一个入口是部门是否是“FERRAMENTAS”等......然后reduce_any基本上会对这3入口张量进行逻辑或。 因此,如果部门是 3 个白名单之一,它将在 3 个条目张量中恰好有一个True条目,因此reduce_any output 将为True 在所有其他情况下它将为False

即使前面的答案解决了你的问题,我还是想为将来来这里的人提出一个更通用的答案。

假设您有一个任意张量x ,例如:

>>> x = tf.range(20)
>>> x
<tf.Tensor: shape=(20,), dtype=int32, numpy=
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19])>

如果我们想获取某些元素的位置,例如41114 ,我们可以将它们存储在张量y中:

>>> y = tf.constant([4, 11, 14])
>>> y
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 4, 11, 14])>

然后在值向量x和搜索到的元素向量y的转置之间使用相等运算。 结果将是一个具有 2 维( x的长度, y的长度)的布尔数组。 应该通过在轴0上使用tf.reduce_any将此数组缩减为与x长度相同的向量:

>>> tf.reduce_any(x == tf.reshape(y, (-1, 1)), axis=0)
<tf.Tensor: shape=(20,), dtype=bool, numpy=
array([False, False, False, False,  True, False, False, False, False,
       False, False,  True, False, False,  True, False, False, False,
       False, False])>

具有True值的位置是y的元素位于x内部的位置。

现在,如果您只想对yx的任何元素执行成员资格测试,则只需删除axis=0参数:

>>> tf.reduce_any(x == tf.reshape(y, (-1, 1)))
<tf.Tensor: shape=(), dtype=bool, numpy=True>

通过更改tf.reshape的第二个参数以向y添加另一个维度并遵循相同的逻辑,可以将此解决方案推广到xy的更高维度。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM