본문 바로가기

Machine Learning Models/Transformer

Vision Transformer (4) - Pytorch 구현

반응형

Vision Transformer (1)

Vision Transformer (2)

Vision Transformer (3) - Attention Map


이번 포스트에서는 Pytorch 프레임워크를 이용하여 Vision Transformer 모델을 구현해보도록 하겠습니다. 복습을 위해 다시 한번 ViT 모델의 구조를 보도록 하겠습니다. Figure 1을 보면 먼저 입력 이미지를 패치 크기만큼 나누어 순서대로 정렬한 후 임베딩을 수행합니다. 이후에 Encoder 에서 Multihead Self-Attention, MLP, Layer Normalization 블락들이 residual connection 과 함께 수행되죠. 최종적인 분류를 위해서는 $z_0^0$에 학습가능한 파라미터로 삽입한 'class token' 에 대한 representation $z_L^0$에 MLP 층을 덧붙입니다.

Figure 1

먼저 필요한 라이브러리를 호출합니다. 이중 einops 라이브러리는 고차원 텐서의 재배열, 차원축소, 반복 등의 복잡한 동작을 지원하며 Pytorch, Keras, Chainer 등의 텐서기반 여러 딥러닝 프레임워크에 동작합니다.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

 

Embedding

Patch embedding

패치 임베딩을 위해서는 입력으로 들어온 이미지를 패치 크기로 나눈후 1차원으로 펼쳐야합니다. 예를 들어 [배치 사이즈, 3, 224, 224] 크기의 배치텐서가 입력으로 들어온다면 [배치 사이즈, 224*224 /(P*P), P*P*3] 1차원 텐서로 만들어주는 것이죠. 이후에는 선형변환으로 feature 차원으로 프로젝션합니다. 물론, 지난 포스트에서 살펴봤듯이 패치 크기만큼의 kernel_size, stride 를 지정해준 2차원 컨볼루션으로도 가능합니다.

Class token

다음으로는 이미지 전체의 representation 을 위한 'class token' 을 패치 임베딩 맨 왼쪽에 추가하는 차례입니다. 'Class token' 자체는 학습가능한 임베딩이므로 "nn.Parameter" 함수를 이용해 랜덤하게 선언해주고 배치 사이즈만큼 늘려줍니다. 이후에 "cat" 함수로 이어붙입니다.

Position embedding

시퀀스 각 요소의 위치 정보를 모델에 삽입하기 위하여 학습 가능한 위치 임베딩을 선언하고 패치 임베딩에 더해줍니다. 'Class token' 이 추가되었으므로 [(196+1)=197, 임베딩차원] 크기를 가집니다. 고정된 sinusidal 함수를 사용해 인위적으로 위치 임베딩을 만들어주는 기존 Transformer 모델과 달리 "nn.Parameter" 함수를 이용하여 랜덤하게 선언하여 초기서부터 학습시킵니다.


전체 임베딩을 구현한 코드는 다음과 같습니다.

 

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
#         self.projection = nn.Sequential(
#            # using a conv layer instead of a linear one -> performance gains
#            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
#            Rearrange('b e (h) (w) -> b (h w) e'),
#        )
        
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
                
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x

 

Encoder

전체 인코더 구조는 Figure 2와 같습니다.

Figure 2

Attention

Attention 을 위해서는 query, key, value 텐서가 각 헤드마다 필요합니다. 임베딩된 입력 $x$는 [배치 사이즈, 패치길이+1, 임베딩차원] 크기를 가지고 있는 텐서이니 이를 각 헤드마다 임베딩차원을 나눠주는 재배열이 필요합니다. 다음과 같이 결과적으로 [배치 사이즈, 헤드 개수, 패치길이+1, 임베딩차원/헤드개수] 형태를 가진 텐서가 되겠죠.

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)

Attention 계산을 위해서는 query, key 텐서의 곱하기와 attention map, value 텐서의 곱하기가 필요합니다. 이는 Pytorch einsum 함수를 통해 쉽게 구현할 수 있습니다.

# [batch size, num heads, query length, key length]
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) 
# [batch size, num heads, query length, embedding]
out = torch.einsum('bhal, bhlv -> bhav ', att, values)

이후에는 multihead attention 출력을 1차원으로 만들고 커플링하여 새로운 feature 차원으로 프로젝션해서 다시 [배치 사이즈, 시퀀스길이, 임베딩차원] 형태의 텐서를 만듭니다. 전체 과정을 나타내는 코드는 다음과 같습니다.

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
  • mask 부분은 기존 Transformer 부분에서처럼 attention 을 취할 부분을 제한하는 역할을 하는 것으로 fill_value 변수에 실수형 텐서가 가질수 있는 값의 최소값을 넣어 softmax 함수 결과 0이 되도록 합니다.

MLP

MLP 층은 간단합니다. 중간에 dropout 층이 들어가고 활성화함수로 GELU (Gaussian Error Linear Unit) 이 사용됩니다. 기존 Transformer 모델처럼 중간 차원은 원래 임베딩 차원의 4배가 되도록 설정합니다.

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

전체 인코더 블락 코드는 다음과 같습니다. 먼저 residual connection 을 위한 함수를 정의하고 

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

다음과 같이 인코더 블락을 구성합니다. 이때 "nn.Sequential" 모듈을 상속하고 __init__ 메소드의 아규먼트로 인코더 블락의 구성요소를 순서대로 넣어줌으로써 forward 메소드를 따로 구현할 필요가 없습니다.

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

 

ViT

전체 구조는 위에서 정의한 인코더 블락이 $L$개 만큼 적층된 구조입니다. 키워드 언패킹 아규먼트 "**kwargs" 와 "nn.Sequential" 모듈을 이용해서 다음과 같이 매우 간단하게 구현할 수 있습니다.

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

Classification head

최종 분류를 위한 MLP Head 는 원래 $z_L^0$ 에 덧붙여집니다. 여기서는 전체 시퀀스에 대해 평균을 낸 이후에 MLP Head 를 덧붙여 보겠습니다.

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

최종적인 ViT 코드는 다음과 같이 1) 임베딩, 2) 인코더, 3) MLP Head 순으로 구성됩니다.

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
        

마지막으로 torchsummary 모듈의 summary 함수를 이용하면 입력이 forward 패스를 진행하면서 소요되는 파라미터와 형태를 알 수 있습니다.

from torchsummary import summary

summary(ViT(), (3, 224, 224), device='cpu')

 

참조

반응형

'Machine Learning Models > Transformer' 카테고리의 다른 글

Vision Transformer (3) - Attention Map  (3) 2021.06.17
Vision Transformer (2)  (0) 2021.06.16
Vision Transformer (1)  (3) 2021.06.16
Transformer Positional Encoding  (6) 2021.06.16
Transformer 구현  (0) 2021.05.27