728x90
반응형

paper : https://arxiv.org/abs/2207.07027

github : https://github.com/nyuad-cai/MedFuse

 

NYU Abu Dhabi

 

의료 도메인에서는 환자의 다양한 형태의 데이터(모달리티)를 활용해 더 정밀한 예측을 수행하는 것이 중요한 과제로 여겨진다.

예를 들어, 중환자실(ICU)에서는 환자의 연속적인 생체징후나 검사 결과와 같은 임상 시계열 데이터, 그리고 흉부 X-ray 영상처럼 이미지 데이터가 각각 따로 수집되고, 이러한 데이터를 통합해 분석하면 환자의 상태를 더 정확히 파악할 수 있을 것으로 기대 할 수 있다.

하지만 자연영상이나 음성처럼 일반적인 다중모달 데이터와 달리, 의료 데이터는 모달리티마다 수집 시점이 다르거나 아예 누락되는 경우가 많다.

즉, 모든 환자에게 모든 형태의 데이터가 항상 존재하지는 않는다는 말이다.

그렇기 때문에, 모델을 학습시킬 때 모든 모달리티가 항상 동시에 존재한다는 가정을 두게 되면, 현실의 데이터 특성과 맞지 않을 뿐 아니라, 실제로 학습에 사용할 수 있는 샘플 수가 크게 줄어드는 문제가 발생하게 된다. (missing modality issue)

이러한 어려움으로 인해, 현재까지 의료 분야에서 성공적인 예측 모델들은 대부분 단일 모달 데이터에만 의존하는 경향을 보여왔으며, 실제로 공개된 다중모달 의료 데이터 벤치마크도 매우 제한적인 상황이다.

 

MedFuse

 

MedFuse 논문에서는 다음과 같은 두 가지를 제안하고 있다.

  1. 임상 시계열 데이터와 의료 영상을 통합할 수 있으면서도 모달리티 누락에도 유연하게 대응 가능한 새로운 다중모달 융합(fusion) 모델을 설계
  2. 공개된 MIMIC-IV(ICU 임상 데이터)와 MIMIC-CXR(흉부 X선 영상)을 연계하여, 병원 내 사망 예측과 환자 표현형 분류라는 두 가지 주요 임상 과제에 대한 새로운 벤치마크 실험을 수행

 

MedFuse 모델 구조

MedFuse는 모달리티별 인코더(modality-specific encoder)에서 추출한 특징들을 LSTM 기반의 융합 모듈을 통해 통합하고, 이를 바탕으로 최종 예측을 수행하는 아키텍처로, 각 모달 인코더는 해당 데이터의 특성을 효과적으로 반영할 수 있도록 설계된 신경망 구조로 구성되어 있으며, 이렇게 얻어진 두 모달의 잠재 표현(latent representation)은 하나의 시계열로 간주되어 순환 신경망(LSTM)을 통해 처리되게 된다.

 

모달리티별 인코더 구조

1. 임상 시계열 데이터 인코더 (fehr)

이 인코더는 환자의 ICU 입원 중 측정된 연속적인 임상 수치들을 입력으로 받아, 시간 경과에 따른 중요한 패턴을 요약한다.

두 개의 층으로 구성된 LSTM 네트워크를 기반으로 하며, 중간 드롭아웃(dropout)을 통해 과적합을 방지한다.

class LSTM(nn.Module):
    def __init__(self, input_dim=76, num_classes=1, hidden_dim=128, batch_first=True, dropout=0.0, layers=1):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.layers = layers
        for layer in range(layers):
            setattr(self, f'layer{layer}', nn.LSTM(
                input_dim, hidden_dim,
                batch_first=batch_first,
                dropout=dropout)
            )
            input_dim = hidden_dim  
        self.do = nn.Dropout(dropout) if dropout > 0.0 else None
        self.feats_dim = hidden_dim
        self.dense_layer = nn.Linear(hidden_dim, num_classes)

 

입력 시계열은 시간축을 따라 처리되며, 마지막 시점의 은닉 상태(hidden state)를 잠재 표현 벡터로 사용한다.

이 벡터의 차원은 256이며, vehr ∈ ℝ^256으로 정의된다. 이 벡터는 환자의 임상적 상태를 예측하는 분류기 gehr의 입력으로 전달된다.

def forward(self, x, seq_lengths):
        x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
        for layer in range(self.layers):
            x, (ht, _) = getattr(self, f'layer{layer}')(x)
        feats = ht.squeeze()
        if self.do is not None:
            feats = self.do(feats)
        out = self.dense_layer(feats)
        scores = torch.sigmoid(out)
        return scores, feats

 

forward에서는 여러 층의 LSTM을 거쳐 시계열의 마지막 타임스텝에 해당하는 출력을 dense_layer에 입력해 최종 예측값 ŷehr를 출력한다.

학습에는 Binary Cross-Entropy(BCE) 손실 함수*를 사용하여 실제 레이블 yehr과 예측값 사이의 오차를 최소화한다.

*Binary Cross-Entropy(BCE) 손실 함수 : 이진 분류 문제에서 모델의 예측 확률과 실제 레이블 간의 차이를 측정하여, 예측이 실제와 얼마나 다른지를 평가하는 데 사용되는 손실 함수. 예측 확률이 실제 레이블과 가까울수록 낮은 값을 가짐

 

요약하면, 이 인코더는 시계열 입력 Xehr(각 시점마다 76차원 피처 벡터)를 받아 LSTM을 통해 256차원의 함축 표현 vehr을 생성하고, 이를 이용해 환자의 예측 레이블을 추정하는 역할을 수행한다.

 

EHR 시계열 입력 Xehr (배열 형태: [batch, time, 76])

LSTM

마지막 타임스텝 은닉 상태 → vehr (크기: [batch, 256])

Linear(256 → 1) → 시그모이드
↓ ŷehr (예측된 사망확률 또는 임상 이벤트 확률)

 

2. 흉부 X-ray 영상 인코더 (fcxr)

영상 인코더는 환자 입원 중 촬영된 흉부 X-ray 영상을 입력으로 받아, 이를 고차원 특징 벡터로 임베딩하는 역할을 한다.

인코더로는 ResNet-34 합성곱 신경망이 사용되며, 경우에 따라 ImageNet 사전학습 가중치를 초기값으로 활용한다.

(ImageNet은 1천만 장 이상의 다양한 이미지와 1천 개의 클래스(label)로 구성된 대규모 데이터셋으로, ImageNet에서 이미 배운 기본적인 시각적 특징을 활용하겠다는 의미)

class CXRModels(nn.Module):
    def __init__(self, args, device='cpu'):
        super(CXRModels, self).__init__()
        self.args = args
        self.device = device
        self.vision_backbone = getattr(torchvision.models, self.args.vision_backbone)(pretrained=self.args.pretrained)
        # vision_backbone -> ResNet-34
        
        for classifier in ['classifier', 'fc']:
            cls_layer = getattr(self.vision_backbone, classifier, None)
            if cls_layer is not None:
                d_visual = cls_layer.in_features
                setattr(self.vision_backbone, classifier, nn.Identity())
                # in_features 값을 저장하고,
                # resnet34.fc -> nn.Identity()로 변경 
                # (분류기층 제거 -> 클래스 예측하지 않고 512 차원의 feature 벡터만 출력)
                break
                
        self.bce_loss = torch.nn.BCELoss(size_average=True)
        self.classifier = nn.Sequential(
            nn.Linear(d_visual, self.args.vision_num_classes)
        )
        self.feats_dim = d_visual

 

*ResNet-34 구조
입력 이미지 (224x224x3)

conv1 → 7x7 Conv, stride 2

maxpool → 3x3 MaxPool, stride 2

layer1 → BasicBlock × 3

layer2 → BasicBlock × 4

layer3 → BasicBlock × 6

layer4 → BasicBlock × 3

avgpool → Global Average Pooling (출력 크기: 512)
↓ 
fc → Fully Connected Layer: nn.Linear(512, 1000)

 

X-ray 영상은 흑백 단일 채널이기 때문에, ResNet에 입력하기 전에 224×224 해상도로 리사이즈하고, 동일 영상을 3채널로 복제하여 처리한다. R

esNet-34의 평균 풀링(average pooling) 레이어 출력이 이 인코더의 최종 출력 vcxr이 되며, 차원은 512이다 (vcxr ∈ ℝ^512).

def forward(self, x, labels=None, n_crops=0, bs=16):
        lossvalue_bce = torch.zeros(1).to(self.device)

        visual_feats = self.vision_backbone(x)
        preds = self.classifier(visual_feats)

        preds = torch.sigmoid(preds)

        if n_crops > 0:
            preds = preds.view(bs, n_crops, -1).mean(1)
        if labels is not None:
            lossvalue_bce = self.bce_loss(preds, labels)

        return preds, lossvalue_bce, visual_feats

 

gcxr 또한 시그모이드 출력층으로 구성되며, 예측값 ŷcxr은 14차원의 벡터(CheXpert 알고리즘을 이용해 추출)다. 학습 역시 BCE 손실을 사용하여, 실제 영상 레이블 ycxr과의 오차를 최소화한다.

결과적으로 이 인코더는 X-ray 한 장을 512차원의 고차원 특징 벡터로 변환하는 역할을 수행하며, 사전학습을 통해 정확한 병변 표현이 가능하도록 조정된다.

결국 medfuse내의 두 인코더는 서로 다른 레이블(임상 예측값 vs 영상 소견)을 바탕으로 개별적으로 사전 학습된다.

임상 시계열 인코더 fehr는 사망 여부나 표현형 존재 여부 같은 임상 레이블 yehr을 예측하도록 학습되고,

영상 인코더 fcxr는 영상 자체의 병변 라벨 ycxr을 예측하도록 학습된다.

각 인코더는 각기 다른 BCE 손실 함수를 기반으로 독립적으로 최적화되며, 이를 통해 각 모달리티에 특화된 강력한 표현 능력을 갖춘 특징 추출기를 구축할 수 있다.

 

다중모달 융합 모듈 (MedFuse Fusion Module)

모달리티별 인코더가 각각 사전 학습을 마친 후, MedFuse는 이 둘을 결합하여 최종 예측을 수행한다.

이때 융합 단계에서는 인코더에 붙어 있던 개별 분류기(예: gehr, gcxr)는 제거하고, 특징 추출 부분만 유지한다.

두 인코더가 출력하는 잠재 표현은 차원이 서로 다르다 (vehr는 256차원, vcxr는 512차원임)

이를 통합하기 위해, vcxr은 투영층(ϕ)을 통해 256차원으로 변환되며, 이렇게 변환된 벡터는 vcxr*로 표기된다.

이 투영층은 단순한 선형 계층(linear layer)으로 구현된다. 이 과정을 통해 vehr ∈ ℝ^256와 vcxr ∈ ℝ^256이 나란히 결합 가능한 형태가 된다.

class Fusion(nn.Module):
    def __init__(self, args, ehr_model, cxr_model):
        super(Fusion, self).__init__()
        self.args = args
        self.ehr_model = ehr_model
        self.cxr_model = cxr_model

        self.projection = nn.Linear(self.cxr_model.feats_dim, self.ehr_model.feats_dim)  # 512 → 256 투영
        
        self.lstm = nn.LSTM(input_size=self.ehr_model.feats_dim, hidden_size=512, batch_first=True)
        self.classifier = nn.Linear(512, self.args.num_classes)  # 최종 분류기

 

MedFuse의 핵심 아이디어는 이 두 모달 벡터를 하나의 시퀀스로 간주하고, 이를 순차적으로 처리하는 것이다.

이를 위해 두 벡터는 [vehr, vcxr*] 형태의 시퀀스로 묶이고, 이 시퀀스는 LSTM 기반 융합 모듈(ffusion)의 입력으로 전달된다.

이 LSTM은 입력 벡터 크기 256, 은닉 상태 크기 512를 갖는 단일 레이어 구조로, 시퀀스 길이 2를 시간차원으로 따라 읽게 된다.

여기서 첫 번째 time-step에서는 vehr를 입력으로 받아 은닉 상태 h를 업데이트한다.

그리고 다음 두 번째 time-step에서는 이전 상태를 바탕으로 vcxr* 정보를 통합하여 최종 은닉 상태 hfusion을 생성한다.

 

def forward(self, ehr_input, cxr_input=None):
vehr = self.ehr_model(ehr_input)  # ∈ ℝ^256
inputs = [vehr.unsqueeze(1)]      # shape: [batch, 1, 256]

if cxr_input is not None:
    vcxr = self.cxr_model(cxr_input)      # ∈ ℝ^512
    vcxr_proj = self.projection(vcxr)     # ∈ ℝ^256
    inputs.append(vcxr_proj.unsqueeze(1)) # shape: [batch, 1, 256]

sequence = torch.cat(inputs, dim=1)  # shape: [batch, 2, 256
# sequence[:, 0, :] = vehr     ← 첫 번째 time-step
# sequence[:, 1, :] = vcxr*    ← 두 번째 time-step
_, (hn, _) = self.lstm(sequence)     # LSTM이 시간축(time-step)을 따라 처리
hfusion = hn[-1]
output = self.classifier(hfusion)  # 예측값 ŷfusion
return output

 

이러한 순차적 결합 방식은 다음 두 가지 측면에서 설계 되었다.

  1. 의사의 의사결정 과정을 모방)하기 위함 임상의는 일반적으로 임상 수치(EHR)를 먼저 살펴보고, 필요 시 영상 자료(CXR)를 추가로 참고하여 종합적으로 판단한다. MedFuse의 LSTM은 이 흐름을 모방하도록 구성되어 있어, 먼저 vehr를 기반으로 판단한 뒤, vcxr*로 정보를 보완하도록 동작한다.
  2. 모달리티 누락에 대한 유연한 처리를 위함 LSTM은 시퀀스 길이가 변해도 동작할 수 있는 구조이므로, 특정 모달리티가 누락된 경우에도 예측이 가능하다. 예를 들어, X-ray 영상이 없는 환자의 경우 vfusion = [vehr]만 입력되며, LSTM은 이를 단독으로 처리해 hfusion을 생성할 수 있다. 이는 MedFuse가 학습과 추론 모두에서 모달리티 누락을 허용하도록 설계되었음을 의미한다.

 

 

생성된 hfusion은 최종 분류기(gfusion)로 전달되어 예측 결과 ŷfusion을 산출한다.

이때 예측의 대상은 EHR 기반 레이블(yehr)이 되며, 영상에 대해 별도의 예측 목적은 두지 않는다.

즉, 영상은 판단을 보조하는 정보로만 사용된다.

(실제 병원 데이터에서는 모든 환자가 X-ray를 찍지는 않으며, 기본 임상 정보(EHR)는 대부분 존재하기에 ehr을 기준으로 평가한다. Ehr 데이터는 필수라고 할 수 있다)

전체 학습 과정에서는 BCE(Binary Cross-Entropy) 손실을 기반으로 ŷfusion과 yehr 간의 오차를 최소화한다.

이 손실을 통해, 인코더(fehr, fcxr), 투영층(ϕ), LSTM 융합 모듈(ffusion), 최종 분류기(gfusion)까지 모델 전체를 end-to-end로 공동 미세조정(fine-tuning)한다.

 

정리하자면,

MedFuse는 임상 시계열(EHR)을 중심으로 환자의 상태를 예측하며, 흉부 X-ray 영상(CXR)은 보조적 정보로 활용된다.

MedFuse는 ‘모달리티별 인코더(특징 추출) + LSTM 융합 모듈 + 최종 분류기’ 로 구성된 구조를 갖고 있으며,

각 모달리티의 임베딩 벡터를 하나의 시퀀스처럼 LSTM에 순차적으로 입력함으로써, 서로 다른 데이터 간의 상호작용을 효과적으로 학습할 수 있도록 설계되었다.

이 구조는 EHR만 있는 경우에도, CXR이 추가된 경우에도, 심지어 CXR만 존재하는 경우에도 유연하게 작동할 수 있다.

 

데이터셋

MedFuse의 성능을 평가하기 위해, 연구진은 MIMIC-IV와 MIMIC-CXR이라는 두 가지 공개 의료 데이터셋을 연계하여 활용했다.

(https://physionet.org, 데이터 접근을 위해서는 CITI 교육과정 이수 필요)

 

MIMIC-IV는 중환자실(ICU) 환자의 전자 건강기록(EHR)을 포함한 대규모 임상 데이터베이스이며,

MIMIC-CXR은 동일 기관의 흉부 X-ray 영상 및 판독 보고서가 포함된 데이터베이스다.

이 두 데이터셋은 환자 식별자(ID)와 입원 기록을 기준으로 연결(linking)되었으며, 각 ICU 입원 사례마다 해당 입원 기간 중 촬영된 X-ray 영상이 함께 매칭되었다.

이러한 연계를 바탕으로, 아래 두 가지 임상 예측 과제가 정의되었다.

1. 환자 표현형 분류 (Phenotyping): 25개 만성/급성 질환 분류 (multi-label)

2. 병원 내 사망 예측: 입원 후 48시간 내 데이터를 기반으로 사망 여부 예측 (binary)

 

MIMIC-IV (EHR 시계열 데이터)

미국 Beth Israel 병원의 중환자실 환자 EHR 데이터. 17개 임상 변수 사용 (5개 범주형, 12개 연속형).

주요 다음과 같은 전처리를 진행 : 2시간 간격 샘플링, 결측치 보정, 이상치 제거, 범주형 변수는 one-hot encoding

모델 입력 형식: xehr ∈ ℝ^{t×76} (t는 환자별 입원 기간)

 

MIMIC-CXR (흉부 X-ray 영상)

동일 병원의 흉부 방사선 이미지 + 판독 결과 데이터셋. 각 입원 사례에 대해 표현형 분류 과제에는 입원 중 가장 마지막 영상, 사망 예측 과제에는 입원 후 48시간 이내 마지막 영상을 사용함. 즉, 각 입원 사례에 0 또는 1장의 영상이 연결되게 됨.

영상 라벨로는 CheXpert 알고리즘으로 생성된 14개 병변의 이진 라벨을 사용했으며, 이 라벨을 기반으로 ResNet-34 영상 인코더를 사전학습(pre-train) 진행

 

벤치마크 과제 정의

MedFuse가 단일 모달 또는 다중 모달 입력에서 어떻게 성능이 달라지는지를 검증하기 위한 벤치마크를 아래와 같이 정의하였으며, 이 설정은 기존의 MIMIC-IV 기반 사망 예측 연구들과 동일하게 구성되어 있다.

 

1. 환자 표현형 분류 (Phenotype Classification)

환자가 ICU에 입원하는 동안 진단된 주요 질환 또는 상태를 예측하는 다중 레이블 분류 문제로, 예측 대상은 Harutyunyan et al. (2019)에서 정의한 25개 표현형(phenotype) 레이블임.

여기에는 만성 신부전, 간 질환, 패혈증 등과 같은 만성 질환, 급성 질환, 혼합형 상태가 포함됨.

  • 입력: 각 입원 사례에 대해, 전체 ICU 입원 기간 동안의 임상 시계열 데이터 xehr
  • 정답 레이블: yehr ∈ ℝ²⁵ (25개 표현형의 존재 여부를 나타내는 이진 벡터)
  • 데이터 정제:
    • MIMIC-IV의 ICD-9 및 ICD-10 코드를 ICD-9 체계로 통합
    • 미국 CMS의 CCS 분류체계를 적용해 25개 표현형 레이블로 매핑
  • 두 번째 입력(선택적): 해당 입원 동안 가장 마지막으로 촬영된 흉부 X-ray 이미지
  • 평가 지표: AUROC (ROC 곡선 아래 면적), AUPRC (정밀도-재현율 곡선 아래 면적)
    → 다중 레이블 특성상 micro-averaging 방식*으로 산출됨

* micro-averaging : 모든 클래스의 예측 결과를 개별 샘플 단위로 합쳐 전체 정확도(정밀도, 재현율 등)를 계산하는 방식.

 

2. 병원 내 사망 예측 (In-hospital Mortality Prediction)

환자가 ICU에 입실한 후 해당 입원 중 사망할지를 조기에 예측하는 이진 분류 문제로, 입실 후 48시간 동안의 초기 데이터를 활용해 환자의 예후를 판단하는 것이 목표임.

이는 실제 임상에서 조기 위험 판단이 중요한 상황을 반영하여 설정한 것.

  • 입력: 입실 후 첫 48시간 동안의 임상 시계열 데이터 xehr
  • 정답 레이블: yehr ∈ {0, 1} (퇴원 시 생존=0, 사망=1)
  • 제외 조건: 48시간 미만 입원 사례는 데이터 부족으로 인해 학습에서 제외
  • 두 번째 입력(선택적): 입원 초기 48시간 내에 촬영된 마지막 X-ray 이미지 → 일부 환자는 해당 기간 동안 영상이 없을 수 있으며, 이 경우 EHR 단독 입력으로 처리
  • 평가 지표: AUROC, AUPRC

데이터 전처리 및 분할

1. 데이터 전처리

임상 시계열 데이터 (EHR): 17개 변수 선택 → 2시간 간격 샘플링 → 결측값 보정, 정규화, 원-핫 인코딩 → 76차원 시계열 피처 행렬 생성

X-ray 영상 (CXR): 224×224로 리사이즈, 흑백 영상을 3채널로 복제하여 CNN 입력 형식으로 변환

2. 데이터 분할 전략

환자 단위(patient-level) 분할로 학습/검증/테스트 세트 구분 → 동일 환자가 여러 세트에 중복되지 않도록 구성 (70% 학습, 10% 검증, 20% 테스트)

각 입원 사례당 최대 1장의 X-ray만 사용, 나머지 영상은 영상 인코더 사전학습용

3. 부분 페어링 (Partially Paired) 구성

모든 사례에 두 모달리티(EHR, CXR)가 모두 존재하지 않음

논문에서는 완전 페어링 샘플과 단일 모달 샘플을 구분하여 평가

4. 인코더 사전학습용 데이터

영상 인코더: MIMIC-CXR에 CheXpert 알고리즘 적용 → 14개 병변 라벨 생성 → CXRUNI (325,188장 학습 / 15,282장 검증 / 36,625장 테스트) → 인코더 사전학습에 사용 (최종 예측과는 직접 무관)

임상 인코더: 각 과제(EHR task)의 실제 레이블(사망 여부, 표현형)로 사전학습 수행

 

MedFuse 훈련 전략

MedFuse의 효과를 검증하기 위해 연구진은 2단계 훈련 전략을 적용하였다.

먼저 모달리티별 인코더를 개별적으로 사전학습(pre-training)한 뒤,

이를 기반으로 전체 모델을 공동 미세조정(joint fine-tuning)하는 방식이다.

또한 다양한 기존 모델들을 구현해 성능 비교 실험도 함께 수행하였다.

 

1단계: 모달리티별 인코더 사전학습

임상 시계열 인코더는 해당 예측 과제(EHR 학습 세트 기준)에 맞춰 구성된 2-layer LSTM 분류기 구조로 학습된다.

중요한 점은 X-ray 유무와 무관하게 모든 EHR 사례가 사전학습에 사용된다는 것이다. → 즉, (EHR)PARTIAL 세트 전체가 활용된다.

영상 인코더(ResNet-34)는 CXRUNI 학습 세트(약 32만 장의 X-ray)에 대해, CheXpert 14개 병변 라벨을 예측하는 구조로 BCE 손실 기반 학습이 이루어진다.

과적합을 방지하기 위해 필요한 경우 조기 종료(Early Stopping)*와 X-ray 회전*, 좌우 반전*과 같은 데이터 증강도 적용되었다.

*조기 종료 (Early Stopping): 검증 성능이 일정 에폭 동안 향상되지 않으면 학습을 조기에 중단하여 과적합을 방지하는 정규화 기법

X-ray 회전 (X-ray Rotation): 모델이 방향 변화에 강건하도록 학습하게 하기 위함

좌우 반전 (Horizontal Flip): 영상 이미지를 좌우로 뒤집어 모델이 대칭적 특징에도 일반화되도록 학습하게 하기 위함

 

2단계: 다중모달 모델 공동 미세조정 (Fine-tuning)

사전학습이 완료된 후에는 인코더의 분류기 부분을 제거하고, 남은 특징 추출부만 결합하여 MedFuse 전체 모델을 구성한다.

이후 해당 과제의 학습 세트 전체(EHR + CXR PARTIAL)를 이용해 end-to-end로 fine-tuning을 진행한다.

이때 EHR-only 사례도 포함되기 때문에, 학습 샘플 수가 극대화된다.

(하지만, EHR-only 사례가 지나치게 많아지면, 모델이 영상 정보를 무시하고 EHR에만 의존하게 될 수 있다. → 이를 방지하기 위해, 연구진은 EHR-only 샘플의 사용 비율을 조정하는 실험을 수행하였다. AUROC가 가장 높은 조건을 찾는 방식으로 실험을 진행하였으며, 해당 실험을 통해 구한 최적 설정 값을 논문에서는 MedFuse(OPTIMAL)로 명시하고 있다.)

 

기타

Fine-tuning 시에는 기존 인코더 가중치의 학습률은 낮추고, 투영층(ϕ), LSTM 융합 모듈, 최종 분류기 등 새로운 모듈은 상대적으로 높은 학습률로 학습되었다.

전체 훈련은 최대 50 epoch 동안 진행되었으며, 검증 AUROC이 15 epoch 연속으로 향상되지 않을 경우 조기 종료가 적용되었다.

최종적으로는 검증 세트에서 가장 높은 AUROC을 기록한 모델을 선택해 테스트 성능을 평가하였다.

 

비교 대상 모델 (Baseline Models)

MedFuse의 성능을 평가하기 위해, 다양한 기존 다중모달 융합 기법들을 동일한 데이터셋과 조건에서 재현하여 비교하였다.

각 비교 모델은 MedFuse와 동일한 환자 분할, 입력 구조, 사전학습 여부를 기준으로 설정되었다.

 

모델명 결합 방식   학습 방식 데이터 유형 기타 특징 및 참고
Early Fusion 단순히 Concat 인코더는 사전학습 후 고정, 분류기만 학습 ① 페어 데이터② 부분 페어 (X-ray 결측 시 Learnable Embedding 사용) 구조 단순, 학습 안정적
Joint Fusion Contact 인코더부터 분류기까지 end-to-end 공동 학습 ① 페어 데이터② 부분 페어 (Learnable Embedding) 모달 간 상호작용 학습 가능, 학습 불안정 가능성
MMTM Mmm (multi-modal transfer modul). Attention 기반 통계 정보 교환 ResNet-34와 LSTM 사이 MMTM 삽입End-to-end 학습 페어 데이터  
DAFT Affine transformation 영상 feature로 EHR 은닉 상태 변환End-to-end 학습 페어 데이터  
Unified 모달 결합 입력 다중모달 입력 + 분기 구조부분 페어 가능 부분 페어 포함 MedFuse 이전 버전, 구조 복잡
Stacked LSTM (EHR-only) 단일 모달 2-layer LSTM + 분류기 EHR-only 다중모달을 사용하지 않은 기본 성능 비교용 기준선

 

실험 결과 요약

실험은 크게 단일 모달 vs. 다중모달, 완전 페어 vs. 부분 페어 데이터, 표현형 유형, 환자 연령대별 성능 등 다양한 관점에서 진행되었다.

1. MedFuse는 단일 모달보다 일관되게 우수한 성능

  • 두 예측 과제(사망 예측, 표현형 분류) 모두에서 EHR-only 모델 대비 AUROC·AUPRC 모두 향상
  • 특히 사망 예측에서 영상 정보의 기여가 뚜렷 (AUROC +4%, AUPRC +9.4%)

2. 부분 페어 상황에서도 높은 실용성

  • 대부분의 현실적 상황처럼 영상이 없는 경우(EHR-only)에도 성능 유지
  • 영상이 있는 경우에는 더욱 우수한 성능 확보
  • 혼합형 질환에서 가장 큰 효과, 만성 질환에서도 의미 있는 향상, 급성 질환에서는 제한적

3. 연령대별 분석 결과, 40–60세 중장년층에서 성능 향상이 가장 큼

  • AUPRC 24% 향상
  • 시계열 + 영상 정보가 모두 중요한 고위험군에서 모델의 임상적 유용성 높음

4. 기존 다중모달 모델들과 비교해 종합적 성능 우위

  • MedFuse (OPTIMAL)이 모든 비교 모델(Early Fusion, Joint Fusion, MMTM, DAFT, Unified) 중 최고 성능
  • AUROC 기준, 사망 예측 +2.7%p, 표현형 분류 +1.8%p 우세
  • AUPRC에서도 가장 우수

5. 단순한 구조이지만 강력한 성능

  • 복잡한 구조의 MMTM, DAFT, Unified에 비해 오히려 성능이 더 높고, 학습 안정성도 우수
  • 특히 부분 페어 환경에서의 유연성과 실용성이 뛰어남

 

한계점 및 향후 연구 방향

1. 평가 과제의 범위가 제한적임

해당 논문에서는 임상 시계열 데이터(EHR)와 흉부 X-ray 영상이라는 두 가지 모달리티를 활용해, 병원 내 사망 예측과 환자 표현형 분류라는 두 가지 예측 과제에 집중하였다.

하지만 실제 중환자실(ICU) 환경에서는 상태 악화 예측(decompensation), 입원 기간(length of stay) 예측 등 다양한 임상 과제가 중요하게 다뤄진다.

향후 연구에서는 MedFuse의 구조를 이러한 과제들에도 적용해보고, 다중모달 융합 방식의 일반화 가능성을 평가하는 것이 필요하다.

 

2. EHR 결측 상황에 대한 고려 부족

현재 MedFuse는 EHR 시계열 데이터가 항상 존재하고, 영상(X-ray)이 선택적으로 존재하는 시나리오만을 다루고 있다.

그러나 실제 임상에서는 영상만 존재하고 EHR이 없는 경우(예: 응급실에서 영상만 촬영된 경우)도 빈번하다.

MedFuse는 EHR을 필수로 요구하는 구조이기 때문에, 이러한 상황에는 적용이 어렵다.

따라서 향후에는 영상 단독 예측, 또는 랩 검사·병리 정보와 영상의 조합 등 다양한 모달리티 구성에도 대응할 수 있는 유연한 융합 방식이 요구된다.

 

3. 모달리티 수 확장에 대한 실험 부족

본 논문은 두 가지 모달리티(EHR, X-ray)만을 대상으로 했지만, 실제 임상 환경에서는 임상 노트, CT·MRI 영상, 유전체 정보, 실시간 모니터링 데이터 등 훨씬 더 다양한 데이터 유형이 존재한다.

MedFuse는 LSTM 기반 구조로 다수의 입력을 순차적으로 처리할 수 있지만, 모달리티 수가 증가할수록 모델 복잡도 및 계산량이 늘어나며 성능이 영향을 받을 가능성이 있다.

따라서, 3개 이상의 모달리티를 동시에 융합한 구조에 대한 실험과, MedFuse 아키텍처의 확장성 및 한계에 대한 평가가 뒤따라야 한다.

 

4. 모델 해석 가능성 부족

현재 MedFuse는 블랙박스 형태의 예측 모델로, 각 모달리티 또는 입력 피처가 최종 예측에 얼마나 기여했는지 설명하기 어렵다.

하지만 의료 AI에서는 해석 가능성(interpretability)티 필수적인 요소이며, 이를 위해 입력 단계 또는 융합 단계에 어텐션 메커니즘을 적용해볼 수 있다.

예를 들어, 시간별 임상 피처나 영상의 국소 부위에 어텐션 가중치를 부여하고, 융합된 시점에서 모달리티 간의 중요도를 산출하는 방식으로 예측 근거에 대한 직관적인 설명을 제공할 수 있을 것이다.

이는 임상 전문가와의 소통, 예측 신뢰도 향상에도 큰 도움이 된다.

 

5. 멀티모달 학습 연구의 지속 필요성

MedFuse는 공개 데이터셋을 기반으로 다중모달 융합이 실제 예측 성능을 향상시킬 수 있음을 보여주었으며, 실용적 벤치마크로서의 가치도 지닌다.

저자들은 의료 데이터의 양과 다양성이 급속히 증가하는 현 상황에서, 이를 효과적으로 융합할 수 있는 모델이 필수적이라고 강조한다.

향후 연구는 EHR, 영상, 병리 이미지, 실시간 센서 데이터, 자연어 임상 기록 등 복합적인 모달리티를 통합한 환자 상태 예측 모델 개발로 이어질 수 있으며,

이는 궁극적으로 정밀의료 기반의 AI 시스템 실현으로 나아가는 기반이 될 것이다.

728x90
반응형