簡體   English   中英

從 Tensorflow 中的張量中讀取文件名

[英]Reading in file names from a tensor in Tensorflow

上下文:我正在嘗試制作一個 GAN 以從大型數據集生成圖像,並且在加載訓練數據時遇到了 OOM 問題。 為了解決這個問題,我試圖傳入一個文件目錄列表,並僅在需要時將它們作為圖像讀入。

問題:我不知道如何從張量本身解析出文件名。 如果有人對如何將張量轉換回列表或以某種方式遍歷張量有任何見解。 或者,如果這是解決此問題的糟糕方法,請告訴我

相關代碼片段:

生成數據: 注意: make_file_list()返回我要讀取的所有圖像的文件名列表

data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)

培訓 function:

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)

訓練步驟:

@tf.function
def train_step(image_files):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    images = [load_img(filepath) for filepath in image_files]

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

錯誤:

line 250, in train_step  *
        images = [load_img(filepath) for filepath in image_files]

    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature

刪除 train_step 上的@tf.function train_step器。 如果你用@tf.function裝飾你的train_step ,Tensorflow會嘗試將train_step中的train_step代碼轉換成一個執行圖,而不是在eager模式下運行。 執行圖提供了加速,但也對可以執行的運算符施加了一些限制(如錯誤所述)。

要將@tf.function保留在train_step上,您可以先在train function 中執行迭代和加載步驟,然后將已加載的圖像作為參數傳遞給train_step而不是嘗試直接在train_step中加載圖像

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

    for image_batch in dataset:
        images = [load_img(filepath) for filepath in image_batch ]
        gen_loss, disc_loss = train_step(images)

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        ....

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM