본문 바로가기

Artificial Intelligence/Computer Vision

A Visual Guide to Vision Transformers

https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.html

 

A Visual Guide to Vision Transformers | MDTURP

A Visual Guide to Vision Transformers This is a visual guide to Vision Transformers (ViTs), a class of deep learning models that have achieved state-of-the-art performance on image classification tasks. Vision Transformers apply the transformer architectur

blog.mdturp.ch

ViT Model overview

 

일반적인 합성곱 신경망(CNN)과 마찬가지로 Vision Transformer도 지도 학습(Supervised Learning) 방식으로 학습합니다. 즉, 이미지와 이에 맞는 레이블(label)로 구성된 데이터셋으로 모델이 학습합니다.

 

Vision Transformer가 내부적으로 어떻게 동작하는지 알아보기 위해, 하나의 데이터(배치 크기 1)에 대해서만 먼저 집중해보겠습니다. 그리고 이 질문을 함께 생각해보시죠: Transformer에 이 데이터를 입력하기 위해서는 어떻게 준비(전처리)해야 할까요?

 

레이블은 나중에 더 관련성있게 살펴보겠습니다. 지금은 이미지 하나만 남겨놓고 보겠습니다.

 

전체 이미지를 동일한 크기의 패치(p x p) 이미지로 나누어 Transformer 내부에서 사용할 수 있도록 준비합니다

 

패치들을 p' = p² x c 크기의 벡터로 평탄화(flaten)합니다. 이 때 p는 패치의 한 변의 크기이고, c는 채널 수입니다. (예를들어, RGB 이미지의 경우 채널 수는 3입니다.)

 

앞에서 이미지 패치로부터 만든 벡터들을 선형 변환을 통해 인코딩합니다. 이렇게 만들어진 패치 임베딩 벡터(Patch Embedding Vector) 는 고정된 크기 d를 갖습니다.

 

이미지 패치들을 모두 고정된 크기의 벡터로 임베딩하게 되면 n x d 크기의 배열을 얻게 됩니다. 여기서 n은 이미지 패치의 개수이고, d는 하나의 패치가 임베딩된 크기입니다.

 

모델을 효과적으로 학습하기 위해, 패치 임베딩에 추가로 분류 토큰(CLS token)이라 부르는 벡터를 추가합니다. 이 벡터는 신경망을 통해 학습 가능한 매개변수로, 무작위로 초기화됩니다. 참고로, CLS 토큰은 하나만 있으며, 모든 데이터들에 동일한 벡터를 추가합니다. (여기까지 하게 되면 n개의 패치 임베딩에 CLS 토큰을 더하여 (n+1)개에 각 임베딩 크기 d인, (n+1) x d 를 갖게 됩니다.)

 

지금까지의 패치 임베딩에는 별도의 위치 정보가 없습니다. 모든 패치 임베딩에 학습 가능한, 무작위로 초기화된 위치 임베딩 벡터(Positional Embedding Vector) 를 더하여 이 문제를 해결합니다. 또한, 앞에서 추가한 분류 토큰(CLS token) 에도 이러한 위치 벡터를 추가합니다. (Transformer에서는 Positional Encoding의 값을 '더해'줍니다. 따라서 벡터의 크기에는 변화가 없습니다.)

 

위치 임베딩 벡터를 추가하면 (n+1) x d 크기의 배열이 남습니다. 이 배열을 Transformer의 입력으로 제공할 것이며, 이에 대해서는 다음 단계에서 더 자세히 설명하겠습니다.

 

Transformer 입력 패치 임베딩 벡터는 여러 큰 벡터에 선형적으로 임베딩됩니다. 이러한 새로운 벡터는 동일한 크기의 세 부분으로 분리됩니다. 이는 각각 Q는 쿼리(Query) 벡터, K는 키(Key) 벡터, V는 값(Value) 벡터입니다. 모든 벡터들을 (n+1)개씩 얻게 됩니다.

 

먼저 어텐션 스코어 A를 계산하기 위해 모든 쿼리 벡터 Q에 모든 키 벡터 K를 곱합니다.

 

이렇게 얻은 어텐션 스코어 행렬 A의 모든 행의 합이 1이 되도록 모든 행에 softmax 함수를 적용합니다.

 

첫 번째 패치 임베딩 벡터에 대한 집계된 문맥 정보(aggregated contextual information) 를 계산하기 위해, 어텐션 행렬의 첫 번째 행에 대해서 연산을 합니다. 여기에 값 벡터 V의 가중치를 사용하여 첫 번째 이미지 패치 임베딩에 대한 집계된 문맥 정보 벡터(aggregated vector) 를 생성합니다.

 

어텐션 스코어 행렬의 다른 행들에 대해서도 위 과정을 반복하여 N+1개의 집계된 문맥 정보 벡터를 구합니다. 즉, 모든 패치마다 하나씩 (=N개) + 분류 토큰(CLS Token)에 대해서 하나 (=1) 입니다. 여기까지 해서 첫번째 어텐션 헤드(Attention Head)를 구합니다.

 

(Transformer의) 멀티-헤드 어텐션을 다루고 있으므로, 다른 QKV들에 대해서 10.1부터 10.5까지의 전체 프로세스를 반복합니다. 위 그림에서는 2개의 헤드만 가정했지만, 일반적으로 ViT는 더 많은 헤드를 갖습니다. 이렇게 여러 개의 집계된 문맥 정보 벡터(Multiple Aggregated Contextual Information Vectors)가 생성됩니다.

 

이렇게 생성한 여러 헤드들을 쌓은 뒤, 패치 임베딩의 크기와 같은 d 크기의 벡터로 매핑시킵니다.

 

이렇게 이전 단계로부터 어텐션 레이어가 완성되었고, 입력 시에 사용했던 것과 정확히 같은 크기의 임베딩들을 얻었습니다.

 

Transformer에서는 잔차 연결(Residual Connection) 을 많이 사용하는데, 이것은 단순히 이전 레이어의 입력을 현재 레이어의 출력에 더해주는 것입니다. 여기서도 잔차 연결을 하겠습니다.

 

이러한 잔차 연결을 통해 (동일한 크기 d인 벡터들끼리 더하여) 같은 크기의 벡터가 생성됩니다.

 

지금까지의 결과(output)를 비선형 활성함수를 갖는 피드 포워드 인공 신경망에 통과시킵니다.

 

Transformer에는 지금까지 연산 이후로 또다른 잔차 연결이 있지만, 여기서는 설명을 간소화하기 위해 건너뛰고 Transformer 레이어 연산을 마무리하겠습니다. 최종적으로 Transformer는 입력 크기와 같은 출력을 생성합니다.

 

지금까지 진행한 10.1부터 10.12까지의 전체 Transformer 연산을 수차례 반복합니다. 여기에서는 6번을 예시로 들었습니다.

 

마지막 단계는 분류 토큰(CLS token) 출력을 확인하는 것입니다. 이벡터는 Vision Transformer 여정의 마지막 단계에서 사용하게 됩니다.

 

최종적이고 마지막 단계에서는 분류 출력 토큰을 완전 연결(fully-connected)된 또 다른 인공 신경망에 통과시켜 입력 이미지에 대한 분류 확률(classification probabilties)을 예측합니다.

 

앞에서 예측한 분류 확률(class probabilties)과 정답(true class label)을 비교하는 표준 크로스-엔트로피 손실 함수(Cross-Entropy Loss Function)을 사용하여 Vision Transformer를 학습합니다. 모델은 역전파(backpropagation) 및 경사 하강법(gradient descent)을 사용하여 손실 함수를 최소화하는 쪽으로 모델의 가중치를 갱신하며 학습합니다.

 

지금까지 시각적 설명을 통해 데이터 준비부터 모델 학습까지 Vision Transformer의 주요 구성 요소들을 살펴봤습니다. 이 설명을 통해 Vision Transformer가 어떻게 동작하는지와 이미지 분류에 어떻게 사용되는지를 이해하는데 도움이 되었기를 바랍니다.

 

# The code was taken from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
# And has only been extended with reference comments to the following blogpost:
# https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.html


import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Blogpost step 10.11
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # Blogpost step 10.6
    def forward(self, x):
        x = self.norm(x)

        #  Blogpost step 10.1
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # Blogpost step 10.2-10.3
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)

        # Blogpost step 10.4-10.5
        out = torch.matmul(attn, v)

        # Blogpost step 10.7-10.8
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):

        # Blogpost step 11
        for attn, ff in self.layers:

            # Blogpost steps 10.1-10.10
            x = attn(x) + x # Blogpost step 10.9

            # Blogpost steps 10.11-10.12
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(

            # Blogpost steps 3-4
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),

            # Blogpost step 5
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # Blogpost step 13
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):

        # Blogpost steps 3-6
        x = self.to_patch_embedding(img)

        b, n, _ = x.shape

        # Blogpost step 7
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

        # Blogpost step 8-9
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        # Blogpost step 10-11
        x = self.transformer(x)

        # Blogpost step 12
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # Blogpost step 13
        x = self.to_latent(x)
        return self.mlp_head(x)

 

 

 

Training the VIT-Model on CIFAR-10

import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import tqdm

import torchvision.transforms as transforms
import torchvision


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
device = 'cuda'

model = ViT(
    image_size = 32,
    patch_size = 8,
    num_classes = 10,
    dim = 1024,
    depth = 6,
    heads = 12,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

 

print("Start training")

model.to(device)

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm.tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)

        # Blogpost step 14
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_test_accuracy = 0
        epoch_test_loss = 0
        for data, label in test_loader:
            data = data.to(device)
            label = label.to(device)

            test_output = model(data)
            test_loss = criterion(test_output, label)

            acc = (test_output.argmax(dim=1) == label).float().mean()
            epoch_test_accuracy += acc / len(test_loader)
            epoch_test_loss += test_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - test_loss : {epoch_test_loss:.4f} - test_acc: {epoch_test_accuracy:.4f}\n"
    )

 

paper URL : https://arxiv.org/abs/2010.11929v2

PR-12 영상 : https://www.youtube.com/watch?v=D72_Cn-XV1g

 

 

반응형
LIST