简体   繁体   中英

using python itertools to manage nested for loops

I am trying to use itertools.product to manage the bookkeeping of some nested for loops, where the number of nested loops is not known in advance. Below is a specific example where I have chosen two nested for loops; the choice of two is only for clarity, what I need is a solution that works for an arbitrary number of loops.

This question provides an extension/generalization of the question appearing here: Efficient algorithm for evaluating a 1-d array of functions on a same-length 1d numpy array

Now I am extending the above technique using an itertools trick I learned here: Iterating over an unknown number of nested loops in python

Preamble:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [idx1, idx2]

func_table  = []
for items in product(*joint):
    f = trivial_functional(*items)
    func_table.append(f)

At the end of the above itertools loop, I have a 12-element, 1-d array of functions, func_table , each element having been built from the trivial_functional.

Question:

Suppose I am given a pair of integers, (i_1, i_2), where these integers are to be interpreted as the indices of idx1 and idx2 , respectively. How can I use itertools.product to determine the correct corresponding element of the func_table array?

I know how to hack the answer by writing my own function that mimics the itertools.product bookkeeping, but surely there is a built-in feature of itertools.product that is intended for exactly this purpose?

I don't know of a way of calculating the flat index other than doing it yourself. Fortunately this isn't that difficult:

def product_flat_index(factors, indices):
  if len(factors) == 1: return indices[0]
  else: return indices[0] * len(factors[0]) + product_flat_index(factors[1:], indices[1:])

>> product_flat_index(joint, (2, 1))
9

An alternative approach is to store the results in a nested array in the first place, making translation unnecessary, though this is more complex:

from functools import reduce
from operator import getitem, setitem, itemgetter

def get_items(container, indices):
  return reduce(getitem, indices, container)

def set_items(container, indices, value):
  c = reduce(getitem, indices[:-1], container)
  setitem(c, indices[-1], value)

def initialize_table(lengths):
  if len(lengths) == 1: return [0] * lengths[0]
  subtable = initialize_table(lengths[1:])
  return [subtable[:] for _ in range(lengths[0])]

func_table = initialize_table(list(map(len, joint)))
for items in product(*map(enumerate, joint)):
  f = trivial_functional(*map(itemgetter(1), items))
  set_items(func_table, list(map(itemgetter(0), items)), f)

>>> get_items(func_table, (2, 1)) # same as func_table[2][1]
<function>

So numerous answers were quite useful, thanks to everyone for the solutions.

It turns out that if I recast the problem slightly with Numpy, I can accomplish the same bookkeeping, and solve the problem I was trying to solve with vastly improved speed relative to pure python solutions. The trick is just to use Numpy's reshape method together with the normal multi-dimensional array indexing syntax.

Here's how this works. We just convert func_table into a Numpy array, and reshape it:

func_table = np.array(func_table)
component_dimensions = [len(idx1), len(idx2)]
func_table = np.array(func_table).reshape(component_dimensions)

Now func_table can be used to return the correct function not just for a single 2d point, but for a full array of 2d points:

dim1_pts = [3,1,2,1,3,3,1,3,0]
dim2_pts = [0,1,2,1,2,0,1,2,1]
func_array = func_table[dim1_pts, dim2_pts]

As usual, Numpy to the rescue!

This is a little messy, but here you go:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [enumerate(idx1), enumerate(idx2)]

func_map  = {}
for indexes, items in map(lambda x: zip(*x), product(*joint)):
    f = trivial_functional(*items)
    func_map[indexes] = f

print(func_map[(2, 0)](5)) # 40 = (3+5)*5

I'd suggest using enumerate() in the right place:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [idx1, idx2]

func_table  = []
for items in product(*joint):
     f = trivial_functional(*items)
     func_table.append(f)

From what I understood from your comments and your code, func_table is simply indexed by the occurence of a certain input in the sequence. You can access it back again using:

for index, items in enumerate(product(*joint)):
    # because of the append(), index is now the 
    # position of the function created from the 
    # respective tuple in join()
    func_table[index](some_value)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM