簡體   English   中英

為什么將輸入和模型強制轉換為float16不起作用?

[英]Why casting input and model to float16 doesn't work?

我正在嘗試將輸入和深度學習模型更改為flaot16,因為我使用的是T4 GPU,它們在fp16上的運行速度更快。 這是代碼的一部分:我首先有我的模型,然后制作了一些虛擬數據點,以期弄清楚首先要弄清楚的數據轉換(我在整個批次中都運行了該錯誤,並且得到了相同的錯誤)。

model = CRNN().to(device)
model = model.type(torch.cuda.HalfTensor)

data_recon = torch.from_numpy(data_recon)
data_truth = torch.from_numpy(data_truth)

dummy = data_recon[0:1,:,:,:,:] # Gets just one batch
dummy = dummy.to(device)
dummy = dummy.type(torch.cuda.HalfTensor)

model(dummy)

這是我得到的錯誤:

> --------------------------------------------------------------------------- 
RuntimeError                              Traceback (most recent call
> last) <ipython-input-27-1fe8ecc524aa> in <module>
> ----> 1 model(dummy)
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> <ipython-input-12-06f39f9304a1> in forward(self, inputs, test)
>      57 
>      58             net['t%d_x0'%(i-1)] = net['t%d_x0'%(i-1)].view(times, batch, self.filter_size, width,
> height)
> ---> 59             net['t%d_x0'%i] = self.bcrnn(inputs, net['t%d_x0'%(i-1)], test)
>      60             net['t%d_x0'%i] = net['t%d_x0'%i].view(-1, self.filter_size, width, height)
>      61 
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> <ipython-input-11-b687949e9ce5> in forward(self, inputs,
> input_iteration, test)
>      31         hidden = initial_hidden
>      32         for i in range(times):
> ---> 33             hidden = self.CRNN(inputs[i], input_iteration[i], hidden)
>      34             output_forward.append(hidden)
>      35         output_forward = torch.cat(output_forward)
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> <ipython-input-10-15c0b221226b> in forward(self, inputs,
> hidden_iteration, hidden)
>      23     def forward(self, inputs, hidden_iteration, hidden):
>      24         in_to_hid = self.i2h(inputs)
> ---> 25         hid_to_hid = self.h2h(hidden)
>      26         ih_to_ih = self.ih2ih(hidden_iteration)
>      27 
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py in
> forward(self, input)
>     336                             _pair(0), self.dilation, self.groups)
>     337         return F.conv2d(input, self.weight, self.bias, self.stride,
> --> 338                         self.padding, self.dilation, self.groups)
>     339 
>     340 
> 
> RuntimeError: Input type (torch.cuda.FloatTensor) and weight type
> (torch.cuda.HalfTensor) should be the same

檢查您對CRNN的實現。 我的猜測是您在模型中存儲了“隱藏”狀態張量,但不是作為“緩沖區”而是作為常規張量存儲。 因此,將模型轉換為float16時,隱藏狀態仍為float32並導致此錯誤。

嘗試將隱藏狀態存儲為模塊中的register_buffer有關更多信息,請參見register_buffer )。
另外,您可以通過重載模型的.to .to()方法將模塊中的任何成員張量顯式轉換為float16。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM