简体   繁体   English

如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数?

[英]How do you send arguments to a generator function using tf.data.Dataset.from_generator()?

I would like to create a number of tf.data.Dataset using the from_generator() function.我想使用from_generator()函数创建一些tf.data.Dataset I would like to send an argument to the generator function ( raw_data_gen ).我想向生成器函数( raw_data_gen )发送一个参数。 The idea is that the generator function will yield different data depending on the argument sent.这个想法是生成器函数将根据发送的参数产生不同的数据。 In this way I would like raw_data_gen to be able to provide either training, validation or test data.通过这种方式,我希望raw_data_gen能够提供训练、验证或测试数据。

training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))

validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))

test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))

The error message I get when I try to call from_generator() in this way is:当我尝试以这种方式调用from_generator()时得到的错误消息是:

TypeError: from_generator() got an unexpected keyword argument 'args'

Here is the raw_data_gen function although I'm not sure if you will need this as my hunch is that the problem is with the call of from_generator() :这是raw_data_gen函数,虽然我不确定您是否需要它,因为我的直觉是问题出在from_generator()的调用上:

def raw_data_gen(train_val_or_test):

    if train_val_or_test == 1:        
        #For every filename collected in the list
        for filename, lab in training_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 2:
        #For every filename collected in the list
        for filename, lab in validation_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 3:
        #For every filename collected in the list
        for filename, lab in test_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    else:
        print("generator function called with an argument not in [1, 2, 3]")
        raise ValueError()

You need to define a new function based on raw_data_gen that doesn't take any arguments.您需要基于raw_data_gen定义一个不带任何参数的新函数。 You can use the lambda keyword to do this.您可以使用lambda关键字来执行此操作。

training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...

Now, we are passing a function to from_generator that doesn't take any arguments, but that will simply act as raw_data_gen with the argument set to 1. You can use the same scheme for the validation and test sets, passing 2 and 3 respectively.现在,我们将一个不带任何参数的函数传递给from_generator ,但它只会充当raw_data_gen并将参数设置为 1。您可以对验证集和测试集使用相同的方案,分别传递 2 和 3。

For Tensorflow 2.4:对于 Tensorflow 2.4:

training_dataset = tf.data.Dataset.from_generator(
     raw_data_gen, 
     args=(1), 
     output_types=(tf.float32, tf.uint8), 
     output_shapes=([None, 1], [None]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 tf.data.Dataset.from_generator() 时的参数化生成器 - Parametrized generators while using tf.data.Dataset.from_generator() 使用 tf.data.Dataset.from_generator 时出错 - Error when using tf.data.Dataset.from_generator 如何加速 tf.data.Dataset.from_generator() - How to speed up tf.data.Dataset.from_generator() 如何在 tf.data.Dataset.from_generator 中保留字典键? - How to preserve dict keys in tf.data.Dataset.from_generator? 使用 tf.data.Dataset.from_generator() 从生成器加载数据 - Loading data from generator using tf.data.Dataset.from_generator() 如何使用自定义生成器使tf.data.Dataset.from_generator产生批处理 - How to make tf.data.Dataset.from_generator yield batches with a custom generator 在 tf.data.Dataset.from_generator() 上应用扩充 - Apply augmentation on tf.data.Dataset.from_generator() 如何使用 tf.data.Dataset.from_generator() 从数据集中一次只加载一批? - How to use tf.data.Dataset.from_generator() to load only one batch at a time from the dataset? 提供给`tf.data.Dataset.from_generator(...)`的map函数可以解析张量对象吗? - Can the map function supplied to `tf.data.Dataset.from_generator(…)` resolve a tensor object? 使用tf.data.Dataset.from_generator时出现“ SystemError:没有设置异常的错误返回” - “SystemError: error return without exception set” when using tf.data.Dataset.from_generator
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM