[英]Implementing a conv2d backward in pytorch
我想實現 conv2d 的反向 function。
# Inherit from Function
class LinearFunction(Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
我認為我對這個function還沒有一個清晰的認識。
如何向后實現 conv2d function?
這是 conv2d 向后 function 的實現:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from torch.nn import init
import math
import numpy as np
class Conv2dFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
output = F.conv2d(input, weight, bias, stride, padding, dilation, groups)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
dilation = ctx.dilation
groups = ctx.groups
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0,2,3)).squeeze(0)
return grad_input, grad_weight, grad_bias, None, None, None, None
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.