Skip to content

DuPLUS: A dual-prompt vision-language framework for universal medical image segmentation across multiple datasets and modalities using hierarchical text control.

License

Notifications You must be signed in to change notification settings

BioMedIA-MBZUAI/DuPLUS

Repository files navigation

DuPLUS: Dual-Prompt Vision-Language Framework for Universal Medical Image Segmentation and Prognosis

Paper License: MIT

This repository contains the official PyTorch implementation for DuPLUS, a novel vision-language framework for universal medical image analysis. DuPLUS uses a unique dual-prompt mechanism to control a single model for segmenting diverse anatomies across multiple modalities and seamlessly extends to prognosis prediction.

DuPLUS Architecture Diagram Figure 1: The DuPLUS architecture, showcasing the dual-prompt mechanism for hierarchical text control.


💡 Overview

Current deep learning models in medical imaging are often task-specific, limiting their generalizability and clinical utility. DuPLUS addresses these limitations with a single, unified framework that can interpret diverse imaging data and perform multiple tasks based on textual instructions.

Key Contributions

Hierarchical Dual-Prompt Mechanism: DuPLUS uses two distinct text prompts for granular control:

  • T_1 (Context Prompt): Conditions the network's encoder-decoder on the broad context, such as imaging modality (CT, MRI, PET) and anatomical region.
  • T_2 (Target Prompt): Steers the prediction head to perform a specific task, such as segmenting a particular organ or tumor.

Universal Segmentation: A single DuPLUS model successfully segments over 30 different organs and tumors across 10 public datasets and 3 imaging modalities, outperforming task-specific and other universal models.

Task Extensibility: The framework is seamlessly extended from segmentation to prognosis prediction by integrating Electronic Health Record (EHR) data, achieving a Concordance Index (CI) of 0.69±0.02 on the HECKTOR dataset.

Parameter-Efficient Adaptation: Utilizes techniques like Low-Rank Adaptation (LoRA) to adapt to new tasks (like prognosis) by fine-tuning only a small fraction of the model's parameters.

🚀 Results

DuPLUS's text-driven control allows for on-the-fly changes to the model's target.

Demonstration of fine-grained control: by changing only the target prompt (T_2), the model deterministically switches the segmented organ on the same CT slice.

💾 Supported Datasets

The framework was trained and validated on 11 diverse, publicly available medical imaging datasets:

  • CT: BCV, LiTS, KiTS, AMOS CT, StructSeg (Thorax & Head/Neck)
  • MRI: AMOS MR, CHAOS, M&Ms, Brain Structures
  • PET: AutoPET

⚙️ Installation

Requirements

  • Python 3.8+
  • PyTorch 1.12+
  • CUDA-compatible GPU (16GB+ VRAM recommended for 3D training)

Setup Environment

# Clone this repository
git clone https://github.com/your-username/DuPLUS.git
cd DuPLUS

# Create and activate a virtual environment
python -m venv venv
source venv/bin/activate  # On Windows use `venv\Scripts\activate`

# Install dependencies
pip install -r requirements.txt

Install PyTorch with CUDA Support

For CUDA 11.8:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

For CUDA 12.1:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

🏃‍♀️ Training

Configuration

Model and training configurations are defined in YAML files within the config/universal/ directory. The main configuration file is universal_resunet_3d.yaml which includes dataset paths, model architecture parameters, and training settings.

# In config/universal/universal_resunet_3d.yaml
arch: resunet
in_chan: 1
base_chan: 32
use_film: false  # Flag to switch between FiLM and dynamic convolution

# Dataset configuration  
dataset_name_list: ['structseg_head_oar','amos_ct', 'amos_mr', 'bcv', 'structseg_oar', 'lits', 'kits', 'mnm', 'brain_structure', 'autopet', 'chaos']
dataset_classes_list: [22, 15, 13, 13, 6, 2, 2, 3, 3, 1, 4] # number of classes per dataset

# Text embedding paths
emb_pth: './text_embeddings.pth'
meta_pth: './text_metadata.json'
emb_mod_pth: './modality_embeddings.pth'

# Training parameters
epochs: 400
training_size: [128, 128, 128]
base_lr: 0.002
optimizer: lamb

Example Training Commands

The training script uses prompts defined in the configuration file to condition the model.

# Train the universal DuPLUS model
python train.py --dataset universal --model universal_resunet --dimension 3d --amp --batch_size 8 --unique_name DuPLUS_universal_run

# Train with FiLM modulation instead of dynamic convolution
python train.py --dataset universal --model universal_resunet --dimension 3d --use_film --amp --batch_size 8 --unique_name DuPLUS_film_run

# Multi-GPU training
python train.py --dataset universal --model universal_resunet --dimension 3d --gpu 0,1,2,3 --batch_size 8 --unique_name DuPLUS_multigpu_run

Advanced Training Options

# Training with gradient accumulation for effective larger batch sizes
python train.py --dataset universal --model universal_resunet --dimension 3d --batch_size 4 --gradient_accumulation_steps 2 --unique_name DuPLUS_gradacc_run

# Custom experiment paths
python train.py --dataset universal --model universal_resunet --dimension 3d --cp_path ./experiments/ --unique_name DuPLUS_custom_path

Model Architecture

DuPLUS uses a ResUNet backbone with dynamic feature modulation:

  • universal_resunet: 3D ResUNet with configurable feature modulation (dynamic convolution or FiLM)

Feature Modulation Methods

  1. Dynamic Convolution (default): Adaptive convolution kernels with higher expressiveness
  2. FiLM (Feature-wise Linear Modulation): Linear transformation of feature maps with efficient parameter usage (enable with --use_film)

📊 Evaluation

The framework automatically evaluates on validation and test sets during training. Results are logged to:

  • Tensorboard logs: experiments/[experiment_name]/tensorboard/
  • Text logs: experiments/[experiment_name]/log.txt
  • Model checkpoints: experiments/[experiment_name]/best.pth

Metrics include:

  • Dice Similarity Coefficient (DSC)
  • Per-class and overall performance
  • Cross-dataset generalization

📄 License

This project is licensed under the MIT License. See the LICENSE file for details.

🙏 Acknowledgments

We thank the creators of the numerous public datasets that made this work possible. This research builds upon the foundational work in vision-language models and medical image analysis from the broader scientific community.

About

DuPLUS: A dual-prompt vision-language framework for universal medical image segmentation across multiple datasets and modalities using hierarchical text control.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages