簡體   English   中英

如何以支持 autograd 的方式圍繞其中心旋轉 PyTorch 圖像張量?

[英]How do I rotate a PyTorch image tensor around it's center in a way that supports autograd?

我想圍繞它的中心隨機旋轉圖像張量(B,C,H,W)(我認為是二維旋轉?)。 我想避免使用 NumPy 和 Kornia,這樣我基本上只需要從 torch 模塊導入。 我也沒有使用torchvision.transforms ,因為我需要它與 autograd 兼容。 本質上,我正在嘗試為像 DeepDream 這樣的可視化技術創建torchvision.transforms.RandomRotation()的 autograd 兼容版本(因此我需要盡可能避免偽影)。

import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


# Somehow rotate tensor around it's center
def rotate_tensor(tensor, radians):
    ...
    return rotated_tensor

# Get a random angle within a specified range 
r_degrees = 5
angle_range = list(range(-r_degrees, r_degrees))
n = random.randint(angle_range[0], angle_range[len(angle_range)-1])

# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180


# test_tensor = preprocess_simple('path/to/file', (512,512))
test_tensor = torch.randn(1,3,512,512)


# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor, ang_rad)


# Optionally use this to check rotated image
# deprocess_simple(output_tensor, 'rotated_image.jpg')

我試圖完成的一些示例輸出:

旋轉圖像的第一個例子 旋轉圖像的第二個例子

所以網格生成器和采樣器是空間變換器(JADERBERG、Max 等)的子模塊。 這些子模塊是不可訓練的,它們讓您可以應用可學習和不可學習的空間轉換。 在這里,我使用這兩個子模塊,並使用它們使用 PyTorch 的函數F.affine_gridF.affine_sample (這些函數分別是生成器和采樣器的實現)通過theta旋轉圖像:

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def get_rot_mat(theta):
    theta = torch.tensor(theta)
    return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
                         [torch.sin(theta), torch.cos(theta), 0]])


def rot_img(x, theta, dtype):
    rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
    grid = F.affine_grid(rot_mat, x.size()).type(dtype)
    x = F.grid_sample(x, grid)
    return x


#Test:
dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)

在上面的例子中,假設我們把我們的形象im看作是一只穿着裙子的跳舞貓: 在此處輸入圖片說明

rotated_im將是一只穿着裙子的逆時針旋轉 90 度旋轉的舞貓:

在此處輸入圖片說明

這是我們所得到的,如果我們稱之為rot_imgtheta eqauls到np.pi/4 在此處輸入圖片說明

最好的部分是它可以區分輸入並具有 autograd 支持! 萬歲!

使用 torchvision 它應該很簡單:

import torchvision.transforms.functional as TF

angle = 30
x = torch.randn(1,3,512,512)

out = TF.rotate(x, angle)

例如,如果x是:

風箏

旋轉 30 度out (注意:逆時針方向):

風箏旋轉

有一個 pytorch 功能:

x = torch.tensor([[0, 1],
            [2, 3]])

x = torch.rot90(x, 1, [0, 1])
>> tensor([[1, 3],
           [0, 2]])

以下是文檔: https : //pytorch.org/docs/stable/generated/torch.rot90.html

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM