PyTorch로 구현한 DDIM 기반 MNIST 이미지 생성기
이 글에서는 PyTorch를 사용하여 DDIM(Denoising Diffusion Implicit Models)을 기반으로 한 MNIST 이미지 생성 모델을 구현하는 과정을 단계별로 소개한다. 각 코드 셀은 논문에서 설명된 원리를 충실히 따르며, Sinusoidal Time Embedding, UNet, Residual Block, Sampling 과정을 모두 포함한다.
1. 베타 스케줄 및 시간 관련 텐서 초기화
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
# 타임스텝 수 설정
timesteps = 200
# DDIM에서는 여전히 베타 스케줄을 사용하지만, 샘플링은 다름
beta = torch.linspace(0.0001, 0.02, timesteps)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
sqrt_alpha_bar = torch.sqrt(alpha_bar)
sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar)
# DDIM sampling을 위한 eta (0이면 Deterministic DDIM)
eta = 0.0
DDIM 모델에서 사용되는 diffusion 스케줄을 정의한다.
beta는 노이즈를 주입하는 정도를 나타내며, alpha, alpha_bar는 각 타임스텝에서의 누적 노이즈 수준을 나타낸다. eta는 샘플링 시 랜덤성의 강도를 제어하는 파라미터이다.
2. 시간 임베딩 및 ResBlock 정의
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
half_dim = dim // 2
emb_scale = math.log(10000) / (half_dim - 1)
self.register_buffer('emb', torch.exp(torch.arange(half_dim) * -emb_scale))
def forward(self, t):
"""
주어진 시간 스텝 t에 대해 sinusoidal embedding 생성
"""
emb = t[:, None] * self.emb[None, :].to(t.device)
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return emb
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_dim):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_dim, out_channels)
)
self.block1 = nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, 3, padding=1)
)
self.block2 = nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, 3, padding=1)
)
# 입력과 출력 채널이 다를 경우 skip 연결용 1x1 conv
if in_channels != out_channels:
self.residual_conv = nn.Conv2d(in_channels, out_channels, 1)
else:
self.residual_conv = nn.Identity()
def forward(self, x, t):
"""
x: 이미지 feature map
t: 시간 임베딩
"""
h = self.block1(x)
time_emb = self.time_mlp(t).unsqueeze(-1).unsqueeze(-1)
h = h + time_emb
h = self.block2(h)
return h + self.residual_conv(x)
두 개의 클래스를 정의한다.
- TimeEmbedding: 각 정수 시간 t에 대해 sinusoidal embedding을 생성하는 모듈이다. GPT에서도 사용되는 방식과 동일하다.
- ResidualBlock: 시간 정보가 주입되는 Residual Convolution Block이다. 두 개의 Conv2d 블록 사이에 time embedding을 broadcast해준다. 입력 채널 수와 출력 채널 수가 다를 경우 skip connection을 위한 1x1 Conv 레이어를 사용한다.
3. 전체 UNet 정의
class SimpleUNet(nn.Module):
def __init__(self, time_dim=256):
super().__init__()
self.time_embed = TimeEmbedding(time_dim)
self.time_mlp = nn.Sequential(
nn.Linear(time_dim, time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim)
)
self.conv0 = nn.Conv2d(1, 64, 3, padding=1)
self.res1 = ResidualBlock(64, 128, time_dim)
self.res2 = ResidualBlock(128, 128, time_dim)
self.res3 = ResidualBlock(128, 64, time_dim)
self.out = nn.Conv2d(64, 1, 3, padding=1)
def forward(self, x, t):
t = self.time_embed(t)
t = self.time_mlp(t)
x = self.conv0(x)
x = self.res1(x, t)
x = self.res2(x, t)
x = self.res3(x, t)
return self.out(x)
최종적으로 사용할 간단한 UNet 구조를 정의한다. 아래와 같은 구조를 따른다.
- 입력: MNIST의 grayscale 이미지 (1채널)
- conv0 → ResBlock1 (64→128)
- ResBlock2 (128→128)
- ResBlock3 (128→64)
- 마지막 출력 Conv2D (64→1)
각 레이어에서 시간 정보를 활용하여 conditional UNet 구조로 작동한다.
4. DDIM 샘플링 함수 정의
@torch.no_grad()
def ddim_sample(model, shape, device, eta=0.0):
"""
DDIM 방식으로 샘플링 수행
- eta=0.0이면 deterministic sampling
- shape: [batch, channel, height, width]
"""
x = torch.randn(shape, device=device)
for i in reversed(range(1, timesteps)):
t = torch.full((shape[0],), i, device=device, dtype=torch.long)
alpha_t = alpha[i]
alpha_bar_t = alpha_bar[i]
alpha_bar_prev = alpha_bar[i - 1] if i > 1 else torch.tensor(1.0)
# 예측된 노이즈
pred_noise = model(x, t)
# x0 복원 (reverse direction)
x0 = (x - (1 - alpha_t).sqrt() * pred_noise) / alpha_t.sqrt()
# DDIM 수정된 예측
sigma = eta * ((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_t / alpha_bar_prev)).sqrt()
noise = torch.randn_like(x) if i > 1 else 0
x = alpha_bar_prev.sqrt() * x0 + (1 - alpha_bar_prev - sigma ** 2).sqrt() * pred_noise + sigma * noise
return x
이 함수는 학습된 모델을 사용하여 DDIM 방식으로 이미지를 생성하는 함수이다. 주요 특징은 다음과 같다.
- 맨 마지막 타임스텝부터 시작하여, 예측된 노이즈를 바탕으로 이전 스텝의 이미지를 계산한다.
- eta=0.0이면 deterministic하게 생성된다.
- x0를 추정하고, 이전 시점의 alpha_bar를 활용하여 이미지 상태를 업데이트한다.
5. MNIST 데이터셋 불러오기
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(), # (0,1) 범위로 정규화됨
transforms.Lambda(lambda x: x - 0.5), # [-0.5, 0.5]로 스케일 조정 (Diffusion 모델 특성상)
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
torchvision을 통해 MNIST 데이터를 불러오고, (0,1) 범위의 값을 (-0.5, 0.5)로 조정한다. diffusion 모델에서는 0 중심 분포가 더욱 적합하므로 이와 같은 정규화를 수행한다.
6. DDIMTrainer 클래스 정의
class DDIMTrainer:
def __init__(self, model, timesteps=200, eta=0.0):
self.model = model
self.timesteps = timesteps
self.eta = eta
self.device = next(model.parameters()).device
self.alpha = (1. - beta).to(self.device)
self.alpha_bar = torch.cumprod(self.alpha, dim=0).to(self.device)
self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar).to(self.device)
self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar).to(self.device)
self.loss_fn = nn.MSELoss()
def train_step(self, x):
"""
하나의 배치에 대해 forward → loss 계산 → backward
"""
b = x.shape[0]
t = torch.randint(0, self.timesteps, (b,), device=self.device).long()
noise = torch.randn_like(x)
sqrt_ab = self.sqrt_alpha_bar[t].view(-1, 1, 1, 1)
sqrt_1m_ab = self.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1, 1)
x_noisy = sqrt_ab * x + sqrt_1m_ab * noise
pred_noise = self.model(x_noisy, t)
loss = self.loss_fn(pred_noise, noise)
return loss
@torch.no_grad()
def sample_ddim(self, shape):
"""
DDIM 방식으로 샘플링 (deterministic 가능)
"""
b = shape[0]
img = torch.randn(shape).to(self.device)
for i in reversed(range(0, self.timesteps)):
t = torch.full((b,), i, device=self.device, dtype=torch.long)
alpha_t = self.alpha[t].view(-1, 1, 1, 1)
alpha_bar_t = self.alpha_bar[t].view(-1, 1, 1, 1)
sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)
pred_noise = self.model(img, t)
# DDIM 역방정식
x0_pred = (img - sqrt_one_minus_alpha_bar_t * pred_noise) / sqrt_alpha_bar_t
if i == 0:
img = x0_pred
else:
alpha_bar_prev = self.alpha_bar[t - 1].view(-1, 1, 1, 1)
sigma = self.eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - self.alpha[t].view(-1, 1, 1, 1)))
noise = torch.randn_like(img) if self.eta > 0 else 0
img = torch.sqrt(alpha_bar_prev) * x0_pred + torch.sqrt(1 - alpha_bar_prev - sigma**2) * pred_noise + sigma * noise
return img.clamp(-0.5, 0.5)
이 클래스는 전체 학습 및 샘플링 과정을 담당한다.
- train_step: 주어진 batch에 대해 랜덤 시간 t를 선택하고, 해당 시점에서 노이즈가 추가된 이미지를 생성한 후, 모델이 예측한 노이즈와 ground-truth 노이즈 간의 MSE를 계산한다.
- sample_ddim: DDIM 알고리즘을 바탕으로 샘플링을 수행한다. 위의 ddim_sample 함수와 유사하나, 클래스 내부에서 self.alpha, self.alpha_bar 등을 저장하고 활용한다.
7. 모델 학습 및 시각화 루프
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 모델 정의 (사용자 정의 UNet)
model = SimpleUNet().to(device)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = DDIMTrainer(model, timesteps=200, eta=0.0)
# 학습 루프
for epoch in range(100):
model.train()
pbar = tqdm(train_loader)
for batch in pbar:
x, _ = batch
x = x.to(device)
optimizer.zero_grad()
loss = trainer.train_step(x)
loss.backward()
optimizer.step()
pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")
- SimpleUNet을 초기화하고 DDIMTrainer에 전달한다.
- 100 에폭 동안 학습을 반복한다.
- 각 배치마다 train_step을 수행하고, optimizer로 파라미터를 업데이트한다.
- 에폭이 끝날 때마다 DDIM 샘플을 생성하고 matplotlib을 통해 시각화한다.
# 샘플링 및 시각화
model.eval()
samples = trainer.sample_ddim((16, 1, 28, 28))
import matplotlib.pyplot as plt
grid = torch.cat([s.squeeze().cpu() + 0.5 for s in samples], dim=-1)
plt.imshow(grid.numpy(), cmap='gray')
plt.title(f"Epoch {epoch+1} Sample")
plt.axis('off')
plt.show()
시각화는 샘플 16장을 가로로 나란히 이어붙여 한 줄의 이미지를 출력한다. 생성된 각 샘플은 모델이 예측한 x0이다.
이 프로젝트는 PyTorch로 diffusion 모델을 처음 구현하는 사용자에게 적합한 구조로 설계되었다. 복잡한 attention이나 multi-resolution 구조 없이, 시간 조건이 주어지는 residual block을 활용한 간단한 UNet 구조이다.
- DDIM 대신 DDPM, PLMS 등 다른 샘플러로 확장
- UNet 구조를 더 깊고 넓게 확장
- CIFAR-10, CelebA와 같은 컬러 이미지 데이터셋으로 확장
- classifier-free guidance 기법을 활용한 조건부 생성
필요 시 해당 기능들도 단계적으로 추가해볼 수 있다.
'Programming > AI & ML' 카테고리의 다른 글
[Euron Research 복습 과제] GPT-1 (0) | 2025.05.11 |
---|---|
[OUTTA Alpha팀 Medical AI& 3D Vision 스터디] 딥러닝 1(CNN 3) - 끝 (1) | 2025.05.10 |
[Euron Research 복습 과제] DDPM MNIST (0) | 2025.04.29 |
[OUTTA Alpha팀 Medical AI& 3D Vision 스터디] 딥러닝 1(CNN 2) (0) | 2025.03.16 |
[OUTTA Alpha팀 Medical AI& 3D Vision 스터디] 딥러닝 1(CNN 1) (0) | 2025.03.09 |