Skip to content

minw1012/multimodal_refrag

Repository files navigation

Multimodal REFRAG

Multimodal Retrieval-augmented Expansion, Fusion, and Generation — extends the text-only REFRAG (Lin et al., 2025) framework to jointly handle image and text evidence chunks.

Given a question and a mixed pool of retrieved image + text chunks, the model learns a selective expansion policy that decides which chunks to show the decoder in full detail (expanded) and which to compress into a single embedding, balancing answer quality against computational cost.


Architecture Overview

                        ┌──────────────────────────────┐
                        │      Retrieved Chunks        │
                        │  (images + text, shuffled)   │
                        └───────────┬──────────────────┘
                                    │
                    ┌───────────────┼───────────────┐
                    ▼               ▼               ▼
            ┌──────────────┐ ┌──────────────┐
            │ CLIP ViT-H-14│ │ RoBERTa-base │    (FROZEN encoders)
            │   (images)   │ │   (text)     │
            └──────┬───────┘ └──────┬───────┘
                   │                │
                   ▼                ▼
            CLS [1,1280]     CLS [1,768]
            Patches [256,1280]  Tokens [T,768]
                   │                │
         ┌─────────┴────────────────┴─────────┐
         │       MultimodalSelectPolicy       │  ← TRAINED (Phase B)
         │  per-modality proj → shared MLP    │
         │  → Bernoulli expand logit per chunk│
         └─────────────────┬──────────────────┘
                           │  expand mask z ∈ {0,1}^N
                           ▼
         ┌─────────────────────────────────────┐
         │        Build Context Sequence       │
         │                                     │
         │  Expanded image  → VisualProjector  │
         │                    [256, D_dec]      │
         │  Compressed image→ VisualProjector  │
         │                    [1, D_dec]        │
         │  Expanded text   → decoder embed    │
         │                    [T, D_dec]        │
         │  Compressed text → TextProjector    │
         │                    [1, D_dec]        │
         └─────────────────┬───────────────────┘
                           │
                           ▼
              [Q_emb | ctx_chunk_1 | ... | ctx_chunk_N]
                           │
                           ▼
                ┌─────────────────────┐
                │  Decoder LLM (LoRA) │  ← Qwen3-30B-A3B (default, MoE 3B active)
                │  Causal LM          │
                └──────────┬──────────┘
                           │
                           ▼
                        Answer

Key Components

Module File Description
MMREFRAGConfig config.py All hyperparameters in one dataclass
WebQADataset dataset.py Loads WebQA JSON + resolves images from imgs.tsv
ImageChunkEncoder encoders.py CLIP ViT-H-14 (frozen) — CLS + 256 patch tokens
TextChunkEncoder encoders.py RoBERTa-base (frozen) — CLS + all tokens
VisualProjector projectors.py 2-layer MLP: CLIP 1280d -> decoder dim
TextProjector projectors.py 2-layer MLP: RoBERTa 768d -> decoder dim
MultimodalSelectPolicy policy.py Per-chunk expand/compress decision (Bernoulli)
GRPOSampler policy.py Group Relative Policy Optimisation sampler
MultimodalREFRAG model.py Full model: encode -> policy -> project -> decode
contrastive_loss losses.py Optional InfoNCE auxiliary loss
train.py train.py Three-phase training CLI + evaluation + generation

Training Pipeline

Training follows a three-phase curriculum, each building on the previous:

Phase A — Reconstruction Curriculum (Projector Alignment)

Goal: Train VisualProjector and TextProjector to map encoder CLS tokens into the decoder's embedding space.

Trained VisualProjector, TextProjector
Frozen Encoders, Decoder, Policy
Method CLS embedding -> project -> repeat T times -> decoder reconstructs caption/text
Loss Cross-entropy reconstruction loss
Best ckpt Lowest training loss

Phase CPT — Multimodal Continual Pre-Training

Goal: Bridge the gap between projected embeddings and the decoder. The decoder has never seen compressed embeddings as context after Phase A — Phase CPT closes this gap.

Three interleaved tasks:

  1. Text next-chunk prediction: compressed text chunk -> predict next chunk tokens
  2. Image caption prediction: compressed image -> predict caption tokens
  3. Mixed-context answer generation: random expand mask -> teacher-force answer
Trained VisualProjector, TextProjector, Decoder (LoRA)
Frozen Encoders, Policy
Loss Cross-entropy (sum of all three tasks)
Best ckpt Lowest val answer-generation loss (Task 3 only)

Phase B — GRPO End-to-End

Goal: Train the expansion policy via Group Relative Policy Optimisation (GRPO).

For each question, G expand masks are sampled. Each trajectory gets a reward:

R = -CE_loss(answer | Q, context)

# CE_loss: teacher-forced cross-entropy of the decoder on the
# ground-truth answer, given question + context from the expand mask.
# Expansion cost is controlled by max_expand_fraction hard constraint.

advantage_i = (R_i - mean(R)) / std(R)   # GRPO group-relative
policy_loss = -mean_i(advantage_i * log pi(z_i | Q, C))
Trained SelectPolicy only
Frozen Encoders, Projectors, Decoder
Loss GRPO policy loss
Best ckpt Highest val primary metric (EM for YesNo, F1 for others)

Compress / Expand Design

Chunk Type Compress Mode Expand Mode
Text RoBERTa CLS -> TextProjector -> [1, D_dec] Original text -> decoder embedding layer -> [T, D_dec]
Image CLIP CLS -> VisualProjector -> [1, D_dec] CLIP 256 patch tokens -> VisualProjector -> [256, D_dec]

Dataset: WebQA

Each sample contains:

  • Q: question string
  • A: reference answer
  • img_posFacts: 1-2 positive images (support the answer)
  • img_negFacts: ~10 hard negative images (same topic, wrong answer)
  • txt_posFacts: positive text facts (often empty for image-centric questions)
  • txt_negFacts: negative text facts

The pos/neg labels are NOT used to directly supervise the policy (that would cause train/test mismatch). Instead, they define what evidence supports the answer. The policy learns via GRPO: selecting positive chunks leads to lower CE loss (higher reward), which reinforces the selection.

Data Layout

<data_root>/
  data/
    WebQA_train_val.json     # Main dataset JSON
    imgs.tsv                 # Base64-encoded images (tab-separated: image_id \t b64data)
    imgs.lineidx             # Byte offsets for random access into imgs.tsv

Installation

# Clone the repository
git clone <repo_url>
cd multimodal_refrag

# Install dependencies
pip install -r requirements.txt

Requirements

  • Python >= 3.9
  • PyTorch >= 2.1.0
  • CUDA (recommended, ~24GB VRAM for Qwen3-30B-A3B with MoE 3B active params)
  • MPS (Apple Silicon) and CPU are also supported but significantly slower

Hardware Requirements

Phase Approximate VRAM Notes
Phase A ~16 GB Decoder frozen, only projector gradients
Phase CPT ~20 GB Decoder LoRA gradients
Phase B ~24 GB G trajectories sampled per step (G=4 default)

Use --max_samples N to limit dataset size for debugging or limited hardware.


Usage

Phase A: Projector Alignment

python -m multimodal_refrag.train phase_a \
    --data_root . \
    --output_dir runs/mm_refrag \
    --decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
    --epochs 3 --batch_size 8 --lr 1e-4

Phase CPT: Continual Pre-Training

python -m multimodal_refrag.train phase_cpt \
    --data_root . \
    --output_dir runs/mm_refrag \
    --decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
    --phase_a_ckpt runs/mm_refrag/phase_a_best.pt \
    --epochs 3 --batch_size 4 --lr 5e-5 \
    --lora_r 16

Phase B: GRPO Policy Training

python -m multimodal_refrag.train phase_b \
    --data_root . \
    --output_dir runs/mm_refrag \
    --decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
    --phase_cpt_ckpt runs/mm_refrag/phase_cpt_best.pt \
    --epochs 5 --batch_size 4 \
    --grpo_G 4 \
    --policy_lr 1e-4 --lora_r 16

Generate (Inference)

python -m multimodal_refrag.train generate \
    --data_root . \
    --output_dir runs/mm_refrag \
    --decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
    --ckpt runs/mm_refrag/phase_b_best.pt \
    --split val \
    --max_new_tokens 128 --temperature 0.0

Results are saved to runs/mm_refrag/generate_<split>.json.

Multi-GPU Training (DDP)

All training phases support torchrun for multi-GPU distributed training. Use CUDA_VISIBLE_DEVICES to select which GPUs to use:

# Phase A on GPU 0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
    -m multimodal_refrag.train phase_a \
    --data_root . --output_dir runs/mm_refrag \
    --epochs 3 --batch_size 8

# Phase CPT on GPU 2,5 only
CUDA_VISIBLE_DEVICES=2,5 torchrun --nproc_per_node=2 \
    -m multimodal_refrag.train phase_cpt \
    --data_root . --output_dir runs/mm_refrag \
    --phase_a_ckpt runs/mm_refrag/phase_a_best.pt \
    --epochs 3 --lora_r 16

# Phase B on GPU 0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
    -m multimodal_refrag.train phase_b \
    --data_root . --output_dir runs/mm_refrag \
    --phase_cpt_ckpt runs/mm_refrag/phase_cpt_best.pt \
    --epochs 5 --grpo_G 8 --policy_lr 1e-4

Checkpoints, logging, and validation are only performed on rank 0.


CLI Reference

Common Arguments (all phases)

Argument Default Description
--data_root multimodal_refrag Project root (expects data/ subdirectory)
--output_dir runs/mm_refrag Checkpoint and log output directory
--decoder Qwen/Qwen3-30B-A3B-Instruct-2507 HuggingFace decoder model name
--seed 1337 Random seed
--max_neg_imgs 4 Max negative images sampled per example
--max_neg_txts 4 Max negative text facts sampled per example

Phase A Arguments

Argument Default Description
--epochs 3 Training epochs
--batch_size 8 Batch size
--lr 1e-4 Learning rate for projectors
--max_samples None Limit dataset size

Phase CPT Arguments

Argument Default Description
--phase_a_ckpt None Path to Phase A checkpoint
--epochs 3 Training epochs
--batch_size 4 Batch size
--lr 5e-5 Learning rate
--expand_frac 0.1 Fraction of chunks randomly expanded in Task 3
--lora_r 16 LoRA rank for decoder
--max_samples None Limit dataset size

Phase B Arguments

Argument Default Description
--phase_cpt_ckpt None Phase CPT checkpoint (preferred)
--phase_a_ckpt None Phase A checkpoint (fallback)
--epochs 5 Training epochs
--batch_size 4 Batch size
--policy_lr 1e-4 Policy learning rate
--grpo_G 4 GRPO group size (trajectories per question)
--max_expand_frac 0.5 Max fraction of chunks that can be expanded
--lora_r 16 LoRA rank for decoder
--max_samples None Limit dataset size

Generate Arguments

Argument Default Description
--ckpt (required) Checkpoint path to load
--split val Dataset split (val, test, train)
--max_new_tokens 128 Max tokens to generate
--temperature 0.0 Sampling temperature (0 = greedy)
--top_p 1.0 Nucleus sampling threshold
--lora_r 16 LoRA rank for decoder
--max_samples None Limit dataset size

Evaluation Metrics

Metric Description Used For
Primary EM for YesNo questions, token-F1 for all others Best checkpoint selection
Token F1 Precision/recall over word tokens (case-insensitive) Answer quality
EM Exact match (case-insensitive, stripped) Answer quality
Expand Ratio Mean fraction of chunks expanded Efficiency measure

Checkpoints

Each phase saves:

  • phase_{a,cpt,b}_epoch{N}.pt — per-epoch checkpoint
  • phase_{a,cpt,b}_best.pt — best checkpoint by primary metric

Checkpoint contents:

{
    "visual_proj":  state_dict,   # VisualProjector weights
    "text_proj":    state_dict,   # TextProjector weights
    "policy":       state_dict,   # SelectPolicy weights
    "decoder_lora": state_dict,   # LoRA adapter weights (if peft available)
    # + extra metadata (epoch, loss, val metrics)
}

Project Structure

multimodal_refrag/
    __init__.py         # Package marker
    config.py           # MMREFRAGConfig dataclass
    dataset.py          # WebQADataset + ImgsTSVReader
    encoders.py         # ImageChunkEncoder (CLIP) + TextChunkEncoder (RoBERTa)
    projectors.py       # VisualProjector + TextProjector (2-layer MLP)
    policy.py           # MultimodalSelectPolicy + GRPOSampler
    model.py            # MultimodalREFRAG (full model)
    losses.py           # contrastive_loss (optional auxiliary)
    train.py            # CLI: phase_a / phase_cpt / phase_b / generate
    requirements.txt    # Python dependencies
    README.md           # This file

Key Design Decisions

  1. GRPO with -CE_loss reward: Phase B uses Group Relative Policy Optimisation with reward = -CE_loss (teacher-forced cross-entropy on the answer). CE loss provides dense, continuous signal even when the decoder's generation quality is poor. The advantage is normalized by group std, and the baseline is the group mean — no EMA state needed.

  2. Asymmetric expand design: Expanded text chunks use the decoder's own embedding table (no projector), while expanded images go through the VisualProjector. This is because the decoder already understands its own token embeddings natively.

  3. Three-phase curriculum: Phase A aligns projectors, Phase CPT teaches the decoder to consume projected embeddings, Phase B trains only the policy. Each phase has a clear and narrow learning objective.

  4. pos/neg labels unused in training loss: The policy is never directly told which chunks are positive. It discovers this through reward signal (GRPO), ensuring no train/test mismatch.


References

  • Lin et al. (2025). REFRAG: Retrieval-augmented Expansion, Fusion, and Generation.
  • Shao et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning. (GRPO)
  • Cherti et al. (2023). Reproducible scaling laws for contrastive language-image learning. (OpenCLIP)
  • Chang & Chen (2022). WebQA: Multihop and Multimodal QA.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages