![](/img/trans.png)
[英]How to perform efficient membership search on very large datasets in Python
[英]Python Membership Operators “In” TensorFlow Datasets
我正在使用 TensorFlow 数据集开发输入管道,数据集只有两列,我想根据值列表进行过滤,但我只能使用运算符等于“==”过滤数据集,当我尝试使用会员操作员“在”我收到以下错误。
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.
在我的代码下面:
import numpy as np
import tensorflow as tf
# Load file
file_path = 'drive/My Drive/Datasets/category_catalog.csv.gz'
def get_dataset(file_path, batch_size=5, num_epochs=1, **kwargs):
return tf.data.experimental.make_csv_dataset(
file_path,
batch_size=batch_size,
na_value="?",
num_epochs=num_epochs,
ignore_errors=True,
**kwargs
)
raw_data = get_dataset(
file_path,
select_columns=['description', 'department'],
compression_type='GZIP'
)
此过滤器有效:
@tf.function
def filter_fn(features):
return features['department'] == 'MOVEIS'
ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)
Output:
next(iter(ds))
OrderedDict([('description', <tf.Tensor: shape=(2,), dtype=string, numpy=
array([b'KIT DE COZINHA KITS PARANA 8 PORTAS GOLDEN EM MDP LINHO BRANCO E LINHO PRETO',
b'ARMARIO AEREO PARA COZINHA 1 PORTA HORIZONTAL EXCLUSIVE ITATIAIA PRETO MATTE'],
dtype=object)>),
('department',
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'MOVEIS', b'MOVEIS'], dtype=object)>)])
此过滤器不起作用:
@tf.function
def filter_fn(features):
return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']
ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)
错误:
---------------------------------------------------------------------------
OperatorNotAllowedInGraphError Traceback (most recent call last)
<ipython-input-52-52131b5369b6> in <module>()
6
7 ds = raw_data.unbatch()
----> 8 ds = ds.filter(filter_fn)
9 ds = ds.batch(2)
18 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
OperatorNotAllowedInGraphError: in user code:
<ipython-input-52-52131b5369b6>:5 filter_fn *
return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:778 __bool__
self._disallow_bool_casting()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:545 _disallow_bool_casting
"using a `tf.Tensor` as a Python `bool`")
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:532 _disallow_when_autograph_enabled
" decorating it directly with @tf.function.".format(task))
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.
这里是 Colab 的链接,可以在其中运行并检查错误:
我为过滤器 function 尝试了以下内容:
tf.reduce_any(tf.math.equal(features['department'], ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']))
它似乎在您提供的 colab 中有效,但我不确定它是否是您想要的。
逻辑如下: math.equal
运算符将给出一个大小为 3 的张量,其中每个条目是True
或False
。 第一个入口是部门是否是“FERRAMENTAS”等......然后reduce_any
基本上会对这3入口张量进行逻辑或。 因此,如果部门是 3 个白名单之一,它将在 3 个条目张量中恰好有一个True
条目,因此reduce_any
output 将为True
。 在所有其他情况下它将为False
。
即使前面的答案解决了你的问题,我还是想为将来来这里的人提出一个更通用的答案。
假设您有一个任意张量x
,例如:
>>> x = tf.range(20)
>>> x
<tf.Tensor: shape=(20,), dtype=int32, numpy=
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19])>
如果我们想获取某些元素的位置,例如4
、 11
、 14
,我们可以将它们存储在张量y
中:
>>> y = tf.constant([4, 11, 14])
>>> y
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 4, 11, 14])>
然后在值向量x
和搜索到的元素向量y
的转置之间使用相等运算。 结果将是一个具有 2 维( x
的长度, y
的长度)的布尔数组。 应该通过在轴0
上使用tf.reduce_any
将此数组缩减为与x
长度相同的向量:
>>> tf.reduce_any(x == tf.reshape(y, (-1, 1)), axis=0)
<tf.Tensor: shape=(20,), dtype=bool, numpy=
array([False, False, False, False, True, False, False, False, False,
False, False, True, False, False, True, False, False, False,
False, False])>
具有True
值的位置是y
的元素位于x
内部的位置。
现在,如果您只想对y
中x
的任何元素执行成员资格测试,则只需删除axis=0
参数:
>>> tf.reduce_any(x == tf.reshape(y, (-1, 1)))
<tf.Tensor: shape=(), dtype=bool, numpy=True>
通过更改tf.reshape
的第二个参数以向y
添加另一个维度并遵循相同的逻辑,可以将此解决方案推广到x
和y
的更高维度。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.