[英]Custom Keras Data Generator with yield
I am trying to create a custom data generator and don't know how integrate the yield
function combined with an infinite loop inside the __getitem__
method. 我正在尝试创建一个自定义数据生成器,但不知道如何在
__getitem__
方法内部集成yield
函数和无限循环。
EDIT : After the answer I realized that the code I am using is a Sequence
which doesn't need a yield
statement. 编辑 :答案后,我意识到我正在使用的代码是不需要
yield
语句的Sequence
。
Currently I am returning multiple images with a return
statement: 目前,我正在使用
return
语句返回多个图像:
class DataGenerator(tensorflow.keras.utils.Sequence):
def __init__(self, files, labels, batch_size=32, shuffle=True, random_state=42):
'Initialization'
self.files = files
self.labels = labels
self.batch_size = batch_size
self.shuffle = shuffle
self.random_state = random_state
self.on_epoch_end()
def __len__(self):
return int(np.floor(len(self.files) / self.batch_size))
def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
files_batch = [self.files[k] for k in indexes]
y = [self.labels[k] for k in indexes]
# Generate data
x = self.__data_generation(files_batch)
return x, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.files))
if self.shuffle == True:
np.random.seed(self.random_state)
np.random.shuffle(self.indexes)
def __data_generation(self, files):
imgs = []
for img_file in files:
img = cv2.imread(img_file, -1)
###############
# Augment image
###############
imgs.append(img)
return imgs
In this article I saw that yield
is used in an infinite loop. 在本文中,我看到了
yield
在无限循环中使用。 I don't quite understand that syntax. 我不太了解这种语法。 How is the loop escaped?
循环如何逃逸?
You are using the Sequence API, which works a bit different than plain generators. 您正在使用Sequence API,该API与普通生成器的工作原理有所不同。 In a generator function, you would use the
yield
keyword to perform iteration inside a while True:
loop, so each time Keras calls the generator, it gets a batch of data and it automatically wraps around the end of the data. 在生成器函数中,您将使用
yield
关键字while True:
循环内执行迭代,因此Keras每次调用生成器时,它都会获取一批数据,并自动环绕数据的末尾。
But in a Sequence, there is an index
parameter to the __getitem__
function, so no iteration or yield
is required, this is performed by Keras for you. 但是在序列中,
__getitem__
函数有一个index
参数,因此不需要迭代或yield
,这是Keras为您执行的。 This is made so the sequence can run in parallel using multiprocessing, which is not possible with old generator functions. 这样可以使序列可以使用多重处理并行运行,而这对于旧的生成器函数是不可能的。
So you are doing things the right way, there is no change needed. 因此,您以正确的方式行事,无需任何更改。
Example of generator in Keras
: Keras
的发电机示例:
def datagenerator(images, labels, batchsize, mode="train"):
while True:
start = 0
end = batchsize
while start < len(images):
# load your images from numpy arrays or read from directory
x = images[start:end]
y = labels[start:end]
yield x, y
start += batchsize
end += batchsize
Keras wants you to have the infinite loop running in the generator. Keras希望您让生成器中运行无限循环。
If you want to learn about Python generators, then the link in the comments is actually a good place to start. 如果您想了解Python生成器,那么注释中的链接实际上是一个不错的起点。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.