End-to-end GRPO (Group Relative Policy Optimization) pipeline for training reasoning models. Implements DeepSeek-R1's approach: synthetic data generation, distributed verification, and RL training without reward models.
Evaluated on full GSM8K test set (1319 problems):
| Metric | Base Model | GRPO Trained | Improvement |
|---|---|---|---|
| Pass@1 | 44.7% | 74.0% | +29.3% |
| Pass@4 | 75.1% | 88.0% | +12.9% |
| Pass@8 | 84.2% | 92.6% | +8.4% |
| Component | Metric | Value |
|---|---|---|
| SGLang | Generation | 3.5 min for 5K samples |
| SGLang | Throughput | 24K tokens/sec |
| Ray | Workers | 4 parallel, balanced distribution |
| GRPO | Trainable params | 0.07% (LoRA) |
| Pipeline | End-to-end | ~12 min on 2x H100 |
┌─────────────────────────────────────────────────────────────────┐
│ DISTRIBUTED REASONING LOOP │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ SGLang │ -> │ Ray │ -> │ GRPO │ │
│ │ Generation │ │ Verification │ │ Training │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │
│ • RadixAttention • 4 parallel • No reward model │
│ • Prefix caching workers • Group-relative │
│ • 10 paths/problem • SymPy verify advantages │
│ • Batched requests • 24K tok/sec • LoRA fine-tuning │
│ │
└─────────────────────────────────────────────────────────────────┘
# System packages (Ubuntu)
apt-get update && apt-get install -y python-is-python3 python3-pip
# Python dependencies
pip install sglang[all] ray[default] transformers peft trl datasets \
omegaconf accelerate bitsandbytes jsonschema jinja2 --upgradegit clone https://github.com/pmukeshreddy/distributed-reasoning-loop.git
cd distributed-reasoning-loop# Start inference server (GPU 1)
CUDA_VISIBLE_DEVICES=1 nohup python -m sglang.launch_server \
--model-path Qwen/Qwen2.5-1.5B-Instruct \
--host 0.0.0.0 --port 30000 > sglang.log 2>&1 &
# Wait for server to initialize
sleep 45
# Verify server is running
tail -n 3 sglang.log# Run pipeline (GPU 0)
CUDA_VISIBLE_DEVICES=0 python scripts/run_ray_pipeline.pyExpected output:
Phase 1: SGLang Generation
Generating: 100%|████████████████| 50/50 [03:23<00:00, 4.07s/it]
Phase 2: Ray Verification
Initialized 4 workers of each type
Ray stats: {'total_processed': 5000}
Phase 3: GRPO Training
Training: 100%|████████████████| 210/210 [08:00<00:00]
Done: SGLang -> Ray -> GRPO
# Restart server with base model
pkill -f sglang && sleep 2
CUDA_VISIBLE_DEVICES=1 nohup python -m sglang.launch_server \
--model-path Qwen/Qwen2.5-1.5B-Instruct \
--host 0.0.0.0 --port 30000 \
--trust-remote-code > sglang.log 2>&1 &
sleep 45
# Evaluate
python scripts/eval_pass_at_k.py \
--model http://localhost:30000 \
--dataset gsm8k \
--k 1 4 8Expected output:
============================================================
PASS@K RESULTS
============================================================
k Accuracy Tokens/s
------------------------------------------------------------
1 44.7% 23322
4 75.1% 23322
8 84.2% 23322
# Restart server with trained model
pkill -f sglang && sleep 2
CUDA_VISIBLE_DEVICES=1 nohup python -m sglang.launch_server \
--model-path ./outputs/grpo_model \
--host 0.0.0.0 --port 30000 \
--trust-remote-code > sglang.log 2>&1 &
sleep 45
# Evaluate
python scripts/eval_pass_at_k.py \
--model http://localhost:30000 \
--dataset gsm8k \
--k 1 4 8Expected output:
============================================================
PASS@K RESULTS
============================================================
k Accuracy Tokens/s
------------------------------------------------------------
1 74.0% 24038
4 88.0% 24038
8 92.6% 24038
distributed-reasoning-loop/
├── src/
│ ├── data_generator/
│ │ ├── cot_generator.py # SGLang inference
│ │ ├── synthetic_data_pipeline.py # Data generation pipeline
│ │ ├── data_preprocessor.py # Quality filtering, deduplication
│ │ └── dataset_loader.py # GSM8K, HumanEval loaders
│ ├── verifier/
│ │ ├── math_verifier.py # SymPy symbolic verification
│ │ └── code_verifier.py # Docker sandbox execution
│ ├── orchestration/
│ │ ├── ray_workers.py # Distributed processing
│ │ └── kafka_streaming.py # Streaming pipeline
│ ├── training/
│ │ ├── grpo_trainer.py # Group Relative Policy Optimization
│ │ ├── dpo_trainer.py # Direct Preference Optimization
│ │ └── sft_trainer.py # Supervised Fine-Tuning
│ └── evaluation/
│ ├── benchmarks.py # Evaluation metrics
│ └── test_time_compute.py # Pass@k, Best-of-N
├── scripts/
│ ├── run_ray_pipeline.py # Full pipeline script
│ └── eval_pass_at_k.py # Evaluation script
├── config/
│ └── default.yaml # Configuration
└── outputs/
├── synthetic_data/ # Generated data
└── grpo_model/ # Trained model
- RadixAttention: Automatic prefix caching for shared prompts
- Batched inference: Concurrent requests for high throughput
- Multi-path sampling: 10 reasoning paths per problem
- Parallel workers: 4 actors processing chunks (1250 samples each)
- Math verification: SymPy symbolic comparison
- Balanced distribution: Even workload across workers
- No reward model: Uses group-relative advantages
- Verification-based: Correct = positive, incorrect = negative
- Efficient: LoRA with 0.07% trainable parameters (1,089,536 params)
- GPU: 2x H100 (80GB) or equivalent
- RAM: 256GB+ recommended
- Storage: 50GB for models and data
- DeepSeek-R1 - GRPO algorithm
- SGLang - RadixAttention inference
- Ray - Distributed computing
- GSM8K - Math reasoning benchmark
- Prime Intellect - Distributed RL infrastructure