简体   繁体   中英

Why casting input and model to float16 doesn't work?

I'm trying to change inputs and a deep learning model to flaot16, since I'm using T4 GPU and they work much faster with fp16. Here's part of the code: I first have my model and then made some dummy data point for the sake of figuring the data casting figured out first (I ran it with the whole batch and got the same error).

model = CRNN().to(device)
model = model.type(torch.cuda.HalfTensor)

data_recon = torch.from_numpy(data_recon)
data_truth = torch.from_numpy(data_truth)

dummy = data_recon[0:1,:,:,:,:] # Gets just one batch
dummy = dummy.to(device)
dummy = dummy.type(torch.cuda.HalfTensor)

model(dummy)

And here's the error I get:

> --------------------------------------------------------------------------- 
RuntimeError                              Traceback (most recent call
> last) <ipython-input-27-1fe8ecc524aa> in <module>
> ----> 1 model(dummy)
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> <ipython-input-12-06f39f9304a1> in forward(self, inputs, test)
>      57 
>      58             net['t%d_x0'%(i-1)] = net['t%d_x0'%(i-1)].view(times, batch, self.filter_size, width,
> height)
> ---> 59             net['t%d_x0'%i] = self.bcrnn(inputs, net['t%d_x0'%(i-1)], test)
>      60             net['t%d_x0'%i] = net['t%d_x0'%i].view(-1, self.filter_size, width, height)
>      61 
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> <ipython-input-11-b687949e9ce5> in forward(self, inputs,
> input_iteration, test)
>      31         hidden = initial_hidden
>      32         for i in range(times):
> ---> 33             hidden = self.CRNN(inputs[i], input_iteration[i], hidden)
>      34             output_forward.append(hidden)
>      35         output_forward = torch.cat(output_forward)
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> <ipython-input-10-15c0b221226b> in forward(self, inputs,
> hidden_iteration, hidden)
>      23     def forward(self, inputs, hidden_iteration, hidden):
>      24         in_to_hid = self.i2h(inputs)
> ---> 25         hid_to_hid = self.h2h(hidden)
>      26         ih_to_ih = self.ih2ih(hidden_iteration)
>      27 
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py
> in __call__(self, *input, **kwargs)
>     491             result = self._slow_forward(*input, **kwargs)
>     492         else:
> --> 493             result = self.forward(*input, **kwargs)
>     494         for hook in self._forward_hooks.values():
>     495             hook_result = hook(self, input, result)
> 
> /opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py in
> forward(self, input)
>     336                             _pair(0), self.dilation, self.groups)
>     337         return F.conv2d(input, self.weight, self.bias, self.stride,
> --> 338                         self.padding, self.dilation, self.groups)
>     339 
>     340 
> 
> RuntimeError: Input type (torch.cuda.FloatTensor) and weight type
> (torch.cuda.HalfTensor) should be the same

Check out your implementation of CRNN . My guess is that you have "hidden" state tensor stored in the model, but not as a "buffer" but just as a regular tensor. Therefore, when casting the model to float16 the hidden state remains float32 and causes you this error.

Try to store the hidden state as a register in the module (see register_buffer for more info).
Alternatively, you can explicitly cast to float16 any member tensor in the module by overloading the .to() method of your model.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM