[英]Reading Tensorflow Dataset changes bahaviour of `take()` and `skip()`
I am trying to inspect the labels inside my tensorflow dataset.我正在尝试检查 tensorflow 数据集中的标签。 However, the values of the labels change to something unexpected after using take()
and skip()
, depending on whether I inspect the data or not.但是,在使用take()
和skip()
后,标签的值会变为意外,这取决于我是否检查数据。 (It looks like within the labels some ones changed to zeros.) I do not see any way that my inspection function could change the dataset. (看起来在标签中有些标签变成了零。)我看不到我的检查 function 可以改变数据集的任何方式。 What am I missing?我错过了什么?
To reproduce the behaviour, change the LOOK_AT_DATA_TWICE
variable.要重现该行为,请更改LOOK_AT_DATA_TWICE
变量。
# python 3.9.4, tensorflow 2.5.0-rc1
import numpy as np
import tensorflow as tf
tf.random.set_seed(42)
def inspect_dataset(ds, msg="", print_all=True):
sum_ones = 0
sum_zeros = 0
for (sig, label) in ds.as_numpy_iterator():
if print_all:
print(msg, label, np.histogram(label, bins=2)[0])
sum_ones += np.sum(label)
sum_zeros += np.sum(label - 1)
print(msg, "SUM of ones=", sum_ones)
print(msg, "SUM of zero=", sum_zeros)
all_pattern = np.random.random((4000, 1000))
all_labels = np.array(2000 * [0] + 2000 * [1])
print(f"all_pattern.shape={all_pattern.shape}")
print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}")
print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}")
complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels))
complete_ds = complete_ds.shuffle(len(all_labels))
LOOK_AT_DATA_TWICE = True # This changes the numbers output below
if LOOK_AT_DATA_TWICE:
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
validation_split=0.5
num_test_samples = int(validation_split * len(all_labels))
train_ds = complete_ds.skip(num_test_samples)
val_ds = complete_ds.take(num_test_samples)
inspect_dataset(train_ds, msg="train_ds in generation", print_all=False)
inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)
Output with LOOK_AT_DATA_TWICE = True
: Output 与LOOK_AT_DATA_TWICE = True
:
all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 997
train_ds in generation SUM of zero= -1003
val_ds in generation SUM of ones= 988
val_ds in generation SUM of zero= -1012
Output with LOOK_AT_DATA_TWICE = False
: Output 与LOOK_AT_DATA_TWICE = False
:
all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 1031
train_ds in generation SUM of zero= -969
val_ds in generation SUM of ones= 1003
val_ds in generation SUM of zero= -997
When the dataset is exhausted (ie, after you iterated through it once), it will redo all the operations.当数据集用尽时(即,在您迭代一次之后),它将重做所有操作。 In your case, because you are shuffling, the shuffle for the first epoch will be different than the shuffling for the second.在您的情况下,因为您正在洗牌,所以第一个时期的洗牌将不同于第二个时期的洗牌。
What it means is that your training set and testing set are actually not consistent between epochs .这意味着您的训练集和测试集实际上在 epochs 之间并不一致。
You can set reshuffle_each_iteration
to the call to shuffle to make the shuffle behave the same at each iteration.您可以将reshuffle_each_iteration
设置为对 shuffle 的调用,以使 shuffle 在每次迭代中的行为都相同。 If you still want a different shuffle for your train set, you should call it again.如果您仍然想要为您的火车组进行不同的洗牌,您应该再次调用它。
ds = tf.data.Dataset.from_tensor_slices(data)
shuffled_ds = ds.shuffle(reshuffle_each_iteration=False)
train_ds = shuffled_ds.take(train_size)
train_ds = train_ds.shuffle()
test_ds = shuffled_ds.skip(train_size)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.