简体   繁体   中英

Injecting “global imports” into Python functions

Short, but complete, summary

I want to allow users of my function (a class factory) to inject/overwrite global imports when using my function (longer explanation of rationale below). But there are about 10 different variables that could be passed in and it adds a number of very repetitive lines to the code. (granted, also makes it more complicated to call too :P) Right now, I'm doing something like the following (just simplifying all of this). To make it runnable, I'm using a dummy class, but in the actual script I'd be using import pkg1 , etc. Figured this was clearer and shorter than a class factory, etc.

class Dummy(object): pass

pkg1, pkg2 = Dummy(), Dummy()
pkg1.average = lambda *args : sum(args) / len(args)
pkg2.get_lengths = lambda *args : map(len, args)


def get_average(*args, **kwargs):
    average = kwargs.get("average") or pkg1.average
    get_lengths = kwargs.get("get_lengths") or pkg2.get_lengths
    return average(*get_lengths(*args))

adjusted_length = lambda *args: map(len, args) + [15]
print get_average([1,2], [10, 4, 5, 6]) == 3 # True
print get_average([1,2], [10, 4, 5, 6], get_lengths=adjusted_length) == 7 # True

Related SO questions

This stack overflow post: Modifying locals in Python , seemed particularly relevant and initially I wanted to just overwrite locals by storing to the locals dictionary but (1) it didn't seem to work, and (2) it seems like it was a bad idea. So, I'm wondering if there's another way to do it.

This one looked promising ( Adding an object to another module's globals in python ), but I'm not really sure how to access the globals for the current file in the same way as a module. (and this question - python: mutating `globals` to dynamically put things in scope - doesn't really apply, since I'm (ultimately) using this to define classes).

I guess I could wrap everything in an exec statement (like this post - globals and locals in python exec() ), but that's both fiddly and means that it's much harder to do error checking/linting/etc.

So here's what I'd like to do. (NOTE: I would have used from pkg1 import average AND from pkg2 import get_lengths but I wanted the example to be clearer (need to copy pkg1 and pkg2 above to run this))

average = pkg1.average
get_lengths = pkg2.get_lengths

def get_average(*args, **kwargs):
    localvars = locals()
    for k in ("get_lengths", "average"):
        if kwargs.get(k, None) and kwargs[k] is not None:
            localvars[k] = kwargs[k]
    return average(*get_lengths(*args))

print get_average([1,2], [10, 4, 5, 6]) == 3 #True
print get_average([1,2], [10, 4, 5, 6], get_lengths=adjusted_length) == 7 # False, is 3

Rationale for my specific use-case

Right now, I'm trying to write a dynamically-generated class factory (to use as an SQLAlchemy mixin), but I want to allow users of my class to pass in alternate constructors, so they can use SQLAlchemy adapters, etc.

For example, Flask-SQLAlchemy provides the same interface as SQLAlchemy, but provides a custom object/class ( db ) that wraps around all the SQLAlchemy objects to provide more features.

You could use arguments with default values to pass functions in. This is effectively what you are doing but cleaner. I've used lists as a single argument instead of *args because it is easier to deal with when you have other arguments. You'll have to enclose your lists in a tuple to pass them in to get_average .

The builtin function sorted works like this so it should be easy for Python programmers to understand.

get_average(lists, average=pkg1.average, get_lengths=pkg2.get_lengths):
    return average(*get_lengths(*lists))

print get_average(([1,2], [10, 4, 5, 6]))
print get_average(([1,2], [10, 4, 5, 6]), get_lengths=adjusted_length)

If you have many keyword arguments you could package them in an object:

class GetAverageContext(object):
    def __init__(self, average=pkg1.average, get_lengths=pkg2.get_lengths):
        self.average = average
        self.get_lengths = get_lengths

DefaultGetAverageContext = GetAverageContext()

def get_average(lists, context=DefaultGetAverageContext):
    return context.average(*context.get_lengths(*lists))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM