简体   繁体   中英

Why are Pytorch transform functions not being differentiated with autograd?

I have been trying to write a set of transforms on input data. I also need the transforms to be differentiable to compute the gradients. However, gradients do not seem to be calculated for the resize, normalize transforms.

from torchvision import transforms

from torchvision.transforms import ToTensor

resize = transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None)

crop = transforms.CenterCrop(size=(224, 224))

normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))

img = torch.Tensor(images[30])

img.requires_grad = True

rgb = torch.dsplit(torch.Tensor(img),3)

transformed = torch.stack(rgb).reshape(3,100,100)

resized = resize.forward(transformed)

normalized = normalize.forward(resized)

image_features = clip_model.encode_image(normalized.unsqueeze(0).to(device))

text_features = clip_model.encode_text(text_inputs)

similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

When running normalized.backward() , there are no gradients for resized and transformed.

I have tried to find the gradient for each individual transform, but it still does not calculate the gradients.

Trying to reproduce your error, what I get when backpropagating the gradient from normalized is:

RuntimeError: grad can be implicitly created only for scalar outputs

What this error means is that the tensor you are calling backward onto should be a scalar and not a vector or multi-dimensional tensor. Generally you would want to reduce the dimensionality for example by averaging or summing. For example you could do the following:

> normalized.mean().backward()

The problem was caused by the transforms themselves: many of them did not affect the gradients as they were rearrangements or transposes of the original tensor.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM