I am currently using a generator to produce my training and validation datasets using tf.data.Dataset.from_generator
. I have a class method that takes care of this for me:
def build_dataset(self, batch_size=16, shuffle=16, validation=None):
train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
if validation is not None:
val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)
The problem is passing in (validation=validation)
to my import_images
generator creates the generator object which Tensorflow doesn't want, and it gives me the error:
TypeError: `generator` must be callable.
Because I have to pass in validation
to tell my generator to produce a separate training and validation version, I am required to create two versions of the same generator. It also doesn't allow me to pass in other arguments to control the percentage of training and validation examples - meaning the generator has to be static. Any suggestions?
I recently encountered a similar problem, but I'm a beginner so not sure if this will help.
Try add a call function in your class.
Below are the original class which raise TypeError: `generator` must be callable.
class DataGen:
def __init__(self, files, data_path):
self.i = 0
self.files=files
self.data_path=data_path
def __load__(self, files_name):
data_path = os.path.join(self.data_path, files_name)
arr_img, arr_mask = load_patch(data_path)
return arr_img, arr_mask
def getitem(self, index):
_img, _mask = self.__load__(self.files[index])
return _img, _mask
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.files):
img_arr, mask_arr = self.getitem(self.i)
self.i += 1
else:
raise StopIteration()
return img_arr, mask_arr
Then I revised the code as below and it worked for me.
class DataGen:
def __init__(self, files, data_path):
self.i = 0
self.files=files
self.data_path=data_path
def __load__(self, files_name):
data_path = os.path.join(self.data_path, files_name)
arr_img, arr_mask = load_patch(data_path)
return arr_img, arr_mask
def getitem(self, index):
_img, _mask = self.__load__(self.files[index])
return _img, _mask
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.files):
img_arr, mask_arr = self.getitem(self.i)
self.i += 1
else:
raise StopIteration()
return img_arr, mask_arr
def __call__(self):
self.i = 0
return self
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.