Programming/AI & ML

[딥러닝을 활용한 의료 영상 처리 & 모델 개발] Part2-2. MONAI와 TorchIO를 활용한 3D 데이터 증강 코드 리뷰

YeonJuJeon 2025. 1. 2. 20:52

목적: MONAI와 TorchIO를 활용하여 3D 의료 영상을 전처리하고 다양한 증강 기법을 적용하여 모델 학습에 적합한 데이터를 생성


1. 라이브러리 설치 및 기본 설정

  • MONAI 설치: 3D 데이터 증강 및 의료 영상 처리에 유용한 기능 제공.
!pip install monai
  • 필수 라이브러리 임포트: 기본적인 파일 처리 및 시각화를 위한 라이브러리.
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from pathlib import Path
import os
  • 현재 작업 디렉토리 확인:
print(os.getcwd())

2. MONAI 라이브러리 상세 임포트 및 설명

  • MONAI 관련 모듈 임포트: 다양한 전처리 및 증강 기법을 위한 MONAI의 핵심 모듈들.
import monai
from monai.apps import DecathlonDataset, download_and_extract
from monai.data import DataLoader, Dataset
from monai.transforms import (
    EnsureChannelFirstd,
    LoadImaged,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    Compose,
    OneOf,
    CropForegroundd,
    Rand3DElasticd,
    RandAffined,
    RandRotated,
    RandFlipd,
)
from monai.visualize.utils import (
    blend_images,
    matshow3d
)
  • 중요 MONAI Transform 설명:
    • EnsureChannelFirstd: 채널을 첫 번째 차원으로 이동하여 텐서 형태로 변환.
    • LoadImaged: NIfTI, DICOM 이미지 로드.
    • Spacingd: 픽셀 간격 조정 및 리샘플링.
    • CropForegroundd: 전경 영역만 남기기.
    • Rand3DElasticd, RandAffined, RandRotated, RandFlipd: 랜덤 3D 변형 및 증강.

3. 데이터 다운로드 및 준비

  • 데이터 다운로드 설정: download_and_extract: MONAI의 데이터 다운로드 및 압축 해제 함수.
download = True
if download:
    directory = os.environ.get("MONAI_DATA_DIRECTORY")
    root_dir = tempfile.mkdtemp() if directory is None else directory
    print(f"root dir is: {root_dir}")

    resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar"
    compressed_file = os.path.join(root_dir, "Task02_Heart.tar")
    data_dir = os.path.join(root_dir, "Task02_Heart")
    if not os.path.exists(data_dir):
        download_and_extract(resource, compressed_file, root_dir)
  • 데이터 경로 설정 및 리스트 생성: 이미지와 레이블 파일을 매칭하여 데이터 딕셔너리 생성.
train_images = list((Path(data_dir)/"imagesTr").glob("*.nii.gz"))
train_labels = list((Path(data_dir)/"labelsTr").glob("*.nii.gz"))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_data_dicts, val_data_dicts = data_dicts[:-9], data_dicts[-9:]
{'image': PosixPath('/tmp/tmppddxayzs/Task02_Heart/imagesTr/la_007.nii.gz'),
 'label': PosixPath('/tmp/tmppddxayzs/Task02_Heart/labelsTr/la_007.nii.gz')}

4. 데이터 로드 및 기본 시각화

  • 기본 Transform 적용:Compose: 여러 transform을 순차적으로 적용.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
    ])
dataset = Dataset(data=val_data_dicts, transform=transform)
  • 데이터 형태 및 픽셀 간격 확인:
print(f"image shape: {dataset[0]['image'].shape}")
print(f"label shape: {dataset[0]['label'].shape}")
print(f"pixel spacing: {dataset[0]['image'].pixdim}")
  • 3D 이미지 시각화: 슬라이스를 건너뛰며 3D 이미지 시각화.
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)


5. 픽셀 간격 조정 및 리샘플링

  • Spacing Transform 적용: Spacingd: 픽셀 간격을 (2, 2, 3)으로 재조정하여 리샘플링.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(2, 2, 3), mode=("bilinear", "nearest"))
    ])
dataset = Dataset(data=val_data_dicts, transform=transform)
  • 리샘플링 후 데이터 형태 및 시각화 확인: 리샘플링으로 이미지 크기 조정 확인.
print(f"image shape: {dataset[0]['image'].shape}")
print(f"label shape: {dataset[0]['label'].shape}")
print(f"pixel spacing: {dataset[0]['image'].pixdim}")
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)


6. 전경 영역 크롭 및 정규화

  • Foreground 크롭 Transform:CropForegroundd: 배경 노이즈 제거로 비용 절감 및 학습 효율 향상.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
    ])

dataset = Dataset(data=val_data_dicts, transform=transform)
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)
  • 정규화 Transform 추가:NormalizeIntensityd: 이미지의 픽셀 값을 정규화하여 학습 안정성 향상.
from monai.transforms import NormalizeIntensityd

transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        NormalizeIntensityd(keys="image", channel_wise=True),
    ])

dataset = Dataset(data=val_data_dicts, transform=transform)
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)


7. 다양한 증강 기법 적용

  • 랜덤 증강 Transform 구성:RandFlipd & RandRotate90d: 이미지를 랜덤하게 뒤집기 및 90도 회전하여 데이터 다양성 증가.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=1,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=1,
            max_k=3,
        )
    ])

dataset = Dataset(data=val_data_dicts, transform=transform)

plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)
  • Affine 변형 적용:RandAffined: 이미지에 기하학적 변형을 적용하여 학습 데이터의 다양성 확보.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandAffined(
            keys=["image", "label"],
            shear_range=(0.5,0.5), mode='bilinear', padding_mode='zeros',
            prob=1,
        ),
    ])

dataset = Dataset(data=val_data_dicts, transform=transform)
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)


8. 이미지와 레이블의 블렌딩 시각화

  • 이미지와 레이블 블렌딩: blend_images: 이미지와 레이블을 합성하여 시각적으로 확인.
import torch

ret = blend_images(image=dataset[0]["image"], label=dataset[0]["label"], alpha=0.5, cmap="hsv", rescale_arrays=True)
fig,axs = plt.subplots(1,3)
slice_index = 10 * 5
axs[0].set_title(f"image slice {slice_index}")
axs[0].imshow(dataset[0]["image"][0, :, :, slice_index], cmap="gray")
axs[1].set_title(f"label slice {slice_index}")
axs[1].imshow(dataset[0]["label"][0, :, :, slice_index])
axs[2].set_title(f"blend slice {slice_index}")
axs[2].imshow(torch.moveaxis(ret[:, :, :, slice_index], 0, -1))


9. 하드 증강 기법 적용

  • 추가 증강 Transform 임포트:고급 증강 기법을 통한 데이터 다양성 및 강건성 향상.
from monai.transforms import (
    RandKSpaceSpikeNoised,
    AdjustContrastd,
    GaussianSmoothd,
    RandCoarseDropoutd,
    HistogramNormalized,
)
  • K-Space 노이즈 추가 Transform:RandKSpaceSpikeNoised: 주파수 영역에서 노이즈 추가하여 강건성 향상.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=0,
            a_max=1800,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandKSpaceSpikeNoised(keys=["image"], prob=1, intensity_range=(13, 15), channel_wise=True),
    ])

dataset = Dataset(data=train_data_dicts, transform=transform)
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)
  • 히스토그램 정규화 Transform:HistogramNormalized: 히스토그램 정규화를 통해 영상의 밝기 및 대비 향상.
transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=0,
            a_max=1800,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        HistogramNormalized(keys=["image"], num_bins=10),
    ])

dataset = Dataset(data=val_data_dicts, transform=transform)
plt = matshow3d(
    volume=dataset[0]["image"][...,1::20],
    fig=None,
    title="input image",
    frame_dim=-1,
    show=True,
    cmap="gray",
)


10. TorchIO를 활용한 추가 증강 기법

  • TorchIO 설치:TorchIO는 또 다른 강력한 의료 영상 증강 라이브러리.
!pip install torchio
  • TorchIO 임포트 및 데이터셋 로드:Colin27: 기본 제공되는 3D MRI 데이터셋.
import torchio as tio

colin = tio.datasets.Colin27()
  • 랜덤 증강 Transform 구성:OneOf: 여러 증강 중 하나를 확률에 따라 선택하여 적용.
transforms_dict = {
    tio.RandomAffine(): 0.75,
    tio.RandomElasticDeformation(): 0.25,
}
transform = tio.OneOf(transforms_dict)

transformed = transform(colin)

transformed.plot()
colin.plot()
transformed.plot()


11. 코드 요약 및 주요 포인트

  • MONAI와 TorchIO의 통합 사용: 두 라이브러리의 강점을 결합하여 다양한 전처리 및 증강 기법 적용.
  • 중요 Transform의 이해 및 활용: 데이터 전처리 과정에서 필요한 Transform을 적절히 선택하고 조합.
  • 데이터 시각화를 통한 검증: matshow3d와 blend_images를 활용하여 전처리 및 증강 결과를 시각적으로 확인.
  • 하드 증강 기법의 적용: RandKSpaceSpikeNoised, HistogramNormalized 등 고급 Transform을 통해 데이터의 다양성 및 모델의 강건성 향상.
  • TorchIO의 활용: 추가적인 증강 기법과 데이터셋을 통해 더욱 다양한 데이터 생성 가능.

학습 포인트:

  • Compose와 같은 MONAI의 기능을 활용하여 복잡한 전처리 파이프라인을 간단하게 구성할 수 있음.
  • SpacingdCropForegroundd를 통해 3D 의료 영상의 크기와 영역을 효과적으로 조정 가능.
  • RandAffined, RandFlipd 등 랜덤 증강 기법을 통해 데이터의 다양성을 확보하여 모델의 일반화 능력 향상.
  • TorchIO를 활용하여 MONAI와는 다른 증강 기법을 추가로 적용 가능.