[英]Tensorflow how to generate unbalanced combined data sets
我對新數據集API(tensorflow 1.4)有疑問。 我有兩個數據集,我需要創建一個組合的不平衡數據集,即每個批次應該包含來自第一個數據集的一定數量的元素和來自第二個數據集的一定數量的元素。 例如,
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([1,1,1,1,1,1]
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([2,2,2,2,2,2]))
假設批量大小為4,我希望組合數據集中的批處理看起來像[1,1,1,2]。 我知道如何使用zip和flat_map生成一個平衡的數據集,但是我對此感到茫然。
提前致謝!
為了解決這個問題,我的解決方案是單獨批處理數據集,壓縮它們,然后在生成的數據集上映射tf.concat
運算符。
在你的例子中它會給出類似的東西(我重命名了第二個數據集dataset2
):
def concat(*tensor_list):
return tf.concat(tensor_list, axis=0)
zipped_ds = tf.data.Dataset.zip((dataset1.batch(3), dataset2))
unbalanced_ds = zipped_ds.map(concat)
如果數據集是張量的嵌套結構,則可以使用以下版本的concat:
def concat(*ds_elements):
#Create one empty list for each component of the dataset
lists = [[] for _ in ds_elements[0]]
for element in ds_elements:
for i, tensor in enumerate(element):
#For each element, add all its component to the associated list
lists[i].append(tensor)
#Concatenate each component list
return tuple(tf.concat(l, axis=0) for l in lists)
如果所有數據集元素(要組合的數據集的一部分)是僅與最外層維度(相對批處理大小)不同的張量,則起作用。 它為數據集元素的每個組件構建一個列表,並將這些組件彼此連接起來。
其中處理一級嵌套。 如果你需要更多,你可以使用recurrence來解壓縮嵌套的嵌套,但它可能會給出一個不那么干凈的計算圖...
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.