[英]function as argument of another function
我有一个代码,其中一部分 function 作为参数传递给另一个 function。 当我运行此代码时,我收到此错误:
h() missing 1 required positional argument: 'id'
在这一部分中,我将 h function 称为另一个函数(引擎类的火车)的参数:
engine.train(h, train_loader, opt.epochs, optimizer)
h 是这段代码:
def h(sample, id):
inputs = utils.cast(sample[0], opt.dtype).detach()
targets = utils.cast(sample[1], 'long')
if opt.teacher_id != '':
if id==1:
y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, sample[2], range(opt.ngpu))
loss_groups = [v.sum() for v in loss_groups]
[m.add(v.item()) for m, v in zip(meters_at, loss_groups)]
return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha)+sum(loss_groups), y_s
if id==2:
y_s1, y_t1, loss_groups1 = utils.data_parallel(f1, inputs, params1, sample[2], range(opt.ngpu))
loss_groups1 = [v.sum() for v in loss_groups1]
[m.add(v.item()) for m, v in zip(meters_at, loss_groups1)]
return utils.distillation(y_s1, y_t1, targets, opt.temperature, opt.alpha)+sum(loss_groups1), y_s1
if id==3:
y_s2, y_t2, loss_groups2 = utils.data_parallel(f2, inputs, params2, sample[2], range(opt.ngpu))
loss_groups2 = [v.sum() for v in loss_groups2]
[m.add(v.item()) for m, v in zip(meters_at, loss_groups2)]
return utils.distillation(y_s2, y_t2, targets, opt.temperature, opt.alpha)+sum(loss_groups2), y_s2
else:
if id==1:
y = utils.data_parallel(f, inputs, params, sample[2], range(opt.ngpu))[0]
if id==2:
y = utils.data_parallel(f1, inputs, params1, sample[2], range(opt.ngpu))[0]
if id==3:
y = utils.data_parallel(f2, inputs, params2, sample[2], range(opt.ngpu))[0]
return F.cross_entropy(y, targets), y
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.