简体   繁体   中英

Dataset.from_generator: TypeError: `generator` must be callable

I am currently using a generator to produce my training and validation datasets using tf.data.Dataset.from_generator . I have a class method that takes care of this for me:

def build_dataset(self, batch_size=16, shuffle=16, validation=None):
    
    train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
    self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
    
    if validation is not None:
        val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
        self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)

The problem is passing in (validation=validation) to my import_images generator creates the generator object which Tensorflow doesn't want, and it gives me the error:

TypeError: `generator` must be callable.

Because I have to pass in validation to tell my generator to produce a separate training and validation version, I am required to create two versions of the same generator. It also doesn't allow me to pass in other arguments to control the percentage of training and validation examples - meaning the generator has to be static. Any suggestions?

I recently encountered a similar problem, but I'm a beginner so not sure if this will help.

Try add a call function in your class.

Below are the original class which raise TypeError: `generator` must be callable.

class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr

Then I revised the code as below and it worked for me.

class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr
  
  def __call__(self):
    self.i = 0
    return self

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM