简体   繁体   English

Tensorflow:如何在 tf.data.Dataset from_generator 中保留文件名?

[英]Tensorflow: how to retain file names in tf.data.Dataset from_generator?

I am struggling with the following.我正在努力解决以下问题。 I am creating a tf.data.Dataset using the from_generator method.我正在使用 from_generator 方法创建一个 tf.data.Dataset。 It works great, but after prediction I would like to investigate which samples were misclassified and why.它工作得很好,但在预测之后,我想调查哪些样本被错误分类以及为什么。 For that, I need to retrieve the file names in the same order they were fed to a model.为此,我需要按照将文件名输入到 model 的顺序来检索文件名。 How can I do it?我该怎么做?

def make_dataset(directory):
    """ Makes a dataset from generator. """

    def generator():
        files = list_files_in_directory(directory)
        random.Random(42).shuffle(files)
        print(f'Files in {directory}: {len(files)}')

        for fn in files:
            X, y = read_file(fn)
            yield X, [y]   # here I lose info about fn

    def get_shapes():
        X, _ = next(generator())
        return tf.TensorShape(X.shape), tf.TensorShape(1)

    return (
        tf.data.Dataset
            .from_generator(
                generator,
                output_types=(tf.float64, tf.uint8),
                output_shapes=get_shapes())
            .batch(128, drop_remainder=True)
            .prefetch(256))

model.fit(make_dataset(directory_train))
y_pred = model.predict(make_dataset(directory_test))

# here: what is the most elegant way to retain the input filenames for y_pred ?

Ps the shuffle in generator is needed to ensure input data randomness, since the files are being read in "lazy" mode. Ps 需要生成器中的 shuffle 以确保输入数据的随机性,因为文件是在“惰性”模式下读取的。

One possibility is to make your generator return the filename, and pass that as a debug input to your model.一种可能性是让您的生成器返回文件名,并将其作为调试输入传递给您的 model。

An end to end example.一个端到端的例子。

Lets train a simple Linear Regression, here is the definition of the model:让我们训练一个简单的线性回归,这里是 model 的定义:

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1,input_shape=(1,))])
model.compile(loss='mse', optimizer='sgd')

Let's generate some data让我们生成一些数据

Now, lets generate some random data, and associate each sample to a dummy filename现在,让我们生成一些随机数据,并将每个样本与一个虚拟文件名相关联

# generating some random data
filenames = np.array([[f"{s}.txt"] for s in string.ascii_lowercase])
X = np.random.uniform(size=(26, 1))
y = np.random.uniform(size=(26, 1))
data = np.concatenate((filenames, X, y), axis=1)

Let's look a the first element of the data:让我们看一下数据的第一个元素:

>>> data[0]
array(['a.txt', '0.36798830850651043', '0.5976948635618315'], dtype='<U32')

Training the model训练 model

We train the model with our data.我们用我们的数据训练 model。 For simplicity, let's use our arrays X and y .为简单起见,让我们使用我们的 arrays Xy

model.fit(X,y)

Creating our debug generator and our debugging model创建我们的调试生成器和我们的调试 model

So, what we want is to create a debug input in our model, that is just there to receive our filename, and output it directly without any transformation.所以,我们想要的是在我们的 model 中创建一个调试输入,它只是用于接收我们的文件名,而 output 则直接在没有任何转换的情况下创建它。 To do that, we use the Functionnal API to wrap the existing model into our debug model:为此,我们使用功能 API 将现有 model 包装到我们的调试 model 中:

debug_input = tf.keras.Input(shape=(), dtype=tf.string)
debug_model = tf.keras.Model([debug_input, model.input], [debug_input, model.output])

And now, we need to create a dataset generator that will yield a tuple (filename, feature) to feed our debug model:现在,我们需要创建一个数据集生成器,它将产生一个元组(filename, feature)来提供我们的调试 model:

def debug_gen(data):
    # shuffling data
    np.random.shuffle(data)
    for filename, feature, label in data:
        yield (filename, feature), label


debug_ds = tf.data.Dataset.from_generator(
    lambda: debug_gen(data),
    output_types=((tf.string, tf.float64), tf.float64),
    output_shapes=((tf.TensorShape(()), tf.TensorShape(())), tf.TensorShape(())),
).batch(1)

Now, if we call predict on one item from our generator, we should get the filename as an output, as well as the prediction:现在,如果我们对生成器中的一项调用predict ,我们应该得到文件名 output,以及预测:

>>> debug_model.predict(debug_ds.take(1))
[array([b'a.txt'], dtype=object), array([[-0.7604195]], dtype=float32)]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM