본문 바로가기

Machine Learning Models/Pytorch

Pytorch - backward(retain_graph=True) (1)

반응형

Pytorch 에서는 계산된 목적함수의 loss 값에 backward() 함수를 계산하면 모델을 구성하는 파라미터에 대한 gradient (기울기)를 계산합니다. backward() 메소드는 암묵적으로 loss 값이 벡터가 아닌 scalar 라고 가정하여 최종 loss 값에 대한 평균이나 합을 통해 벡터를 하나의 scalar 값으로 만들어 주어야 하는데요, backward(torch.tensor([1], dtype=torch.float) 이 디폴트로 설정되어 있습니다. backward() 메소드를 수행할 때 또다른 유용한 파라미터는 retain_graph 라는 매개변수입니다. ratain 이라는 말에서 알 수 있듯이 텐서들의 연속된 연산으로 구성되는 graph 를 유지한다는 것인데 한 번 살펴보도록 하겠습니다.

Figure 1과 같이 a부터 계산되는 b,c,d,e 가 있습니다. 

Figure 1

Variable a 로부터 b,c,가 계산되고 c에서부터 결과가 분기되어 d,e 라는 두 개의 출력을 생성하는데요, 이를 코드로 구현하면 다음과 같습니다.

import torch
from torch.autograd import Variable

a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()

이 상태에서 d.backward() 를 수행하면 d에 대한 a 의 기울기가 계산됩니다. 하지만 이후 e.backward() 를 수행하게 되면 에러가 발생하게 됩니다.

d.backward()
e.backward()

RuntimeError 를 보면 graph 에 대한 두 번째 backward를 시행시 저장되어 있단 중간 결과값들이 free 되었다고 나오면서 retain_graph=True 로 설정하라고 나옵니다. 즉, Pytorch 에서는 backward() 메소드를 한 번 수행하면 기울기가 계산되면서 이를 계산하기 위한 중간 결과값들이 없어진다는 것이죠. 따라서 d.backward(retain_graph=True) 로 수행하고 e.backward() 를 수행하면 에러없이 동작합니다.

d.backward(retain_graph=True)
e.backward()

backward() 메소드를 여러 번 수행하는 경우는 생각보다 자주 있습니다. 특히 각 층에 대한 여러 개의 목적함수가 존재하는 multi-task 학습에서 주로 사용되는데요. loss1, loss2 가 있을 때, 각 loss 에 대한 기울기를 다음과 같이 계산할 수 있습니다.

# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters

또한 Figure 2와 같이 deep neural networks 에 branch 가 여러 개 존재할 경우 branch 1 loss 에 대한 backward() 를 수행하게 되면 중간 결과값들이 사라지게 되어 branch 2 loss 에 대한 기울기 계산이 수행될 수 없습니다. 따라서 여러 번 backward() 를 수행할 때 retrain_graph=True 로 설정하여야 원하는 대로 최적화를 위한 기울기를 계산할 수 있습니다. 또한, 여러 번 backward() 를 수행하게 되면 grad 속성을 가지고 있는 Variable/Parameter 의 기울기가 누적되서 더해지게 됩니다. 

Figure 2

참조

 


Pytorch - backward(retain_graph=True) (2)

반응형

'Machine Learning Models > Pytorch' 카테고리의 다른 글

Pytorch - embedding  (0) 2021.06.23
Pytorch - autograd 정의  (1) 2021.06.02
Pytorch - gather  (0) 2021.06.01
Pytorch - scatter  (3) 2021.06.01
Pytorch - hook  (5) 2021.05.09