简体   繁体   中英

Given a function that outputs a list, is it possible in Python to extract a function for each component?

For example, let's say I have a vector function mapping R2 to R2, for instance:

fun = lambda x1, x2: [x1**2 + 1, x2**2 - x1]

I'd like something that allows me to do this:

for f in components(fun):
    print(f(2,3))  # Print first 5, then 7

Note: I'm not asking how to iterate over the components of one out, which is trivial (for val in f(2,3):), but how to iterate over the functions computing each component in the output. Is this possible?

Well you could do some trick, although you will need to explicitly state the expected number of components, since there is no way to tell how many outputs a Python function will have (unless you do something like "probing" the function with test values, which is also a possibility but more complex):

def components(fun, n):
    for i in range(n):
        yield lambda *args, **kwargs: fun(*args, **kwargs)[i]

Then your loop could be:

for f in components(fun, 2):
    print(f(2,3))

If you want to avoid repeating computations, you can use some kind of memoization. In Python 3 you can use lru_cache :

from functools import lru_cache

def components(fun, n):
    # You can tune lru_cache with a maxsize parameter
    fun = lru_cache()(fun)
    for i in range(n):
        yield lambda *args, **kwargs: fun(*args, **kwargs)[i]

I don't think it's possible to do what you ask, since in this case there are not two different functions. It is one function with a list as output, taking the two arguments as input.

You could make two different functions and concatenate this in one big function to get what you acces to two separate functions:

fun_1 = lambda x1: x1**2 + 1
fun_2 = lambda x1, x2: x2**2 - x1
fun = lambda x1, x2: [fun_1(x1), fun_2(x1, x2)]

I found out sympy converts such functions into a list of sympy expressions, so I ended up doing the following:

from inspect import signature
import sympy

def components(function):
    # Determine number of argument function takes
    n_args = len(signature(function).parameters)
    # Allocate n symbols like x1, x2, ..., xn
    symbols = [sympy.Symbol("x"+str(i)) for i in range(1, n_args+1)]
    # Get list of expressions for the components of the input function
    expression_list = function(*symbols)

    # Convert each expression into a function and yield
    for expr in expression_list:
        yield lambda *args: sympy.lambdify(symbols, expr)(*args)

    return None

fun = lambda x1, x2: [x1**2 + 1, x2**2 - x1]

for f in components(fun):
    print(f(2,3))  # prints 5, 7

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM