简体   繁体   English

读取 Tensorflow 数据集改变了 `take()` 和 `skip()` 的行为

[英]Reading Tensorflow Dataset changes bahaviour of `take()` and `skip()`

I am trying to inspect the labels inside my tensorflow dataset.我正在尝试检查 tensorflow 数据集中的标签。 However, the values of the labels change to something unexpected after using take() and skip() , depending on whether I inspect the data or not.但是,在使用take()skip()后,标签的值会变为意外,这取决于我是否检查数据。 (It looks like within the labels some ones changed to zeros.) I do not see any way that my inspection function could change the dataset. (看起来在标签中有些标签变成了零。)我看不到我的检查 function 可以改变数据集的任何方式。 What am I missing?我错过了什么?

To reproduce the behaviour, change the LOOK_AT_DATA_TWICE variable.要重现该行为,请更改LOOK_AT_DATA_TWICE变量。

# python 3.9.4, tensorflow 2.5.0-rc1
import numpy as np
import tensorflow as tf

tf.random.set_seed(42)


def inspect_dataset(ds, msg="", print_all=True):
    sum_ones = 0
    sum_zeros = 0
    for (sig, label) in ds.as_numpy_iterator():
        if print_all:
            print(msg, label, np.histogram(label, bins=2)[0])
        sum_ones += np.sum(label)
        sum_zeros += np.sum(label - 1)

    print(msg, "SUM of ones=", sum_ones)
    print(msg, "SUM of zero=", sum_zeros)


all_pattern = np.random.random((4000, 1000))
all_labels = np.array(2000 * [0] + 2000 * [1])

print(f"all_pattern.shape={all_pattern.shape}")
print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}")
print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}")

complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels))
complete_ds = complete_ds.shuffle(len(all_labels))

LOOK_AT_DATA_TWICE = True  # This changes the numbers output below
if LOOK_AT_DATA_TWICE:
    inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)

validation_split=0.5
num_test_samples = int(validation_split * len(all_labels))
train_ds = complete_ds.skip(num_test_samples)
val_ds = complete_ds.take(num_test_samples)

inspect_dataset(train_ds, msg="train_ds in generation", print_all=False)
inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)

Output with LOOK_AT_DATA_TWICE = True : Output 与LOOK_AT_DATA_TWICE = True

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]

complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 997
train_ds in generation SUM of zero= -1003
val_ds in generation SUM of ones= 988
val_ds in generation SUM of zero= -1012

Output with LOOK_AT_DATA_TWICE = False : Output 与LOOK_AT_DATA_TWICE = False

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
   
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 1031
train_ds in generation SUM of zero= -969
val_ds in generation SUM of ones= 1003
val_ds in generation SUM of zero= -997

When the dataset is exhausted (ie, after you iterated through it once), it will redo all the operations.当数据集用尽时(即,在您迭代一次之后),它将重做所有操作。 In your case, because you are shuffling, the shuffle for the first epoch will be different than the shuffling for the second.在您的情况下,因为您正在洗牌,所以第一个时期的洗牌将不同于第二个时期的洗牌。

What it means is that your training set and testing set are actually not consistent between epochs .这意味着您的训练集和测试集实际上在 epochs 之间并不一致

You can set reshuffle_each_iteration to the call to shuffle to make the shuffle behave the same at each iteration.您可以将reshuffle_each_iteration设置为对 shuffle 的调用,以使 shuffle 在每次迭代中的行为都相同。 If you still want a different shuffle for your train set, you should call it again.如果您仍然想要为您的火车组进行不同的洗牌,您应该再次调用它。

ds = tf.data.Dataset.from_tensor_slices(data)
shuffled_ds = ds.shuffle(reshuffle_each_iteration=False)
train_ds = shuffled_ds.take(train_size)
train_ds = train_ds.shuffle()
test_ds = shuffled_ds.skip(train_size)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 读取张量流中的大数据集 - reading a large dataset in tensorflow 在TensorFlow中读取与mnist数据集相同格式的新数据集 - Reading a new dataset in the same format as mnist dataset is read in TensorFlow 跳过行,但在读取 python 中的 csv 时取信息 - Skip rows, but take information when reading csv in python TensorFlow-tf.data.Dataset读取大型HDF5文件 - TensorFlow - tf.data.Dataset reading large HDF5 files tensorflow - tf.data.Dataset在批处理之前随机跳过样本以获得不同的批次 - tensorflow - tf.data.Dataset randomly skip samples before batching to get different batches 使用 '.take(n)' 方法在 tensorflow 2 教程(用于语言理解的变压器 model)中减小训练数据集的大小不起作用 - Reducing size of training dataset in tensorflow 2 tutorial (Transformer model for language understanding) with '.take(n)' method does not work Tensorflow tf.data.Dataset.cache 似乎没有达到预期的效果 - Tensorflow tf.data.Dataset.cache seems do not take the expected effect Tensorflow 数据集 - dataset.repeat() - Tensorflow Dataset - dataset.repeat() 使用experimental.make_csv_dataset读取tensorflow中的CSV文件时出错 - Error with reading CSV files in tensorflow using experimental.make_csv_dataset 不同形状的 Tensorflow 数据集 - Tensorflow Dataset with different shapes
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM