繁体   English   中英

如何使用 tf.data.Dataset 加载和映射字典/jsons 列表

[英]How to load and map list of dictionaries/jsons with tf.data.Dataset

我有一个存储在字典列表中的数据集/记录。 字典可能非常复杂。 我想通过 TensorFlow 数据集 API 加载此列表。 我怎样才能做到这一点? 然而,我尝试了这样的事情,但它不起作用:

import tensorflow as tf
import json

LABELS_IDS = ["cat", "dog", "animal"]

def parse_record(record):
    image = tf.io.read_file(record["_file"])
    image = tf.image.decode_jpeg(image)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [224, 224])
    image = tf.image.random_flip_left_right(image, seed=None)

    labels = []
    for element in record["_categories"]:
        if element in LABELS_IDS:
            labels.append(LABELS_IDS.index(element))

    one_hot_labels = tf.reduce_sum(tf.one_hot(labels, len(LABELS_IDS)), axis=0)
    return image, one_hot_labels

records = [{"_file":"images/test.jpg", "_categories": ["cat", "animal"]}]
    
train_x = tf.data.Dataset.from_tensor_slices(records).map(parse_record)

编辑:

我找到了答案,您可以简单地将记录映射到不同的方法:

LABELS_IDS = ["cat", "dog", "animal"]
records = [{"_file":"images/test.jpg", "_categories": ["cat", "animal"]}]

def _load_files(records):
    return [record["_file"] for record in records]

def _load_labels(records):
    vectors = []
    for record in records:
        labels = []
        for element in record["_categories"]:
            if element in LABELS_IDS:
                labels.append(LABELS_IDS.index(element))

        one_hot = tf.reduce_sum(tf.one_hot(present, len(LABELS_IDS)), axis=0)
        vectors.append(one_hot.numpy())
    return vectors

def _load_data(file_path, label):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=3, expand_animations=False)
    return image, label

data = (
  _load_files(records),
  _load_labels(records)
)

train_x = tf.data.Dataset.from_tensor_slices(data).map(_load_data)

为了社区的利益,我在这里添加@Cospel 答案

LABELS_IDS = ["cat", "dog", "animal"]
records = [{"_file":"images/test.jpg", "_categories": ["cat", "animal"]}]

def _load_files(records):
    return [record["_file"] for record in records]

def _load_labels(records):
    vectors = []
    for record in records:
        labels = []
        for element in record["_categories"]:
            if element in LABELS_IDS:
                labels.append(LABELS_IDS.index(element))

        one_hot = tf.reduce_sum(tf.one_hot(present, len(LABELS_IDS)), axis=0)
        vectors.append(one_hot.numpy())
    return vectors

def _load_data(file_path, label):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=3, expand_animations=False)
    return image, label

data = (
  _load_files(records),
  _load_labels(records)
)

train_x = tf.data.Dataset.from_tensor_slices(data).map(_load_data)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM