- ✅
R1-SGG-7B, R1-SGG-Zero-7B
- ✅ Support PSG dataset (bbox format only, not Panoptic)
- ✅ Updated loss implementation
- ✅ Always use
custom_per_device_train_batch_sizeinstead ofper_device_train_batch_sizefor faster sampling under gradient accumulation ⚠️ Current loss implementation might still be affected by gradient accumulation: trl issue #3021
bash install.shMain dependencies:
- torch == 2.5.0 or 2.5.1 (cu124, optional)
- transformers (supports Qwen2VL, Qwen2.5VL)
- trl
- vLLMLoad preprocessed datasets via:
from datasets import load_dataset
db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"]
db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"]or for PSG:
db_train = load_dataset("JosephZ/psg_train_sg")["train"] # keys: image_id, image, objects, relationships
db_val = load_dataset("JosephZ/psg_test_sg")["train"]We transformed VG150 into HuggingFace Datasets format with keys:
image_idimageprompt_openprompt_closeobjectsrelationships
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
- Qwen/Qwen2.5-VL-3B-Instruct
- Qwen/Qwen2.5-VL-7B-Instruct
For SLURM users:
sbatch scripts/sft/7B_sgg.sh For local machines:
bash scripts/sft_local/7B_sgg.sh⏱️ Approximate training time:
- 2B models: ~4 hours (4×A100 SXM4 GPUs)
- 7B models: ~10 hours (4×A100 SXM4 GPUs)
** Update (11/05/2025): to use "Hard Recall"**:
--reward_funcs format_reward edge_hard_reward
For A100 GPUs:
sbatch scripts/grpo/train_a100_2B.sh(12 hours on 16×A100 GPUs)
For GH200 GPUs:
sbatch scripts/grpo/train_gh200.sh(16 hours on 16×GH200 GPUs)
For clusters with many RTX_3090/4090 GPUs:
sbatch scripts/grpo/train_fused.sh- Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks.
- (Fun fact: training with 120×RTX_4090 is crazy but severely limited by communication latency.)
💡 Recommended learning rate: 6e-7.
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIRFor models trained with predefined categories, add true:
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR truebash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true trueDATASET_TYPE=vg # or psg
python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json
python src/vg150_eval.py $DATASET sgg_pred_results.jsonThe GRPOTrainer used in this project is based on trl's GRPOTrainer, extended to support multimodal inputs.
If you find this work helpful, please cite:
@article{chen2025compile,
title={Compile Scene Graphs with Reinforcement Learning},
author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen},
journal={arXiv preprint arXiv:2504.13617},
year={2025}
}