简体   繁体   中英

How to write python decorator that updates keyword argument?

The goal is to write a decorator that updates one keyword argument of the wrapped function. In the following code wrapper attempts to update kwarg1 :

import inspect                                                                                                                                                                                         [0/300678]
from functools import wraps

def override_me(arg, kwarg1="default kwarg1", kwarg2="default kwarg2"):
    print(f"override_me {arg} kwarg1={kwarg1} kwarg2={kwarg2}")

def append_kwarg1(func):
    original_kwarg1_default = (
        inspect.signature(func).parameters["kwarg1"].default
    )

    @wraps(func)
    def wrapper(*args, kwarg1=original_kwarg1_default, **kwargs):
        func(*args, kwarg1=kwarg1 + "_patched!", **kwargs)

    return wrapper

override_me = append_kwarg1(override_me)


override_me("passed_arg")
override_me("passed_arg", kwarg1="passed_kwarg1_named")
override_me("passed_arg", "passed_kwarg1_as_arg") # TypeError: override_me() got multiple values for argument 'kwarg1'

This, however, fails when kwarg1 is passed as a positional argument.

Edit: Clarifications as pointed in comments: override_me signature cannot be changed (think: external module).

A working solution based on the fact that even a keyword argument has an index. Then depending on the length of args passed by caller we can determine if the argument of interest has been passed positionally or by keyword and update it either in args or kwargs :


import inspect                                                                                                                                                                                         
from functools import wraps

def override_me(arg, kwarg1="default kwarg1", kwarg2="default kwarg2"):
    print(f"override_me {arg} kwarg1={kwarg1} kwarg2={kwarg2}")

def append_kwarg1(func):
    params = inspect.signature(func).parameters
    kwarg1_index = next(
        x[0] for x in zip(range(len(params)), params.items()) if x[1][0] == "kwarg1"
    )

    def update(v):
        return v + "_patched!"

    @wraps(func)
    def wrapper(*args, **kwargs):
        if len(args) > kwarg1_index:
            args = (
                args[:kwarg1_index]
                + (update(args[kwarg1_index]),)
                + args[kwarg1_index + 1 :]
            )
            func(*args, **kwargs)
        else:
            kwargs["kwarg1"] = update(kwargs.get("kwarg1", params["kwarg1"].default))
            func(*args, **kwargs)

    return wrapper


override_me = append_kwarg1(override_me)

override_me("passed_arg", kwarg1="passed_kwarg1_named")
override_me("passed_arg")
override_me("passed_arg", "passed_kwarg1_as_arg")
override_me("passed_arg", "passed_kwarg1_as_arg", "passed_kwarg2_as_arg")

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM