简体   繁体   English

Pytorch model 训练不使用正向

[英]Pytorch model training without using forward

I'm working on training CLIP model.我正在训练 CLIP model。 Here's the source code of the model https://github.com/openai/CLIP/blob/main/clip/model.py这是model https的源代码://github.com/openai/CLIP/blob/main/clip/model.py

Basically the CLIP object is constructed like this:基本上,CLIP object 的构造如下:

class CLIP(nn.module):
   ...
   def encode_image(self, image):
     return self.visual(image.type(self.dtype))

   def encode_text(self, text):
    x = ... 
    ...
    return x

   def forward(self, image, text):
     image_features = self.encode_image(image)
     text_features = self.encode_text(text)
     ...
     return logits_per_image, logits_per_text

The forward method except pair of image and text, since I want to repurpose CLIP for other task(text-text pairs), I'm not using forward from CLIP, but I'm using others method defined inside CLIP.除了图像和文本对之外的转发方法,因为我想将 CLIP 重新用于其他任务(文本-文本对),所以我没有使用来自 CLIP 的转发,但我使用的是在 CLIP 中定义的其他方法。 My training code look like this:我的训练代码如下所示:

for k in range(epoch):
  for batch in dataloader :
    x,y = batch
    y1 = model.encode_text(x[first_text_part])
    y2 = model.encode_text(x[second_text_part])
    <calculate loss, backward, step, etc>

The problem is, after 1 epoch, all the gradients turn out to be nan even though the loss is not nan.问题是,在 1 个 epoch 之后,即使损失不是 nan,所有梯度都变成 nan。
My suspicion is PyTorch only able to propagate the gradient through the forward method.我的怀疑是 PyTorch 只能通过正向方法传播梯度。
Some source says that forward is not that special ( https://discuss.pytorch.org/t/must-use-forward-function-in-nn-module/50943/3 ), but other source say coding with torch must use the forward ( https://stackoverflow.com/a/58660175/12082666 ).一些消息来源说 forward 并不是那么特别( https://discuss.pytorch.org/t/must-use-forward-function-in-nn-module/50943/3 ),但是其他消息来源说必须使用torch编码前进( https://stackoverflow.com/a/58660175/12082666 )。

The question is, can we train Pytorch network without using forward method?问题是,我们可以在不使用前向方法的情况下训练 Pytorch 网络吗?

The forward() in pytorch in nothing new. pytorch 中的forward()并不是什么新鲜事。 It just attaches the graph of your network when called.它只是在调用时附加您的网络图。 Backpropagation doesnt rely much on forward() because, the gradients are propagated through the graph.反向传播不太依赖 forward(),因为梯度是通过图传播的。

The only difference is that in pytorch source, forward is similar to call () method with all the hooks registered in nn.Module.唯一不同的是,在 pytorch 源码中,forward 类似于call () 方法,所有钩子都注册在 nn.Module 中。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM