Pytorch 에서는 어떠한 기능을 하는 블록 혹은 함수에 대해 forward() / backward() 를 "torch.autograd" 모듈로 새롭게 정의할 수 있습니다. 이는 "torch.autograd.Function" 클래스를 상속하여 "@staticmethod" 를 이용하여 입력에 대한 함수의 동작을 forward() 함수에, 함수 출력에 대한 기울기를 받아 입력에 대한 기울기를 계산하는 backward() 함수를 새롭게 정의합니다.
Example
Legendre polynomial
$y=a+bP_3 (c+dx)$ 에서 $P_3=\frac{1}{2}(5x^3-3x)$의 Legendre 3차 다항식이라고 가정하겠습니다. $P_3$ 함수에 대한 forward / backward 는 "torch.autograd.Function" 클래스를 상속하여 다음과 같이 기술할 수 있습니다.
class LegendrePolynomial3(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return 0.5 * (5 * input ** 3 - 3 * input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * 1.5 * (5 * input ** 2 - 1)
- forward pass 에서는 입력 텐서를 받아 함수 로직을 수행한 이후의 출력 텐서를 리턴합니다. "ctx" argument 는 이후 backward 연산을 수행하기 위한 정보를 담고 있는 context 객체입니다. ctx.save_for_backward 메소드를 이용하여 입력 텐서를 담슴니다.
- backward pass 에서는 함수 출력에 대한 목적함수의 기울기를 "grad_output" 이란 argument 로 받습니다. Back-propagation 수행을 위해서 함수 입력에 대한 기울기를 계산하여 리턴합니다. $P_3 (x)$를 $x$에 대해서 미분하면 $7.5x^2-1.5$가 되므로 chain-rule 에 의해 이를 "grad_output" 에 곱해주어 리턴합니다.
임의로 생성한 입력, 출력 텐서에 대해 $a, b, c, d$를 훈련 가능한 파라미터로 선언하고 다음과 같이 훈련시켜 값을 구할 수 있습니다. 이때, 위에서 정의한 LegendrePolynomial3 클래스의 "apply" 메소드를 사용합니다.
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For this example, we need
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
# not too far from the correct result to ensure convergence.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
learning_rate = 5e-6
for t in range(2000):
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
P3 = LegendrePolynomial3.apply
# Forward pass: compute predicted y using operations; we compute
# P3 using our custom autograd operation.
y_pred = a + b * P3(c + d * x)
# Compute and print loss
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# Use autograd to compute the backward pass.
loss.backward()
# Update weights using gradient descent
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')
ReLU
ReLU 함수는 "torch.nn" 모듈에 정의되어 있어 따로 정의할 필요가 없지만 위와 마찬가지로 별도로 구현할 수 있습니다.
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
backward 함수를 보면 ReLU 가 0보다 작은 부분에 대해서는 기울기가 0이므로 이에 해당하는 "grad_input[input < 0] = 0" 을 정의해준 것을 볼 수 있습니다. 이후에는 마찬가지로 "MyReLU.apply" 를 통해 새로 정의한 ReLU 모듈을 사용할 수 있습니다.
Linear
선형 변환에 대한 forward / backward 함수 또한 마찬가지로 기술할 수 있습니다. 다만, 위의 두 예는 입력만 들어오는 단순한 함수였다면 선형 변환은 입력에 대해 weight 를 곱하고 bias 를 더하는 연산을 수행하므로 forward 함수에 3개의 입력이 ("input", "weight", "bias") argument 로 들어옵니다. 중요한 점은 backward 함수를 정의할 때 forward 함수에 들어온 각 입력 별로 기울기를 계산해 리턴해야 한다는 점입니다.
class LinearFunction(torch.autograd.Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
backward 함수에서는 각 입력의 기울기를 None 으로 초기화하고 각 입력 별로 기울기를 "grad_output" 과 forward 함수에서 저장한 입력 텐서들로 계산합니다. 기술적으로 모든 입력에 대해 기울기를 계산하여야하지만 "bias" 와 같이 계산이 필요하지 않은 경우도 있는데, 이를 각 입력 별로 기울기 계산 필요여부를 튜플로 담고 있는 "ctx.needs_input_grad" 메소드를 통해 알 수 있습니다. 기울기 계산이 필요하지 않다면 해당하는 입력 자리에 None 을 리턴해도 상관없습니다.
Constant multiplication
텐서에 대한 상수곱은 어떻게 구성할 수 있을까요? forward / backward 함수에 필요한 상수는 텐서가 아니므로 "ctx.constant" 속성에 곱한 상수를 저장합니다. backward 함수 기술 시에는 forward 함수에 입력으로 들어온 상수에 대한 기울기가 필요 없으므로 None 을 리턴해야 합니다.
class MulConstant(Function):
@staticmethod
def forward(ctx, tensor, constant):
# ctx is a context object that can be used to stash information
# for backward computation
ctx.constant = constant
return tensor * constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
Gradient check
Pytorch 에서는 "torch.autograd.Function" 클래스를 이용해 구성한 함수에 대해 기울기 계산이 제대로 되었는지 수치적으로 검증할 수 있는 "torch.autograd.gradcheck" 함수를 제공합니다. "gradcheck" 함수는 새로 정의한 autograd 함수와 텐서 튜플을 입력으로 받아 backward 함수에서 기술한 기울기와 유한차원법으로 계산한 수치 기울기가 정의한 차이 이내에 있는지를 검사합니다.
from torch.autograd import gradcheck
input = (torch.randn(20,20,dtype=torch.double,requires_grad=False), \
torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
'Machine Learning Models > Pytorch' 카테고리의 다른 글
Pytorch LSTM (0) | 2021.06.26 |
---|---|
Pytorch - embedding (0) | 2021.06.23 |
Pytorch - gather (0) | 2021.06.01 |
Pytorch - scatter (3) | 2021.06.01 |
Pytorch - backward(retain_graph=True) (1) (4) | 2021.05.09 |