[英]Grid search function in Python
我正在尝试编写一个参数搜索函数来循环遍历其中一个参数,并重复调用具有与我正在搜索的其他参数相同的所有其他参数的函数。 这是一些示例代码:
def worker1(a, b, c):
return a + b + c
def worker2(d, e, f):
return d * e * f
def search(model, params):
res = []
# Loop over one of the parameters and repeatedly append to res
if model == 1:
res.append(worker1(**params))
elif model == 2:
res.append(worker2(**params))
return res
params = dict(a=1, b=2, c=3)
print search(1, params)
我有两个工人,根据我传递给search()
的model
标志的值调用它们。 我在这里试图解决的问题是在if语句上编写一个循环(在代码中注释),以通过仅更改一个参数来重复调用say worker1
。 我希望我的代码灵活-有时我想循环a
并保持b
和c
相同,但是有时我想循环b
并保持a
和c
相同。
我会打开任何建议的解决方案,但我想我会在params
字典中指定搜索参数。 例如,要循环a
1,2,3,4,我会说:
`params = dict(a=[1,2,3,4], b=2, c=3)`
如果我不必修改worker1
和worker2
的代码, worker1
也worker2
。
谢谢!
您也许可以使用itertools.product
来调用带有所有参数组合的工作人员:
http://docs.python.org/2/library/itertools.html#itertools.product
例如
from itertools import product
def worker1(a, b, c):
return a + b + c
def worker2(d, e, f):
return d * e * f
def search(model, *params):
res = []
# Loop over one of the parameters and repeatedly append to res
for current_params in product(*params):
if model == 1:
res.append(worker1(*current_params))
elif model == 2:
res.append(worker2(*current_params))
return res
print search(1, [1,2,3,4], [2], [3])
# more complicated combinations are possible:
print search(1, [1,2,3,4], [2,7,9], [3,13,23,43])
我避免使用关键字参数,因为您的工作程序函数采用了不同名称的args,所以这没有多大意义。
我假设您的辅助函数实际上看起来并不像上面的那样,您可以使用内置的sum
和reduce
函数进一步简化代码。
我不确定我是否理解这个问题。 检查这是否是您想要的(省略了model
参数):
>>> def worker1(a, b, c):
return a + b + c
>>> def search(params):
params = params.values()
var_param = filter(lambda p: type(p) == list, params)[0]
other_params = filter(lambda p: p != var_param, params)
return [worker1(x, *other_params) for x in var_param]
>>> search({'a':2, 'b':[3,4,5], 'c':3})
[8, 9, 10]
假设:
worker1()
参数是可交换的(顺序无关紧要)。 list
在上面的示例中, b
是您要循环遍历的变量参数
更新:
如果要保留函数worker1
的参数顺序:
def search(params):
params = params.items()
var_param = filter(lambda t: type(t[1]) == list, params)[0]
other_params = filter(lambda t: t != var_param, params)
var_param_key = var_param[0]
var_param_values = var_param[1]
return [worker1(**dict([(var_param_key, x)] + other_params)) for x in var_param_values]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.