Programming/AI & ML

[Euron Research 복습 과제] DDIM

YeonJuJeon 2025. 6. 23. 12:04

PyTorch로 구현한 DDIM 기반 MNIST 이미지 생성기

이 글에서는 PyTorch를 사용하여 DDIM(Denoising Diffusion Implicit Models)을 기반으로 한 MNIST 이미지 생성 모델을 구현하는 과정을 단계별로 소개한다. 각 코드 셀은 논문에서 설명된 원리를 충실히 따르며, Sinusoidal Time Embedding, UNet, Residual Block, Sampling 과정을 모두 포함한다.


1. 베타 스케줄 및 시간 관련 텐서 초기화

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# 타임스텝 수 설정
timesteps = 200

# DDIM에서는 여전히 베타 스케줄을 사용하지만, 샘플링은 다름
beta = torch.linspace(0.0001, 0.02, timesteps)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
sqrt_alpha_bar = torch.sqrt(alpha_bar)
sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar)

# DDIM sampling을 위한 eta (0이면 Deterministic DDIM)
eta = 0.0

DDIM 모델에서 사용되는 diffusion 스케줄을 정의한다.
beta는 노이즈를 주입하는 정도를 나타내며, alpha, alpha_bar는 각 타임스텝에서의 누적 노이즈 수준을 나타낸다. eta는 샘플링 시 랜덤성의 강도를 제어하는 파라미터이다.


2. 시간 임베딩 및 ResBlock 정의

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        half_dim = dim // 2
        emb_scale = math.log(10000) / (half_dim - 1)
        self.register_buffer('emb', torch.exp(torch.arange(half_dim) * -emb_scale))

    def forward(self, t):
        """
        주어진 시간 스텝 t에 대해 sinusoidal embedding 생성
        """
        emb = t[:, None] * self.emb[None, :].to(t.device)
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim, out_channels)
        )
        self.block1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )

        # 입력과 출력 채널이 다를 경우 skip 연결용 1x1 conv
        if in_channels != out_channels:
            self.residual_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.residual_conv = nn.Identity()

    def forward(self, x, t):
        """
        x: 이미지 feature map
        t: 시간 임베딩
        """
        h = self.block1(x)
        time_emb = self.time_mlp(t).unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb
        h = self.block2(h)
        return h + self.residual_conv(x)

두 개의 클래스를 정의한다.

  • TimeEmbedding: 각 정수 시간 t에 대해 sinusoidal embedding을 생성하는 모듈이다. GPT에서도 사용되는 방식과 동일하다.
  • ResidualBlock: 시간 정보가 주입되는 Residual Convolution Block이다. 두 개의 Conv2d 블록 사이에 time embedding을 broadcast해준다. 입력 채널 수와 출력 채널 수가 다를 경우 skip connection을 위한 1x1 Conv 레이어를 사용한다.

3. 전체 UNet 정의

class SimpleUNet(nn.Module):
    def __init__(self, time_dim=256):
        super().__init__()
        self.time_embed = TimeEmbedding(time_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim)
        )

        self.conv0 = nn.Conv2d(1, 64, 3, padding=1)
        self.res1 = ResidualBlock(64, 128, time_dim)
        self.res2 = ResidualBlock(128, 128, time_dim)
        self.res3 = ResidualBlock(128, 64, time_dim)
        self.out = nn.Conv2d(64, 1, 3, padding=1)

    def forward(self, x, t):
        t = self.time_embed(t)
        t = self.time_mlp(t)
        x = self.conv0(x)
        x = self.res1(x, t)
        x = self.res2(x, t)
        x = self.res3(x, t)
        return self.out(x)

최종적으로 사용할 간단한 UNet 구조를 정의한다. 아래와 같은 구조를 따른다.

  • 입력: MNIST의 grayscale 이미지 (1채널)
  • conv0 → ResBlock1 (64→128)
  • ResBlock2 (128→128)
  • ResBlock3 (128→64)
  • 마지막 출력 Conv2D (64→1)

각 레이어에서 시간 정보를 활용하여 conditional UNet 구조로 작동한다.


4. DDIM 샘플링 함수 정의

@torch.no_grad()
def ddim_sample(model, shape, device, eta=0.0):
    """
    DDIM 방식으로 샘플링 수행
    - eta=0.0이면 deterministic sampling
    - shape: [batch, channel, height, width]
    """
    x = torch.randn(shape, device=device)
    for i in reversed(range(1, timesteps)):
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)
        alpha_t = alpha[i]
        alpha_bar_t = alpha_bar[i]
        alpha_bar_prev = alpha_bar[i - 1] if i > 1 else torch.tensor(1.0)

        # 예측된 노이즈
        pred_noise = model(x, t)

        # x0 복원 (reverse direction)
        x0 = (x - (1 - alpha_t).sqrt() * pred_noise) / alpha_t.sqrt()

        # DDIM 수정된 예측
        sigma = eta * ((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_t / alpha_bar_prev)).sqrt()
        noise = torch.randn_like(x) if i > 1 else 0
        x = alpha_bar_prev.sqrt() * x0 + (1 - alpha_bar_prev - sigma ** 2).sqrt() * pred_noise + sigma * noise

    return x

이 함수는 학습된 모델을 사용하여 DDIM 방식으로 이미지를 생성하는 함수이다. 주요 특징은 다음과 같다.

  • 맨 마지막 타임스텝부터 시작하여, 예측된 노이즈를 바탕으로 이전 스텝의 이미지를 계산한다.
  • eta=0.0이면 deterministic하게 생성된다.
  • x0를 추정하고, 이전 시점의 alpha_bar를 활용하여 이미지 상태를 업데이트한다.

5. MNIST 데이터셋 불러오기

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # (0,1) 범위로 정규화됨
    transforms.Lambda(lambda x: x - 0.5),  # [-0.5, 0.5]로 스케일 조정 (Diffusion 모델 특성상)
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

torchvision을 통해 MNIST 데이터를 불러오고, (0,1) 범위의 값을 (-0.5, 0.5)로 조정한다. diffusion 모델에서는 0 중심 분포가 더욱 적합하므로 이와 같은 정규화를 수행한다.


6. DDIMTrainer 클래스 정의

class DDIMTrainer:
    def __init__(self, model, timesteps=200, eta=0.0):
        self.model = model
        self.timesteps = timesteps
        self.eta = eta

        self.device = next(model.parameters()).device

        self.alpha = (1. - beta).to(self.device)
        self.alpha_bar = torch.cumprod(self.alpha, dim=0).to(self.device)
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar).to(self.device)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar).to(self.device)

        self.loss_fn = nn.MSELoss()
    def train_step(self, x):
        """
        하나의 배치에 대해 forward → loss 계산 → backward
        """
        b = x.shape[0]
        t = torch.randint(0, self.timesteps, (b,), device=self.device).long()
        noise = torch.randn_like(x)

        sqrt_ab = self.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
        sqrt_1m_ab = self.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)

        x_noisy = sqrt_ab * x + sqrt_1m_ab * noise

        pred_noise = self.model(x_noisy, t)
        loss = self.loss_fn(pred_noise, noise)
        return loss

    @torch.no_grad()
    def sample_ddim(self, shape):
        """
        DDIM 방식으로 샘플링 (deterministic 가능)
        """
        b = shape[0]
        img = torch.randn(shape).to(self.device)

        for i in reversed(range(0, self.timesteps)):
            t = torch.full((b,), i, device=self.device, dtype=torch.long)
            alpha_t = self.alpha[t].view(-1, 1, 1, 1)
            alpha_bar_t = self.alpha_bar[t].view(-1, 1, 1, 1)
            sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
            sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)

            pred_noise = self.model(img, t)

            # DDIM 역방정식
            x0_pred = (img - sqrt_one_minus_alpha_bar_t * pred_noise) / sqrt_alpha_bar_t
            if i == 0:
                img = x0_pred
            else:
                alpha_bar_prev = self.alpha_bar[t - 1].view(-1, 1, 1, 1)
                sigma = self.eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - self.alpha[t].view(-1, 1, 1, 1)))
                noise = torch.randn_like(img) if self.eta > 0 else 0
                img = torch.sqrt(alpha_bar_prev) * x0_pred + torch.sqrt(1 - alpha_bar_prev - sigma**2) * pred_noise + sigma * noise

        return img.clamp(-0.5, 0.5)

이 클래스는 전체 학습 및 샘플링 과정을 담당한다.

  • train_step: 주어진 batch에 대해 랜덤 시간 t를 선택하고, 해당 시점에서 노이즈가 추가된 이미지를 생성한 후, 모델이 예측한 노이즈와 ground-truth 노이즈 간의 MSE를 계산한다.
  • sample_ddim: DDIM 알고리즘을 바탕으로 샘플링을 수행한다. 위의 ddim_sample 함수와 유사하나, 클래스 내부에서 self.alpha, self.alpha_bar 등을 저장하고 활용한다.

7. 모델 학습 및 시각화 루프

from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 정의 (사용자 정의 UNet)
model = SimpleUNet().to(device)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = DDIMTrainer(model, timesteps=200, eta=0.0)

# 학습 루프
for epoch in range(100):
    model.train()
    pbar = tqdm(train_loader)
    for batch in pbar:
        x, _ = batch
        x = x.to(device)

        optimizer.zero_grad()
        loss = trainer.train_step(x)
        loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")
  1. SimpleUNet을 초기화하고 DDIMTrainer에 전달한다.
  2. 100 에폭 동안 학습을 반복한다.
  3. 각 배치마다 train_step을 수행하고, optimizer로 파라미터를 업데이트한다.
  4. 에폭이 끝날 때마다 DDIM 샘플을 생성하고 matplotlib을 통해 시각화한다.
    # 샘플링 및 시각화
    model.eval()
    samples = trainer.sample_ddim((16, 1, 28, 28))

    import matplotlib.pyplot as plt
    grid = torch.cat([s.squeeze().cpu() + 0.5 for s in samples], dim=-1)
    plt.imshow(grid.numpy(), cmap='gray')
    plt.title(f"Epoch {epoch+1} Sample")
    plt.axis('off')
    plt.show()

시각화는 샘플 16장을 가로로 나란히 이어붙여 한 줄의 이미지를 출력한다. 생성된 각 샘플은 모델이 예측한 x0이다.

이게 조금 잘 나온듯?


이 프로젝트는 PyTorch로 diffusion 모델을 처음 구현하는 사용자에게 적합한 구조로 설계되었다. 복잡한 attention이나 multi-resolution 구조 없이, 시간 조건이 주어지는 residual block을 활용한 간단한 UNet 구조이다.

  • DDIM 대신 DDPM, PLMS 등 다른 샘플러로 확장
  • UNet 구조를 더 깊고 넓게 확장
  • CIFAR-10, CelebA와 같은 컬러 이미지 데이터셋으로 확장
  • classifier-free guidance 기법을 활용한 조건부 생성

필요 시 해당 기능들도 단계적으로 추가해볼 수 있다.