简体   繁体   English

使用 Pytorch model 参数进行计算时出现结果类型转换错误

[英]Result type cast error when doing calculations with Pytorch model parameters

When I ran the code below:当我运行下面的代码时:

import torchvision

model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
    params[var] *= 0.1

a RuntimeError was reported:报告了 RuntimeError:

RuntimeError: result type Float can't be cast to the desired output type Long

But when I changed params[var] *= 0.1 to params[var] = params[var] * 0.1 , the error disappears.但是当我将params[var] *= 0.1更改为params[var] = params[var] * 0.1时,错误消失了。

Why would this happen?为什么会这样?

I thought params[var] *= 0.1 had the same effect as params[var] = params[var] * 0.1 .我认为params[var] *= 0.1params[var] = params[var] * 0.1 0.1 具有相同的效果。

First, let us know the first long-type parameter in dens.net201 , you will find the features.norm0.num_batches_tracked which indicates the number of mini-batches during training used to calculate the mean and variance if there is BatchNormalization layer in the model. This parameter is a long-type number and cannot be float type because it behaves like a counter .首先,让我们知道 dens.net201 中的第一个 long 类型参数,如果在dens.net201中有 BatchNormalization 层,你会发现features.norm0.num_batches_tracked表示训练期间用于计算均值和方差的小批量的数量. This parameter is a long-type number and cannot be float type because it behaves like a counter

Second, in PyTorch, there are two types of operations:二、PyTorch中,有两种操作:

  • Non-Inplace operations: you assign the new output after calculation to a new copy from the variable, eg x = x + 1 or x = x / 2. The memory location of x before assignment not equal to the memory location after assignment because you have a copy from the original variable. Non-Inplace operations:你把计算后的新的output赋值给变量的新副本,比如x = x + 1 or x = x / 2。赋值前x的memory位置不等于赋值后的memory位置,因为你有原始变量的副本。
  • Inplace operations: when the calculations directly applied to the original copy of the variable without making any copy here eg x += 1 or x /= 2.就地操作:当计算直接应用于变量的原始副本而不在此处进行任何复制时,例如 x += 1 或 x /= 2。

Let's move to your example to understand what happened:让我们转到您的示例以了解发生了什么:

  1. Non-Inplcae operation:非现场操作:

     model = torchvision.models.dens.net201(num_classes=10) params = model.state_dict() name = 'features.norm0.num_batches_tracked' print(id(params[name])) # 140247785908560 params[name] = params[name] + 0.1 print(id(params[name])) # 140247785908368 print(params[name].type()) # changed to torch.FloatTensor
  2. Inplace operation:就地操作:

     print(id(params[name])) # 140247785908560 params[name] += 1 print(id(params[name])) # 140247785908560 print(params[name].type()) # still torch.LongTensor params[name] += 0.1 # you want to change the original copy type to float,you got an error

Finally, some remarks:最后,一些评论:

  • In-place operations save some memory, but can be problematic when computing derivatives because of an immediate loss of history.就地操作节省了一些 memory,但在计算导数时可能会出现问题,因为会立即丢失历史记录。 Hence, their use is discouraged.因此,不鼓励使用它们。 Source 来源
  • You should be cautious when you decide to use in-place operations since they overwrite the original content.当您决定使用就地操作时应谨慎,因为它们会覆盖原始内容。
  • If you use pandas, this is a bit similar to the inplace=True in pandas:).如果使用pandas,这个有点类似于pandas中的inplace inplace=True :)。

This is a good resource to read more about in-place operation source and read also this discussion source .这是一个很好的资源,可以阅读有关就地操作的更多信息,也可以阅读此讨论

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM