简体   繁体   English

如何按特定值过滤 tf.data.Dataset?

[英]How can I filter tf.data.Dataset by specific values?

I create a dataset by reading the TFRecords, I map the values and I want to filter the dataset for specific values, but since the result is a dict with tensors, I am not able to get the actual value of a tensor or to check it with tf.cond() / tf.equal .我通过读取 TFRecords 创建了一个数据集,我映射了值,我想过滤数据集的特定值,但由于结果是一个带有张量的字典,我无法获得张量的实际值或检查它与tf.cond() / tf.equal How can I do that?我该怎么做?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()

I am answering my own question.我正在回答我自己的问题。 I found the issue!我发现了问题!

What I needed to do is tf.unstack() the label like this:我需要做的是tf.unstack()的标签:

label = tf.unstack(features['label'])
label = label[0]

before I give it to tf.equal() :在我把它交给tf.equal()

result = tf.reshape(tf.equal(label, 'some_label_value'), [])

I suppose the problem was that the label is defined as an array with one element of type string tf.FixedLenFeature([1], tf.string) , so in order to get the first and single element I had to unpack it (which creates a list) and then get the element with index 0, correct me if I'm wrong.我想问题是标签被定义为一个数组,其中包含一个字符串tf.FixedLenFeature([1], tf.string)类型的元素,因此为了获得第一个和单个元素,我必须将其解包(这会创建一个列表),然后获取索引为 0 的元素,如果我错了,请纠正我。

I think you don't need to make label a 1-dimensional array in the first place.我认为您不需要首先将标签设为一维数组。

with:与:

feature = {'label': tf.FixedLenFeature((), tf.string)}

you won't need to unstack the label in your filter_func您不需要在 filter_func 中取消堆叠标签

Reading, filtering a dataset is very easy and there is no need to unstack anything.读取、过滤数据集非常容易,无需拆开任何东西。

to read the dataset:读取数据集:

print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
    ##below could be large in case of image
    print(record)
    ##let us print a specific key
    print(record['key2'])

To filter is equally simple:过滤同样简单:

my_filtereddataset = my_dataset.filter(_filtcond1)

where you define _filtcond1 however you want.您可以根据需要定义 _filtcond1 。 Let us say there is a 'true' 'false' boolean flag in your dataset, then:假设您的数据集中有一个 'true' 'false' 布尔标志,然后:

@tf.function
def _filtcond1(x):
    return x['key_bool'] == 1

or even a lambda function:甚至是一个 lambda 函数:

my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)

If you are reading a dataset which you havent created or you are unaware of the keys (as seems to be the OPs case), you can use this to get an idea of the keys and structure first:如果您正在阅读尚未创建的数据集或者您不知道密钥(似乎是 OP 的情况),您可以使用它首先了解密钥和结构:

import json
from google.protobuf.json_format import MessageToJson

for raw_record in noidea_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    ##print(example) ##if image it will be toooolong
    m = json.loads(MessageToJson(example))
    print(m['features']['feature'].keys())

Now you can proceed with the filtering现在您可以继续进行过滤

You should try to use the apply function from tf.data.TFRecordDataset tensorflow documentation您应该尝试使用 tf.data.TFRecordDataset tensorflow 文档中的 apply 函数

Otherwise... read this article about TFRecords to get a better knowledge about TFRecords TFRecords for humans否则...阅读这篇关于 TFRecords 的文章,以更好地了解人类的TFRecords TFRecords

But the most likely situation is that you can not access neither modify a TFRecord...there is a request on github about this topic TFRecords request但最可能的情况是你不能访问也不能修改 TFRecord...github 上有一个关于这个主题的请求 TFRecords 请求

My advice is to make the things as easy as you can...you have to know that you are you working with graph and sessions...我的建议是让事情尽可能简单......你必须知道你正在处理图表和会话......

In any case...if everything fail try the part of the code that does not work in a tensorflow session as simple as you can do it...probably all these operations should be done when tf.session is running...在任何情况下......如果一切都失败了,请尝试尽可能简单地在 tensorflow 会话中不起作用的代码部分......可能所有这些操作都应该在 tf.session 运行时完成......

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM