[英]Reading in file names from a tensor in Tensorflow
上下文:我正在嘗試制作一個 GAN 以從大型數據集生成圖像,並且在加載訓練數據時遇到了 OOM 問題。 為了解決這個問題,我試圖傳入一個文件目錄列表,並僅在需要時將它們作為圖像讀入。
問題:我不知道如何從張量本身解析出文件名。 如果有人對如何將張量轉換回列表或以某種方式遍歷張量有任何見解。 或者,如果這是解決此問題的糟糕方法,請告訴我
相關代碼片段:
生成數據: 注意: make_file_list()
返回我要讀取的所有圖像的文件名列表
data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)
培訓 function:
def train(dataset, epochs):
plot_iteration = []
gen_loss_l = []
disc_loss_l = []
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
gen_loss, disc_loss = train_step(image_batch)
訓練步驟:
@tf.function
def train_step(image_files):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
images = [load_img(filepath) for filepath in image_files]
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
錯誤:
line 250, in train_step *
images = [load_img(filepath) for filepath in image_files]
OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature
刪除 train_step 上的@tf.function
train_step
器。 如果你用@tf.function裝飾你的train_step
,Tensorflow會嘗試將train_step中的train_step
代碼轉換成一個執行圖,而不是在eager模式下運行。 執行圖提供了加速,但也對可以執行的運算符施加了一些限制(如錯誤所述)。
要將@tf.function
保留在train_step
上,您可以先在train
function 中執行迭代和加載步驟,然后將已加載的圖像作為參數傳遞給train_step
而不是嘗試直接在train_step
中加載圖像
def train(dataset, epochs):
plot_iteration = []
gen_loss_l = []
disc_loss_l = []
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
images = [load_img(filepath) for filepath in image_batch ]
gen_loss, disc_loss = train_step(images)
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
....
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.