Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2e13017
First commit with mocked scripts and launcher.
finbarrtimbers Nov 11, 2025
1123baa
Add script
hamishivi Nov 11, 2025
c21cf54
Updated code to be longer
finbarrtimbers Nov 11, 2025
8c7f5e3
Updated script
finbarrtimbers Nov 11, 2025
f2d7286
updated code to use 4 nodes
finbarrtimbers Nov 11, 2025
d4db1bc
updated script to not die with checkpoint error
finbarrtimbers Nov 11, 2025
4990aa9
fix init for tokenizer
finbarrtimbers Nov 11, 2025
c8b7626
fixed dataset loader
finbarrtimbers Nov 11, 2025
c39be00
updated dataset
finbarrtimbers Nov 11, 2025
e939ee1
Updated code
finbarrtimbers Nov 11, 2025
b7b4ffa
updated mocks cript
finbarrtimbers Nov 11, 2025
b1b4422
fixes mock
finbarrtimbers Nov 11, 2025
9f875a7
removed ref policy
finbarrtimbers Nov 11, 2025
85c22e5
removed old files
finbarrtimbers Nov 11, 2025
c690375
Uses a flag now
finbarrtimbers Nov 11, 2025
379ab73
Updated scripts
finbarrtimbers Nov 11, 2025
74b6fd3
fixed kl calculation
finbarrtimbers Nov 11, 2025
54d2cbb
conditionally load config
finbarrtimbers Nov 11, 2025
2f43921
cleaned up PR
finbarrtimbers Nov 12, 2025
59e28fe
Cleaned up calculate_ref_Logprobs
finbarrtimbers Nov 12, 2025
a0edde4
Cleaned up code
finbarrtimbers Nov 12, 2025
5b220a9
Cleaned up code.
finbarrtimbers Nov 12, 2025
5e6a6ad
Merge branch 'main' into finbarr/no-ref-policy
finbarrtimbers Nov 12, 2025
80c04eb
Added docstring plus type annotations
finbarrtimbers Nov 12, 2025
6d02b68
updated message
finbarrtimbers Nov 12, 2025
afa00b2
added comment
finbarrtimbers Nov 12, 2025
e367e3e
Fixed bugs
finbarrtimbers Nov 12, 2025
6dd533d
Merge branch 'main' into finbarr/no-ref-policy
finbarrtimbers Nov 13, 2025
cc453ea
now dschf is plumbed through
finbarrtimbers Nov 13, 2025
5a38087
Merge branch 'main' into finbarr/no-ref-policy
finbarrtimbers Nov 13, 2025
1caaa2c
uses active sampling
finbarrtimbers Nov 13, 2025
93a4171
Merge branch 'main' into finbarr/no-ref-policy
finbarrtimbers Nov 17, 2025
3bfb96b
ANother config
finbarrtimbers Nov 13, 2025
e5ee27a
Removed active sampling
finbarrtimbers Nov 13, 2025
51639ef
Cleaned up code
finbarrtimbers Nov 17, 2025
a992387
Added back comments
finbarrtimbers Nov 17, 2025
65679a5
Cleaned up code by removing unnecessary code.
finbarrtimbers Nov 17, 2025
7e2d945
Updated script
finbarrtimbers Nov 17, 2025
8e4f325
set load ref policy true
finbarrtimbers Nov 17, 2025
3b24977
Updated logprob code
finbarrtimbers Nov 17, 2025
53c5b81
Fix double import
finbarrtimbers Nov 18, 2025
1974db8
Ran linter.
finbarrtimbers Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 185 additions & 148 deletions open_instruct/grpo_fast.py

Large diffs are not rendered by default.

64 changes: 54 additions & 10 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,26 @@
# limitations under the License.


import asyncio
import itertools
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Literal, Union

try:
import deepspeed
from deepspeed.runtime.engine import DeepSpeedEngine
except ImportError:
pass
import asyncio

import deepspeed
import pandas as pd
import torch
import transformers
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from deepspeed.runtime.engine import DeepSpeedEngine
from huggingface_hub import HfApi
from rich import print as rprint
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from torch.nn.parallel.distributed import DistributedDataParallel
from transformers import PreTrainedModel, PreTrainedTokenizer

from open_instruct import logger_utils
from open_instruct.ground_truth_utils import VerifierFunction
Expand Down Expand Up @@ -155,6 +150,55 @@ def disable_dropout_in_model(model: torch.nn.Module) -> None:
module.p = 0


def load_ref_policy(
model_config: ModelConfig,
ds_config: dict,
deepspeed_stage: int,
local_rank: int,
device: torch.device,
rank: int,
checkpoint_path: str | None = None,
) -> transformers.PreTrainedModel:
"""Loads a reference policy model for evaluation.

Args:
model_config: Configuration containing model name and revision.
ds_config: DeepSpeed configuration dictionary.
deepspeed_stage: DeepSpeed ZeRO stage.
local_rank: Local GPU rank for device mapping.
device: Target device for loading checkpoint.
rank: Global process rank for logging.
checkpoint_path: Optional path to model checkpoint to load.

Returns:
Initialized reference policy model in evaluation mode.
"""
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
ref_policy: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
**({"device_map": {"": local_rank}} if deepspeed_stage != 3 else {}),
)
disable_dropout_in_model(ref_policy)
ref_policy, *_ = deepspeed.initialize(model=ref_policy, config=ds_config)
ref_policy.eval()

if checkpoint_path:
state_dict = torch.load(checkpoint_path, map_location=device)
if hasattr(ref_policy, "module"):
# Needed if wrapped by DeepSpeed.
ref_policy.module.load_state_dict(state_dict)
else:
# If a vanilla HF model.
ref_policy.load_state_dict(state_dict)
logger.info(f"{rank=}: Loaded reference policy checkpoint from {checkpoint_path}")
return ref_policy


def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""
Compute the entropy of the logits.
Expand Down Expand Up @@ -459,7 +503,7 @@ def get_olmo3_generation_config(tokenizer):
def save_with_accelerate(
accelerator: Accelerator,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizer,
tokenizer: transformers.PreTrainedTokenizer,
output_dir: str,
use_lora: bool = False,
model_attribute_to_save: str | None = None,
Expand All @@ -478,7 +522,7 @@ def save_with_accelerate(
temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id
)

unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model)
unwrapped_model: transformers.PreTrainedModel = accelerator.unwrap_model(model)
if model_attribute_to_save is not None:
unwrapped_model = getattr(unwrapped_model, model_attribute_to_save)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
Expand Down
31 changes: 11 additions & 20 deletions open_instruct/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
disable_dropout_in_model,
entropy_from_logits,
get_olmo3_generation_config,
load_ref_policy,
log_softmax_and_gather,
print_rich_single_line_metrics,
print_rich_table,
Expand Down Expand Up @@ -558,31 +559,21 @@ def from_pretrained(
)
self.value_model.train()

# reference model
ds_config = get_eval_ds_config(
ds_config, self.ref_policy_hf_ds_config = get_eval_ds_config(
offload=False,
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
per_device_train_batch_size=args.per_device_train_batch_size,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["gradient_accumulation_steps"] = 1
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
print(f"{dschf=}")
self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,

self.ref_policy: PreTrainedModel = load_ref_policy(
model_config=model_config,
ds_config=ds_config,
deepspeed_stage=args.deepspeed_stage,
local_rank=self.local_rank,
device=self.device,
rank=self.rank,
)
disable_dropout_in_model(self.ref_policy)
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config)
self.ref_policy.eval()
self.local_metrics = utils.MetricsTracker(device=self.device)

self.offload_to_cpu(self.model)
Expand Down
34 changes: 32 additions & 2 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from rich.pretty import pprint
from tqdm import tqdm
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
from transformers.integrations import HfDeepSpeedConfig

from open_instruct import logger_utils

Expand Down Expand Up @@ -1396,19 +1397,48 @@ def get_train_ds_config(
}


def get_eval_ds_config(offload, stage=0, bf16=True):
def get_eval_ds_config(
offload: bool, stage: int = 0, bf16: bool = True, per_device_train_batch_size: int = 1
) -> tuple[dict[str, Any], HfDeepSpeedConfig | None]:
"""Creates a DeepSpeed configuration for evaluation.

Args:
offload: Whether to offload parameters to CPU.
stage: ZeRO optimization stage. Only 0 or 3 are relevant as there's no optimizer for eval.
bf16: Whether to enable bfloat16 precision.
per_device_train_batch_size: Batch size per GPU.

Returns:
Tuple containing a Dictionary containing DeepSpeed configuration, and the actual HfDeepSpeedConfig object if stage 3 is used, else None. We need to return the HfDeepSpeedConfig object so it doesn't go out of scope as HF accelerate uses it internally via a global weakref.

Raises:
ValueError: If stage is not 0 or 3.
"""
if stage not in (0, 3):
raise ValueError(
f"stage must be 0 or 3 for evaluation (got {stage}). 1 or 2 only differ from stage 0 by optimizer sharding, which is irrelevant for evaluation."
)
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": "auto",
"offload_param": {"device": "cpu" if offload else "none", "pin_memory": True},
}
return {
ds_config = {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {"enabled": bf16},
"prescale_gradients": False,
"wall_clock_breakdown": False,
}
ds_config["train_micro_batch_size_per_gpu"] = per_device_train_batch_size
ds_config["gradient_accumulation_steps"] = 1
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
# This is needed as it apparently has mysterious side effects.
hf_config = HfDeepSpeedConfig(ds_config)
logger.info(f"DeepSpeed config: {hf_config}")
else:
hf_config = None
return ds_config, hf_config


def get_optimizer_grouped_parameters(
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/debug/large_test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ num_prompts=25376
exp_name=rlvr_ace_fn_and_og_ocr_stdio_from_base_with_perf_penalty
BEAKER_IMAGE="${1:-${BEAKER_USER}/open-instruct-integration-test}"
uv run python mason.py \
--cluster ai2/jupiter \
--cluster ai2/saturn \
--image "$BEAKER_IMAGE" \
--pure_docker_mode \
--workspace ai2/open-instruct-dev \
Expand All @@ -20,6 +20,7 @@ uv run python mason.py \
--gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& source configs/beaker_configs/code_api_setup.sh \&\&python open_instruct/grpo_fast.py \
--exp_name ${exp_name} \
--beta 0.0 \
--load_ref_policy false \
--num_samples_per_prompt_rollout 16 \
--num_unique_prompts_rollout 32 \
--num_mini_batches 1 \
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/debug/single_gpu_on_beaker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ uv run python mason.py \
--num_epochs 1 \
--num_learners_per_node 1 \
--vllm_tensor_parallel_size 1 \
--beta 0.01 \
--beta 0.0 \
--load_ref_policy true \
--seed 3 \
--local_eval_every 1 \
--vllm_sync_backend gloo \
Expand Down