简体   繁体   中英

How to compose a nested function g=fn(…(f3(f2(f1()))…) from a list of functions [f1, f2, f3,…fn]

Question

Is there a readily available Pythonic way to compose a multiple nested function g = f3(f2(f1())) from a list of functions [f1, f2, f3] where there are more functions in a list.

If there are a few, I may do:

g = lambda x: f3(f2(f1(x)))

However when I have dozens of functions eg layers in a deep neural network, it is un-manageable. Prefer not creating another function to compose g but finding an available way.


Update

Based on the answer from @Chris. For sequential neural network layers [ batchnorm, matmul, activation, softmaxloss ] , each of which has a forward(X) method to calculate its output to the next layer, the loss function L and loss would be:

L = reduce(lambda f, g: lambda X: g(f(X)),  [ layer.forward for layer in layers ] )   # Loss function
network_loss = L(X)

One way using functools.reduce :

from functools import reduce

f1 = lambda x: x+1
f2 = lambda x: x*2
f3 = lambda x: x+3
funcs = [f1, f2, f3]

g = reduce(lambda f, g: lambda x: g(f(x)), funcs)

Output:

g(1)==7 # ((1+1) * 2) + 3
g(2)==9 # ((2+1) * 2) + 3

Insight:

functools.reduce will chain its second argument ( funcs here) according to its first argument ( lambda here).

That being said, it will start chaining f1 and f2 as f_21(x) = f2(f1(x)) , then f3 and f_21 as f3(f_21(x)) which becomes g(x) .

One problem with the reduce -baed approach is that you introduce O(n) additional function calls. An alternative is to define a single function that remembers the functions to compose; when called, it simply calls each function in sequence on the given argument.

def compose(*args):
    """compose(f1, f2, ..., fn) == lambda x: fn(...(f2(f1(x))...)"""

    def _(x):
        result = x
        for f in args:
            result = f(result)
        return result
    return _

You can implement it yourself, but you could also try a module named compose which implements this, written by @mtraceur. It takes care to handle various details, such as correct function signature forwarding.

pip install compose
from compose import compose

def doubled(x):
    return 2*x

octupled = compose(doubled, doubled, doubled)

print(octupled(1))
# 8

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM