Programming/AI & ML

[딥러닝을 활용한 의료 영상 처리 & 모델 개발] Part2-4. Brain tumor Semantic segmentation(Brats 2020)을 SegResNet 이용한 학습 실습

YeonJuJeon 2025. 1. 8. 21:08

코드 링크:

https://github.com/ahxlzjt/MedImagingDL/blob/55edfd5ed184027a708c9aa1fdf732680390de3f/ch05_06_3d_brain_tumor_segmentation_infe.ipynb

 

MedImagingDL/ch05_06_3d_brain_tumor_segmentation_infe.ipynb at 55edfd5ed184027a708c9aa1fdf732680390de3f · ahxlzjt/MedImagingDL

#MedicalDL #CV #AI . Contribute to ahxlzjt/MedImagingDL development by creating an account on GitHub.

github.com


0. 라이브러리 설치 및 임포트

  1. !pip install monai, !pip install natsort로 필요 라이브러리 설치
  2. MONAI, PyTorch, NumPy, Matplotlib 등을 불러옴

1. Config 정의 및 Dataset 준비

1.1 cfg 클래스

  • base_path, train_path, test_path 등 프로젝트 경로 지정
  • seed, max_epochs, in_channels=4, out_channels=3 설정

1.2 train_data_dicts, test_data_dicts 생성

  • Brats 2020 데이터 폴더를 glob하여,
    • 5개 파일(*_flair.nii, *_t1.nii, *_t1ce.nii, *_t2.nii, *_seg.nii)이 있으면 train으로
    • 4개 파일(*_flair.nii, *_t1.nii, *_t1ce.nii, *_t2.nii)이 있으면 test
    • 각각 {'image': [...], 'label': [...]} 형태 딕셔너리로 저장

2. Custom Transform: ConvertToMultiChannelBasedOnBratsClassesd

  • Brats의 label은 {0,1,2,3} (배경=0, ED=1, ET=2, NET/NCR=3)
  • 이 클래스를 TC, WT, ET로 나누어 3채널로 변환
    • TC(Tumor Core): label 2 or 3
    • WT(Whole Tumor): label 1 or 2 or 3
    • ET(Enhancing Tumor): label 2만

3. Data Transforms (Train/Val)

3.1 Train Transform

  1. LoadImaged, EnsureChannelFirstd: 4채널 이미지를 로드
  2. ConvertToMultiChannelBasedOnBratsClassesd: 위에서 정의한 3채널 라벨로 매핑
  3. Orientationd(axcodes="RAS"), Spacingd(pixdim=(1,1,1)): 방향·스페이싱 정규화
  4. RandSpatialCropd(roi_size=[224,224,144]): 고정 크기 크롭
  5. NormalizeIntensityd(채널단위로 Z-score 등)
  6. RandScaleIntensityd(factors=0.1): 강도 스케일링

3.2 Val Transform

  • 위와 거의 동일하나, Rand형 변환(augment)은 제외

4. Dataset Split & Loader

  1. train_dict, valid_dict = train_test_split(train_data_dicts, test_size=0.2, random_state=2023)
  2. train_dataset = Dataset(train_dict, transform=train_transform)
  3. train_loader = DataLoader(..., batch_size=1, shuffle=True, ...)
  4. valid_dataset = Dataset(valid_dict, transform=val_transform)
  5. val_loader = DataLoader(..., batch_size=1, shuffle=False, ...)

5. Visualization

  • val_data_example = valid_dataset[2]
  • 이미지 shape, label shape 출력
  • plt.imshow(val_data_example["image"][i, :, :, 60]) 등으로 채널별 슬라이스 시각화
  • 라벨도 3채널로 각각 확인


6. Seed 고정 & Model 초기화

  1. seed_everything(cfg.seed)
    • 랜덤 시드 고정 (PyTorch, NumPy 등)
  2. SegResNet 모델 생성
    • model = SegResNet(...) with in_channels=4, out_channels=3
  3. Loss/Optimizer 설정
    • loss_function = DiceLoss(...)
    • optimizer = torch.optim.Adam(...)
    • lr_scheduler = CosineAnnealingLR(...)
  4. DiceMetric 준비
    • dice_metric = DiceMetric(...)
    • dice_metric_batch = DiceMetric(...) (배치별 평균 계산)

7. Inference 함수 정의

  • sliding_window_inference 사용
  • val_AMP=True 시 torch.cuda.amp.autocast로 연산 가속

8. 학습 루프

  1. 모델 훈련(model.train())
    • DataLoader 반복
    • with torch.cuda.amp.autocast(): outputs = model(inputs)
    • loss_function(outputs, labels)
    • AMP 스케일링(scaler.scale(loss).backward(), scaler.step(optimizer))
    • 일정 스텝마다 train_loss 출력
    • lr_scheduler.step()로 학습률 업데이트
  2. 검증(model.eval())
    • inference + post-processing(Activations+AsDiscrete)
    • dice_metric에 추가
    • TC, WT, ET 각각 dice 출력
    • best_metric 갱신 시 모델 weight 저장(best_metric_model.pth)

Epoch 20


9. 학습 결과 시각화

  • plt.plot(epoch_loss_values)로 Epoch Average Loss
  • plt.plot(metric_values)로 Val Mean Dice 표시
  • 채널별(TC, WT, ET) 그래프도 추가로 그려서 성능 추이 확인

10. Inference on a sample

  1. model.load_state_dict(torch.load("best_metric_model.pth"))
  2. val_input = valid_dataset[2]["image"].unsqueeze(0).cuda()
  3. val_output = inference(val_input)
  4. val_output = post_trans(val_output[0])
  5. 슬라이스 70번 등 임의로 골라 원본 이미지, 라벨, 모델 출력 각각 시각화

11. 정리

  • Dataset: Brats multi-modal(4채널) + 라벨(1채널) → 변환을 통해 3채널 분류(TC, WT, ET)
  • SegResNet 모델 기반 DiceLoss 학습
  • sliding_window_inference로 3D 세그멘테이션 수행
  • Batch=1로 GPU 메모리 절약, AMP로 연산 가속
  • DiceMetric으로 매 epoch 검증 & 최적모델 저장
  • 최종적으로 Validation 예시를 그림으로 확인