簡體   English   中英

如何連接兩個 Tensorflow 數據集?

[英]How to concatenate two Tensorflow DataSets?

我正在嘗試加載然后擴充一些圖像( 160 x 160 x 3 )數據集,其中圖像存儲在文件夾中,文件夾名稱對我來說是 label。 正在應用多種轉換來生成數據副本,並且需要將它們concatenated (or stacked may be) ,以便將數據連接起來並將它們存儲回磁盤。

下面是我能夠編寫的最簡單的可重現片段,我無法append/concatenate/stack這兩個數據集。

def some_transformation(image, label):
    # do something like rotation, clipping, noise add etc.
    return image, label

userA = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)

userB = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)

print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
transformed_userA_w_label.concatenate(transformed_userB_w_label)

Output的打印語句如下:

User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>

預期: 6張圖片

Output ds - <ConcatenateDataset shapes: ((6, 160, 160, 3), (6, 2)), types: (tf.float64, tf.float64)>

這里的關鍵問題是tf.data.Dataset.from_tensorstf.data.Dataset.from_tensor_slices的使用。

  • tf.data.Dataset.from_tensors([t1,t2,t3]) - 創建一個數據集,其中列表的每個元素都作為數據點給出
  • tf.data.Dataset.from_tensor_slices(t) - 創建一個數據集,其中一個元素是在第一個軸上索引的項目

根據您擁有的數據(即 3 張尺寸為 160x160x3 的圖像,即3x160x160x3 ),您需要使用第二種方法。 否則,您的所有 3 張圖像都將作為單個數據點(這可能不是您想要的)。

轉到第二個問題,您展示的 output,

User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>

它只是顯示單個元素的外觀。 因此,即使您的代碼正確,您也不會看到您想要的6 要查看元素的數量,您必須迭代數據集。 在您的情況下,您將看到2 (因為此數據集將所有 3 個圖像視為單個數據點)。

因此,要修復您的代碼,請執行此操作,

def some_transformation(image, label):
    # do something like rotation, clipping, noise add etc.
    return image, label

userA = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)

userB = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)

print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
concat_ds = transformed_userA_w_label.concatenate(transformed_userB_w_label)
print(concat_ds)

for i,ii in enumerate(concat_ds):
  print(i)

你會看到i的值被打印了 6 次。 這就是你需要的。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM