简体   繁体   中英

Python decorator with optional argument (which is function)

Note: I know that decorators with optional argument contain three nested function. But optional argument here is function itself. Please go through the complete post before you mark this as duplicate. I already tried all the tricks for decorators with optional argument, but I could not found any that takes function as argument.

I am having a decorator for wrapping error:

def wrap_error(func):
    from functools import wraps

    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            import sys

            exc_msg = traceback.format_exception(*sys.exc_info())
            raise MyCustomError(exc_msg)

    return wrapper

If some function raises any exception, it wraps the error. This wrapper is used like:

@wrap_error
def foo():
    ...

Now I want to modify this wrapper with additional callback function which will be optional. And I want this wrapper to be used as:

@wrap_error
def foo():
    ...

@wrap_error(callback)
def foo():
    ...

I know how to write decorators with optional arguments (in case passed argument is not function, based on isfunction(func) check within wrapper). But I am not sure how to handle this case.

Note: I can not use @wrap_error() instead of @wrap_error . This wrapper is used in multiple number of packages, and it is not possible to update the change in all

Here is the blocker : Consider the wrapper as:

@wrap_error(callback)               --->       foo = wrap_error(callback)(foo)
def foo():
    ...

So, by the time wrap_error(foo) is executed, we do not know whether there will be any callback function for execution after that or not (in case we use just @wrap_error instead of @wrap_error(callback) ).

If there is no (callback) , wrapping function within wrap_error will return func(*args. **kwargs) so that I can raise exception. Else we have to return func so that it is called at next step, and if func() raises the exception, we call callback() in except block.

To summarise the problem before attempting to answer it, you want a decorator that works correctly in both of the following contexts:

@decorator  # case 1
def some_func(...):
    ...

@decorator(some_callback)  # case 2
def some_func(...):
    ...

or, to unroll the @ syntax to clarify things:

some_func = decorator(some_func)  # case 1

some_func = decorator(some_callback)(some_func)  # case 2

The tricky issue here, as I see it, is that it's very hard for decorator to tell the difference between some_func and some_callback (and therefore between cases 1 and 2); both are (presumably) just callable objects.


One potential solution is to provide named arguments:

# imports at top of file, not in function definitions
from functools import wraps
import sys

def decorator(func=None, callback=None):
    # Case 1
    if func is not None:
        @wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)  # or whatever
        return wrapper
    # Case 2
    elif callback is not None: 
        def deco(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                return callback(f(*args, **kwargs))  # or whatever
            return wrapper
        return deco

This makes case 2 look slightly different:

@decorator(callback=some_callback)
def some_func(...):
    ...

But otherwise does what you want. Note that the option you say you can't use

@decorator()
def some_func(...):
    ...

won't work with this, as the decorator expects either func or callback to be supplied (it will return None otherwise, which isn't callable, so you'll get a TypeError ).

Since it is hard to tell decorator(func) from decorator(callback) , make two decorators:

from functools import wraps

class MyCustomError(Exception):
    def __init__(self):
        print('in MyCustomError')

# Common implementation
def wrap(func,cb=None):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            if cb is not None:
                cb()
            raise MyCustomError()
    return wrapper

# No parameters version
def wrap_error(func):
    return wrap(func)

# callback parameter version
def wrap_error_cb(cb):
    def deco(func):
        return wrap(func,cb)
    return deco

@wrap_error
def foo(a,b):
    print('in foo',a,b)
    raise Exception('foo exception')

def callback():
    print('in callback')

@wrap_error_cb(callback)
def bar(a):
    print('in bar',a)
    raise Exception('bar exception')

Check that foo and bar are correctly using functools.wraps :

>>> foo
<function foo at 0x0000000003F00400>
>>> bar
<function bar at 0x0000000003F00598>

Check that the wrapped functions work:

>>> foo(1,2)
in foo 1 2
in MyCustomError
Traceback (most recent call last):
  File "<interactive input>", line 1, in <module>
  File "C:\test.py", line 16, in wrapper
    raise MyCustomError()
MyCustomError
>>> bar(3)
in bar 3
in callback
in MyCustomError
Traceback (most recent call last):
  File "<interactive input>", line 1, in <module>
  File "C:\test.py", line 16, in wrapper
    raise MyCustomError()
MyCustomError

Updated

Here's a way to do it with the syntax you requested, but I think the above answer is clearer.

from functools import wraps

class MyCustomError(Exception):
    def __init__(self):
        print('in MyCustomError')

# Common implementation
def wrap(func,cb=None):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            if cb is not None:
                cb()
            raise MyCustomError()
    return wrapper

def wrap_error(func_or_cb):
    # If the function is tagged as a wrap_error_callback
    # return a decorator that returns the wrapped function
    # with a callback.
    if hasattr(func_or_cb,'cb'):
        def deco(func):
            return wrap(func,func_or_cb)
        return deco
    # Otherwise, return a wrapped function without a callback.
    return wrap(func_or_cb)

# decorator to tag callbacks so wrap_error can distinguish them
# from *regular* functions.
def wrap_error_callback(func):
    func.cb = True
    return func

### Examples of use

@wrap_error
def foo(a,b):
    print('in foo',a,b)
    raise Exception('foo exception')

@wrap_error_callback
def callback():
    print('in callback')

@wrap_error(callback)
def bar(a):
    print('in bar',a)
    raise Exception('bar exception')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM