[英]Result type cast error when doing calculations with Pytorch model parameters
When I ran the code below:当我运行下面的代码时:
import torchvision
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
params[var] *= 0.1
a RuntimeError was reported:报告了 RuntimeError:
RuntimeError: result type Float can't be cast to the desired output type Long
But when I changed params[var] *= 0.1
to params[var] = params[var] * 0.1
, the error disappears.但是当我将
params[var] *= 0.1
更改为params[var] = params[var] * 0.1
时,错误消失了。
Why would this happen?为什么会这样?
I thought params[var] *= 0.1
had the same effect as params[var] = params[var] * 0.1
.我认为
params[var] *= 0.1
与params[var] = params[var] * 0.1
0.1 具有相同的效果。
First, let us know the first long-type parameter in dens.net201
, you will find the features.norm0.num_batches_tracked
which indicates the number of mini-batches during training used to calculate the mean and variance if there is BatchNormalization layer in the model. This parameter is a long-type number and cannot be float type because it behaves like a counter
.首先,让我们知道 dens.net201 中的第一个 long 类型参数,如果在
dens.net201
中有 BatchNormalization 层,你会发现features.norm0.num_batches_tracked
表示训练期间用于计算均值和方差的小批量的数量. This parameter is a long-type number and cannot be float type because it behaves like a counter
。
Second, in PyTorch, there are two types of operations:二、PyTorch中,有两种操作:
Let's move to your example to understand what happened:让我们转到您的示例以了解发生了什么:
Non-Inplcae operation:非现场操作:
model = torchvision.models.dens.net201(num_classes=10) params = model.state_dict() name = 'features.norm0.num_batches_tracked' print(id(params[name])) # 140247785908560 params[name] = params[name] + 0.1 print(id(params[name])) # 140247785908368 print(params[name].type()) # changed to torch.FloatTensor
Inplace operation:就地操作:
print(id(params[name])) # 140247785908560 params[name] += 1 print(id(params[name])) # 140247785908560 print(params[name].type()) # still torch.LongTensor params[name] += 0.1 # you want to change the original copy type to float,you got an error
Finally, some remarks:最后,一些评论:
inplace=True
in pandas:).inplace=True
:)。 This is a good resource to read more about in-place operation source and read also this discussion source .这是一个很好的资源,可以阅读有关就地操作源的更多信息,也可以阅读此讨论源。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.