简体   繁体   English

从 .tfrecord 到 tf.data.Dataset 到 tf.keras.model.fit

[英]From .tfrecord to tf.data.Dataset to tf.keras.model.fit

I am attemping to use Tensorflow (v2.0)'s Datasets API to pass large amounts of data to a tf.keras.model .我正在尝试使用 Tensorflow (v2.0) 的 Datasets API 将大量数据传递给tf.keras.model Here is a simplified version of my dataset:这是我的数据集的简化版本:

for rec in my_dataset:
    print(repr(rec))

$ {'feature0': <tf.Tensor: id=528, shape=(), dtype=float32, numpy=0.2963>,
'feature1': <tf.Tensor: id=618, shape=(), dtype=int64, numpy=0>,
'feature2': <tf.Tensor: id=620, shape=(), dtype=string, numpy=b'Inst1'>,
'target': <tf.Tensor: id=621, shape=(), dtype=int64, numpy=2>}
{'feature0': <tf.Tensor: id=528, shape=(), dtype=float32, numpy=0.4633>,
'feature1': <tf.Tensor: id=618, shape=(), dtype=int64, numpy=1>,
'feature2': <tf.Tensor: id=620, shape=(), dtype=string, numpy=b'Inst4'>,
'target': <tf.Tensor: id=621, shape=(), dtype=int64, numpy=0>}

...and so on. ...等等。 Each record in the my_dataset object is a dictionary with the features' (and target's) names as the keys and associated tensors as the values. my_dataset对象中的每条记录都是一个字典,以特征(和目标)名称作为键,关联张量作为值。 I created the dataset from several .tfrecord files, so I'm constrained in the sense that each tensor corresponds to a tf.train.Example (wrapper) object.我从几个 .tfrecord 文件创建了数据集,所以我在每个张量对应一个tf.train.Example (包装器)对象的意义上受到限制。 The dataset precisely matches the format seen in tensorflow documentation (see, for example, the last code example in https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file ).数据集与 tensorflow 文档中的格式完全匹配(例如,参见https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file 中的最后一个代码示例)。

I would like to use this dataset with keras.我想将此数据集与 keras 一起使用。 The tf.keras.model objects I'm working with all seem, for their fit function, to take as input a tuple representing the feature vector (X) and the target (y).我正在使用的tf.keras.model对象,对于它们的fit函数,似乎都将表示特征向量 (X) 和目标 (y) 的元组作为输入。 I think I could figure out how to transform the tensors from my dataset into numpy arrays and pass them into the model that way, or iterate over the dataset using an iterator, but if I understand correctly that seems to defeat the whole purpose of using the Datasets API to begin with (see, for example, https://www.tensorflow.org/guide/keras/overview#train_from_tfdata_datasets ).我想我可以弄清楚如何将我的数据集中的张量转换为 numpy 数组并以这种方式将它们传递到模型中,或者使用迭代器迭代数据集,但是如果我理解正确,这似乎违背了使用数据集 API(例如,参见https://www.tensorflow.org/guide/keras/overview#train_from_tfdata_datasets )。

My question: what is the appropriate way to transform my_dataset into some form that tf.keras.model.fit() will receive?我的问题:将my_dataset转换为tf.keras.model.fit()将接收的某种形式的适当方法是什么? Or if this is the wrong question, what fundamental concepts am I missing that keep me from asking the right one?或者,如果这是一个错误的问题,我错过了哪些基本概念使我无法提出正确的问题? (For example, should the .tfrecord Examples be structured differently? Or, am I required to use an iterator instead of directly passing my_dataset to the model as I'd prefer?) (例如,.tfrecord 示例的结构是否应该不同?或者,我是否需要使用迭代器而不是像我希望的那样直接将my_dataset传递给模型?)

Unfortunately I was only able to find a workaround instead of an outright solution.不幸的是,我只能找到一种解决方法而不是彻底的解决方案。 Because tf.stack will only work on items of the same data type, I need to transform all data into floats during processing of the Examples (including one-hot encoding for all strings), and then use tf.stack on the resulting tensor:因为tf.stack仅适用于相同数据类型的项目,所以我需要在示例的处理过程中将所有数据转换为浮点数(包括所有字符串的 one-hot 编码),然后在生成的张量上使用tf.stack

def proces_example(serialized_example):
    feature_description = get_feature_desc()  # dictionary describing features and dtypes
    target_name = get_target_name()  # so we don't include the target in our feature vector
    parsed_example = tf.io.parse_single_example(serialized_example, feature_description)
    tensor_list = []
    for tensor in parsed_example:
        if tensor != target_name:
            parsed_example[tensor] = tf.dtypes.cast(parsed_example[tensor], tf.float32)
            tensor_list.append(parsed_example[tensor])
    X = tf.stack(tensor_list)
    y = parsed_example[target_name]
    return X, y

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM