Multimodal Retrieval-augmented Expansion, Fusion, and Generation — extends the text-only REFRAG (Lin et al., 2025) framework to jointly handle image and text evidence chunks.
Given a question and a mixed pool of retrieved image + text chunks, the model learns a selective expansion policy that decides which chunks to show the decoder in full detail (expanded) and which to compress into a single embedding, balancing answer quality against computational cost.
┌──────────────────────────────┐
│ Retrieved Chunks │
│ (images + text, shuffled) │
└───────────┬──────────────────┘
│
┌───────────────┼───────────────┐
▼ ▼ ▼
┌──────────────┐ ┌──────────────┐
│ CLIP ViT-H-14│ │ RoBERTa-base │ (FROZEN encoders)
│ (images) │ │ (text) │
└──────┬───────┘ └──────┬───────┘
│ │
▼ ▼
CLS [1,1280] CLS [1,768]
Patches [256,1280] Tokens [T,768]
│ │
┌─────────┴────────────────┴─────────┐
│ MultimodalSelectPolicy │ ← TRAINED (Phase B)
│ per-modality proj → shared MLP │
│ → Bernoulli expand logit per chunk│
└─────────────────┬──────────────────┘
│ expand mask z ∈ {0,1}^N
▼
┌─────────────────────────────────────┐
│ Build Context Sequence │
│ │
│ Expanded image → VisualProjector │
│ [256, D_dec] │
│ Compressed image→ VisualProjector │
│ [1, D_dec] │
│ Expanded text → decoder embed │
│ [T, D_dec] │
│ Compressed text → TextProjector │
│ [1, D_dec] │
└─────────────────┬───────────────────┘
│
▼
[Q_emb | ctx_chunk_1 | ... | ctx_chunk_N]
│
▼
┌─────────────────────┐
│ Decoder LLM (LoRA) │ ← Qwen3-30B-A3B (default, MoE 3B active)
│ Causal LM │
└──────────┬──────────┘
│
▼
Answer
| Module | File | Description |
|---|---|---|
MMREFRAGConfig |
config.py |
All hyperparameters in one dataclass |
WebQADataset |
dataset.py |
Loads WebQA JSON + resolves images from imgs.tsv |
ImageChunkEncoder |
encoders.py |
CLIP ViT-H-14 (frozen) — CLS + 256 patch tokens |
TextChunkEncoder |
encoders.py |
RoBERTa-base (frozen) — CLS + all tokens |
VisualProjector |
projectors.py |
2-layer MLP: CLIP 1280d -> decoder dim |
TextProjector |
projectors.py |
2-layer MLP: RoBERTa 768d -> decoder dim |
MultimodalSelectPolicy |
policy.py |
Per-chunk expand/compress decision (Bernoulli) |
GRPOSampler |
policy.py |
Group Relative Policy Optimisation sampler |
MultimodalREFRAG |
model.py |
Full model: encode -> policy -> project -> decode |
contrastive_loss |
losses.py |
Optional InfoNCE auxiliary loss |
train.py |
train.py |
Three-phase training CLI + evaluation + generation |
Training follows a three-phase curriculum, each building on the previous:
Goal: Train VisualProjector and TextProjector to map encoder CLS tokens into the decoder's embedding space.
| Trained | VisualProjector, TextProjector |
| Frozen | Encoders, Decoder, Policy |
| Method | CLS embedding -> project -> repeat T times -> decoder reconstructs caption/text |
| Loss | Cross-entropy reconstruction loss |
| Best ckpt | Lowest training loss |
Goal: Bridge the gap between projected embeddings and the decoder. The decoder has never seen compressed embeddings as context after Phase A — Phase CPT closes this gap.
Three interleaved tasks:
- Text next-chunk prediction: compressed text chunk -> predict next chunk tokens
- Image caption prediction: compressed image -> predict caption tokens
- Mixed-context answer generation: random expand mask -> teacher-force answer
| Trained | VisualProjector, TextProjector, Decoder (LoRA) |
| Frozen | Encoders, Policy |
| Loss | Cross-entropy (sum of all three tasks) |
| Best ckpt | Lowest val answer-generation loss (Task 3 only) |
Goal: Train the expansion policy via Group Relative Policy Optimisation (GRPO).
For each question, G expand masks are sampled. Each trajectory gets a reward:
R = -CE_loss(answer | Q, context)
# CE_loss: teacher-forced cross-entropy of the decoder on the
# ground-truth answer, given question + context from the expand mask.
# Expansion cost is controlled by max_expand_fraction hard constraint.
advantage_i = (R_i - mean(R)) / std(R) # GRPO group-relative
policy_loss = -mean_i(advantage_i * log pi(z_i | Q, C))
| Trained | SelectPolicy only |
| Frozen | Encoders, Projectors, Decoder |
| Loss | GRPO policy loss |
| Best ckpt | Highest val primary metric (EM for YesNo, F1 for others) |
| Chunk Type | Compress Mode | Expand Mode |
|---|---|---|
| Text | RoBERTa CLS -> TextProjector -> [1, D_dec] | Original text -> decoder embedding layer -> [T, D_dec] |
| Image | CLIP CLS -> VisualProjector -> [1, D_dec] | CLIP 256 patch tokens -> VisualProjector -> [256, D_dec] |
Each sample contains:
- Q: question string
- A: reference answer
- img_posFacts: 1-2 positive images (support the answer)
- img_negFacts: ~10 hard negative images (same topic, wrong answer)
- txt_posFacts: positive text facts (often empty for image-centric questions)
- txt_negFacts: negative text facts
The pos/neg labels are NOT used to directly supervise the policy (that would cause train/test mismatch). Instead, they define what evidence supports the answer. The policy learns via GRPO: selecting positive chunks leads to lower CE loss (higher reward), which reinforces the selection.
<data_root>/
data/
WebQA_train_val.json # Main dataset JSON
imgs.tsv # Base64-encoded images (tab-separated: image_id \t b64data)
imgs.lineidx # Byte offsets for random access into imgs.tsv
# Clone the repository
git clone <repo_url>
cd multimodal_refrag
# Install dependencies
pip install -r requirements.txt- Python >= 3.9
- PyTorch >= 2.1.0
- CUDA (recommended, ~24GB VRAM for Qwen3-30B-A3B with MoE 3B active params)
- MPS (Apple Silicon) and CPU are also supported but significantly slower
| Phase | Approximate VRAM | Notes |
|---|---|---|
| Phase A | ~16 GB | Decoder frozen, only projector gradients |
| Phase CPT | ~20 GB | Decoder LoRA gradients |
| Phase B | ~24 GB | G trajectories sampled per step (G=4 default) |
Use --max_samples N to limit dataset size for debugging or limited hardware.
python -m multimodal_refrag.train phase_a \
--data_root . \
--output_dir runs/mm_refrag \
--decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
--epochs 3 --batch_size 8 --lr 1e-4python -m multimodal_refrag.train phase_cpt \
--data_root . \
--output_dir runs/mm_refrag \
--decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
--phase_a_ckpt runs/mm_refrag/phase_a_best.pt \
--epochs 3 --batch_size 4 --lr 5e-5 \
--lora_r 16python -m multimodal_refrag.train phase_b \
--data_root . \
--output_dir runs/mm_refrag \
--decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
--phase_cpt_ckpt runs/mm_refrag/phase_cpt_best.pt \
--epochs 5 --batch_size 4 \
--grpo_G 4 \
--policy_lr 1e-4 --lora_r 16python -m multimodal_refrag.train generate \
--data_root . \
--output_dir runs/mm_refrag \
--decoder Qwen/Qwen3-30B-A3B-Instruct-2507 \
--ckpt runs/mm_refrag/phase_b_best.pt \
--split val \
--max_new_tokens 128 --temperature 0.0Results are saved to runs/mm_refrag/generate_<split>.json.
All training phases support torchrun for multi-GPU distributed training.
Use CUDA_VISIBLE_DEVICES to select which GPUs to use:
# Phase A on GPU 0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
-m multimodal_refrag.train phase_a \
--data_root . --output_dir runs/mm_refrag \
--epochs 3 --batch_size 8
# Phase CPT on GPU 2,5 only
CUDA_VISIBLE_DEVICES=2,5 torchrun --nproc_per_node=2 \
-m multimodal_refrag.train phase_cpt \
--data_root . --output_dir runs/mm_refrag \
--phase_a_ckpt runs/mm_refrag/phase_a_best.pt \
--epochs 3 --lora_r 16
# Phase B on GPU 0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
-m multimodal_refrag.train phase_b \
--data_root . --output_dir runs/mm_refrag \
--phase_cpt_ckpt runs/mm_refrag/phase_cpt_best.pt \
--epochs 5 --grpo_G 8 --policy_lr 1e-4Checkpoints, logging, and validation are only performed on rank 0.
| Argument | Default | Description |
|---|---|---|
--data_root |
multimodal_refrag |
Project root (expects data/ subdirectory) |
--output_dir |
runs/mm_refrag |
Checkpoint and log output directory |
--decoder |
Qwen/Qwen3-30B-A3B-Instruct-2507 |
HuggingFace decoder model name |
--seed |
1337 |
Random seed |
--max_neg_imgs |
4 |
Max negative images sampled per example |
--max_neg_txts |
4 |
Max negative text facts sampled per example |
| Argument | Default | Description |
|---|---|---|
--epochs |
3 |
Training epochs |
--batch_size |
8 |
Batch size |
--lr |
1e-4 |
Learning rate for projectors |
--max_samples |
None |
Limit dataset size |
| Argument | Default | Description |
|---|---|---|
--phase_a_ckpt |
None |
Path to Phase A checkpoint |
--epochs |
3 |
Training epochs |
--batch_size |
4 |
Batch size |
--lr |
5e-5 |
Learning rate |
--expand_frac |
0.1 |
Fraction of chunks randomly expanded in Task 3 |
--lora_r |
16 |
LoRA rank for decoder |
--max_samples |
None |
Limit dataset size |
| Argument | Default | Description |
|---|---|---|
--phase_cpt_ckpt |
None |
Phase CPT checkpoint (preferred) |
--phase_a_ckpt |
None |
Phase A checkpoint (fallback) |
--epochs |
5 |
Training epochs |
--batch_size |
4 |
Batch size |
--policy_lr |
1e-4 |
Policy learning rate |
--grpo_G |
4 |
GRPO group size (trajectories per question) |
--max_expand_frac |
0.5 |
Max fraction of chunks that can be expanded |
--lora_r |
16 |
LoRA rank for decoder |
--max_samples |
None |
Limit dataset size |
| Argument | Default | Description |
|---|---|---|
--ckpt |
(required) | Checkpoint path to load |
--split |
val |
Dataset split (val, test, train) |
--max_new_tokens |
128 |
Max tokens to generate |
--temperature |
0.0 |
Sampling temperature (0 = greedy) |
--top_p |
1.0 |
Nucleus sampling threshold |
--lora_r |
16 |
LoRA rank for decoder |
--max_samples |
None |
Limit dataset size |
| Metric | Description | Used For |
|---|---|---|
| Primary | EM for YesNo questions, token-F1 for all others | Best checkpoint selection |
| Token F1 | Precision/recall over word tokens (case-insensitive) | Answer quality |
| EM | Exact match (case-insensitive, stripped) | Answer quality |
| Expand Ratio | Mean fraction of chunks expanded | Efficiency measure |
Each phase saves:
phase_{a,cpt,b}_epoch{N}.pt— per-epoch checkpointphase_{a,cpt,b}_best.pt— best checkpoint by primary metric
Checkpoint contents:
{
"visual_proj": state_dict, # VisualProjector weights
"text_proj": state_dict, # TextProjector weights
"policy": state_dict, # SelectPolicy weights
"decoder_lora": state_dict, # LoRA adapter weights (if peft available)
# + extra metadata (epoch, loss, val metrics)
}multimodal_refrag/
__init__.py # Package marker
config.py # MMREFRAGConfig dataclass
dataset.py # WebQADataset + ImgsTSVReader
encoders.py # ImageChunkEncoder (CLIP) + TextChunkEncoder (RoBERTa)
projectors.py # VisualProjector + TextProjector (2-layer MLP)
policy.py # MultimodalSelectPolicy + GRPOSampler
model.py # MultimodalREFRAG (full model)
losses.py # contrastive_loss (optional auxiliary)
train.py # CLI: phase_a / phase_cpt / phase_b / generate
requirements.txt # Python dependencies
README.md # This file
-
GRPO with -CE_loss reward: Phase B uses Group Relative Policy Optimisation with reward = -CE_loss (teacher-forced cross-entropy on the answer). CE loss provides dense, continuous signal even when the decoder's generation quality is poor. The advantage is normalized by group std, and the baseline is the group mean — no EMA state needed.
-
Asymmetric expand design: Expanded text chunks use the decoder's own embedding table (no projector), while expanded images go through the VisualProjector. This is because the decoder already understands its own token embeddings natively.
-
Three-phase curriculum: Phase A aligns projectors, Phase CPT teaches the decoder to consume projected embeddings, Phase B trains only the policy. Each phase has a clear and narrow learning objective.
-
pos/neg labels unused in training loss: The policy is never directly told which chunks are positive. It discovers this through reward signal (GRPO), ensuring no train/test mismatch.
- Lin et al. (2025). REFRAG: Retrieval-augmented Expansion, Fusion, and Generation.
- Shao et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning. (GRPO)
- Cherti et al. (2023). Reproducible scaling laws for contrastive language-image learning. (OpenCLIP)
- Chang & Chen (2022). WebQA: Multihop and Multimodal QA.