简体   繁体   中英

Decorator to alter function behavior

I've found that I have two unrelated functions that implement identical behavior in different ways. I'm now wondering if there's a way, via decorators probably, to deal with this efficiently, to avoid writing the same logic over and over if the behavior is added elsewhere.

Essentially I have two functions in two different classes that have a flag called exact_match . Both functions check for some type of equivalence in the objects that they are members of. The exact_match flag forces to function to check float comparisons exactly instead of with a tolerance. You can see how I do this below.

def is_close(a, b, rel_tol=1e-09, abs_tol=0.0):
    return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)


def _equal(val_a, val_b):
"""Wrapper for equality test to send in place of is_close."""
    return val_a == val_b

    @staticmethod
def get_equivalence(obj_a, obj_b, check_name=True, exact_match=False):
    equivalence_func = is_close
    if exact_match:
        # If we're looking for an exact match, changing the function we use to the equality tester.
        equivalence_func = _equal

    if check_name:
        return obj_a.name == obj_b.name

    # Check minimum resolutions if they are specified
    if 'min_res' in obj_a and 'min_res' in obj_b and not equivalence_func(obj_a['min_res'], obj_b['min_res']):
        return False

    return False

As you can see, standard procedure has us use the function is_close when we don't need an exact match, but we swap out the function call when we do. Now another function needs this same logic, swapping out the function. Is there a way to use decorators or something similar to handle this type of logic when I know a specific function call may need to be swapped out?

No decorator needed; just pass the desired function as an argument to get_equivalence (which is now little more than a wrapper that applies the argument).

def make_eq_with_tolerance(rel_tol=1e-09, abs_tol=0.0):
    def _(a, b):
        return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
    return _    

# This is just operator.eq, by the way
def _equal(val_a, val_b-):
    return val_a == val_b

def same_name(a, b):
    return a.name == b.name

Now get_equivalence takes three arguments: the two objects to compare and a function that gets called on those two arguments.

@staticmethod
def get_equivalence(obj_a, obj_b, equivalence_func):

    return equivalence_func(obj_a, obj_b)

Some example calls:

get_equivalence(a, b, make_eq_with_tolerance())
get_equivalence(a, b, make_eq_with_tolerance(rel_tol=1e-12))  # Really tight tolerance
get_equivalence(a, b, _equal)
get_equivalence(a, b, same_name)

I came up with an alternative solution that is perhaps less correct but answers let's me solve the problem as I originally wanted to.

My solution uses a utility class that can be used as a member of a class or as a mixin for the class to provide the utility functions in a convenient way. Below, the functions _equals and is_close are defined elsewhere as their implementations is besides the point.

class EquivalenceUtil(object):
    def __init__(self, equal_comparator=_equals, inexact_comparator=is_close):
        self.equals = equal_comparator
        self.default_comparator = inexact_comparator

    def check_equivalence(self, obj_a, obj_b, exact_match=False, **kwargs):
        return self.equals(obj_a, obj_b, **kwargs) if exact_match else self.default_comparator(obj_a, obj_b, **kwargs)

It's a simple class that can be used like so:

class BBOX(object):
    _equivalence = EquivalenceUtil()

    def __init__(self, **kwargs):
        ...

    @classmethod
    def are_equivalent(cls, bbox_a, bbox_b, exact_match=False):
        """Test for equivalence between two BBOX's."""
        bbox_list = bbox_a.as_list
        other_list = bbox_b.as_list
        for _index in range(0, 3):
            if not cls._equivalence.check_equivalence(bbox_list[_index], 
                                                      other_list[_index], 
                                                      exact_match=exact_match):
            return False
        return True

This solution is more opaque to the user about how things are checked behind the scenes, which is important for my project. Additionally it is pretty flexible and can be reused within a class in multiple places and ways, and easily added to a new class.

In my original example the code can turn into this:

class TileGrid(object):

    def __init__(self, **kwargs):
        ...

    @staticmethod
    def are_equivalent(grid1, grid2, check_name=False, exact_match=False):
        if check_name:
            return grid1.name == grid2.name
        # Check minimum resolutions if they are specified
        if 'min_res' in grid1 and 'min_res' in grid2 and not cls._equivalence.check_equivalence(grid1['min_res'], grid2['min_res'], exact_match=exact_match):
            return False

        # Compare the bounding boxes of the two grids if they exist in the grid
        if 'bbox' in grid1 and 'bbox' in grid2:
            return BBOX.are_equivalent(grid1.bbox, grid2.bbox, exact_mach=exact_match)

        return False

I can't recommend this approach in the general case, because I can't help but feel there's some code smell to it, but it does exactly what I need it to and will solve a great many problems for my current codebase. We have specific requirements, this is a specific solution. The solution by chepner is probably best for the general case of letting the user decide how a function should test equivalence.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM