簡體   English   中英

規范化 tf.data.Dataset

[英]Normalize tf.data.Dataset

我有一個 tf.data.Dataset 圖像的輸入形狀(批量大小,128, 128, 2)和目標形狀(批量大小,128, 128, 1),其中輸入是 2 通道圖像(復值具有兩個通道的圖像表示實部和虛部),目標是 1 通道圖像(實值圖像)。 我需要先從輸入和目標圖像中刪除它們的平均圖像,然后將它們縮放到 (0,1) 范圍來標准化輸入和目標圖像。 如果我沒記錯的話, tf.data.Dataset 一次只能處理一批,而不是整個數據集。 所以我從'remove_mean'py_function中的批次中的每個圖像中刪除批次的平均圖像,然后通過減去其最小值並除以其最大值和最小值的差將每個圖像縮放到(0,1) py_function 'linear_scaling'。 但是在應用函數之前和之后從數據集中打印輸入圖像的最小值和最大值后,圖像值沒有變化。 任何人都可以建議這可能出了什么問題嗎?

def remove_mean(image, target):
    image_mean = np.mean(image, axis=0)
    target_mean = np.mean(target, axis=0)
    image = image - image_mean
    target = target - target_mean
    return image, target

def linear_scaling(image, target):
    image_min = np.ndarray.min(image, axis=(1,2), keepdims=True)
    image_max = np.ndarray.max(image, axis=(1,2), keepdims=True)
    image = (image-image_min)/(image_max-image_min)

    target_min = np.ndarray.min(target, axis=(1,2), keepdims=True)
    target_max = np.ndarray.max(target, axis=(1,2), keepdims=True)
    target = (target-target_min)/(target_max-target_min)
    return image, target

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))


Output -

tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)

map不是就地操作,因此當您執行train_dataset.map(....)時,您的train_dataset不會更改。

train_dataset = train_dataset.map(...)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM