簡體   English   中英

讀取 Tensorflow 數據集改變了 `take()` 和 `skip()` 的行為

[英]Reading Tensorflow Dataset changes bahaviour of `take()` and `skip()`

我正在嘗試檢查 tensorflow 數據集中的標簽。 但是,在使用take()skip()后,標簽的值會變為意外,這取決於我是否檢查數據。 (看起來在標簽中有些標簽變成了零。)我看不到我的檢查 function 可以改變數據集的任何方式。 我錯過了什么?

要重現該行為,請更改LOOK_AT_DATA_TWICE變量。

# python 3.9.4, tensorflow 2.5.0-rc1
import numpy as np
import tensorflow as tf

tf.random.set_seed(42)


def inspect_dataset(ds, msg="", print_all=True):
    sum_ones = 0
    sum_zeros = 0
    for (sig, label) in ds.as_numpy_iterator():
        if print_all:
            print(msg, label, np.histogram(label, bins=2)[0])
        sum_ones += np.sum(label)
        sum_zeros += np.sum(label - 1)

    print(msg, "SUM of ones=", sum_ones)
    print(msg, "SUM of zero=", sum_zeros)


all_pattern = np.random.random((4000, 1000))
all_labels = np.array(2000 * [0] + 2000 * [1])

print(f"all_pattern.shape={all_pattern.shape}")
print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}")
print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}")

complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels))
complete_ds = complete_ds.shuffle(len(all_labels))

LOOK_AT_DATA_TWICE = True  # This changes the numbers output below
if LOOK_AT_DATA_TWICE:
    inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)

validation_split=0.5
num_test_samples = int(validation_split * len(all_labels))
train_ds = complete_ds.skip(num_test_samples)
val_ds = complete_ds.take(num_test_samples)

inspect_dataset(train_ds, msg="train_ds in generation", print_all=False)
inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)

Output 與LOOK_AT_DATA_TWICE = True

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]

complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 997
train_ds in generation SUM of zero= -1003
val_ds in generation SUM of ones= 988
val_ds in generation SUM of zero= -1012

Output 與LOOK_AT_DATA_TWICE = False

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
   
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 1031
train_ds in generation SUM of zero= -969
val_ds in generation SUM of ones= 1003
val_ds in generation SUM of zero= -997

當數據集用盡時(即,在您迭代一次之后),它將重做所有操作。 在您的情況下,因為您正在洗牌,所以第一個時期的洗牌將不同於第二個時期的洗牌。

這意味着您的訓練集和測試集實際上在 epochs 之間並不一致

您可以將reshuffle_each_iteration設置為對 shuffle 的調用,以使 shuffle 在每次迭代中的行為都相同。 如果您仍然想要為您的火車組進行不同的洗牌,您應該再次調用它。

ds = tf.data.Dataset.from_tensor_slices(data)
shuffled_ds = ds.shuffle(reshuffle_each_iteration=False)
train_ds = shuffled_ds.take(train_size)
train_ds = train_ds.shuffle()
test_ds = shuffled_ds.skip(train_size)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM