简体   繁体   中英

Python - Generator within another generator

I have to use the output of a generator in another generator.

Below is the code -

Here Generator 2 is called within generator 1 and the final output is received from generator 2.

I am trying to use something like below, can anyone suggest a solution?

def sub_gen(data): for r in res_gen(): yield each train_datagen(r)

Generator 1

def res_gen (num_threads = 4 ):
    while (True) :
      for i in range(0,len(file_list),num_threads):
        # use multi-process to speed up
        res = []
        p = Pool(num_threads)
        patch = p.map(gen_patches,file_list[i:min(i+num_threads,len(file_list))])
        #patch = p.map(gen_patches,file_list[i:i+num_threads])
        for x in patch:
            res += x
        res1 = np.array(res)
        res1 = res1.reshape((res1.shape[0],res1.shape[1],res1.shape[2],1))
        res1 = res1.astype('float32')/255.0
        yield res1

Generator 2

def train_datagen(res1, batch_size=4):
    indices = list(range(res1.shape[0]))
    while(True):
        np.random.shuffle(indices)    # shuffle
        for i in range(0, len(indices), batch_size):
            ge_batch_y = res1[indices[i:i+batch_size]]
            noise =  np.random.normal(0, sigma/255.0, ge_batch_y.shape)   
            #noise =  K.random_normal(ge_batch_y.shape, mean=0, stddev=sigma/255.0)
            ge_batch_x = ge_batch_y + noise  # input image = clean image + noise
            yield ge_batch_x, ge_batch_y

I'm pretty sure the only issue in your short sub_gen generator is that you've written yield each instead of yield from . The latter expects an iterable value after it (often another generator), and it yields each value just like an explicit for loop

So I think your code should be:

def sub_gen(data) :
  for r in res_gen() :
      yield from train_datagen(r)

Lets test this with much simpler generator functions:

def foo():
    yield [1, 2]
    yield [3, 4]

def bar(iterable):
    for x in iterable:
        yield 10+x
        yield 20+x

def baz():
    for iterable in foo():
        yield from bar(iterable)

for value in baz():   # use the top-level generator!
    print(value)      # prints 11, 21, 12, 22, 13, 23, 14, 24 each on its own line

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM