diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 0d3828a6..47f3ee88 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -29,6 +29,7 @@ jobs: llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }} llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }} llama-3-8b-sft-name: ${{ steps.run-llama-3-8b-sft.outputs.name }} + llama-3-8b-dpo-name: ${{ steps.run-llama-3-8b-dpo.outputs.name }} llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }} llama-3-8b-fsdp-cp-name: ${{ steps.run-llama-3-8b-fsdp-cp.outputs.name }} mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }} @@ -262,6 +263,26 @@ jobs: task.convert_to_safetensors=False \ profile_start_step=3 + - name: Run Llama 3.0 8B DPO + id: run-llama-3-8b-dpo + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + XLA_IR_DEBUG: 1 + XLA_HLO_DEBUG: 1 + run: | + name=$(e2e_testing/gen_name.py llama-3-8b-dpo) + echo "name=$name" >> "$GITHUB_OUTPUT" + tp run ${{ steps.docker-url-option.outputs.value }} \ + --name $name \ + torchprime/torch_xla_models/train.py \ + --config-name llama-3-8b-dpo-w-orca \ + ici_mesh.fsdp=4 \ + task.max_steps=20 \ + task.global_batch_size=16 \ + task.lr_scheduler.type=constant \ + task.convert_to_safetensors=False \ + profile_start_step=3 + - name: Run Llama 3.0 8B (ddp + fsdp) id: run-llama-3-8b-ddp-fsdp env: @@ -334,6 +355,7 @@ jobs: matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name || matrix.config.benchmark == 'mixtral-8x7b' && needs.tp-run.outputs.mixtral-8x7b-name || matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name || + matrix.config.benchmark == 'llama-3-8b-dpo' && needs.tp-run.outputs.llama-3-8b-dpo-name || matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name || matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name || matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name diff --git a/e2e_testing/step_time_bounds.yaml b/e2e_testing/step_time_bounds.yaml index 356f45af..08bfdcfc 100644 --- a/e2e_testing/step_time_bounds.yaml +++ b/e2e_testing/step_time_bounds.yaml @@ -50,6 +50,13 @@ benchmarks: sample_size: 11 target_loss: 0.4735 loss_tolerance: 0.001 + llama-3-8b-dpo: + name: Llama 3.0 8B DPO + step_time_lower_bound: 0 + step_time_upper_bound: 1 + confidence_interval: 0.5 + average: 0.5 + sample_size: 1 llama-3-8b-ddp-fsdp: name: Llama 3.0 8B (ddp + fsdp) step_time_lower_bound: 3.22900775 diff --git a/torchprime/data/__init__.py b/torchprime/data/__init__.py index b9e16168..274c6a9e 100644 --- a/torchprime/data/__init__.py +++ b/torchprime/data/__init__.py @@ -3,15 +3,18 @@ """ from .dataset import make_train_dataset +from .dpo_dataset import make_dpo_dataset from .sft_dataset import make_sft_dataset DATASET_BUILDERS = { "train": make_train_dataset, "sft": make_sft_dataset, + "dpo": make_dpo_dataset, } __all__ = [ "DATASET_BUILDERS", "make_train_dataset", "make_sft_dataset", + "make_dpo_dataset", ] diff --git a/torchprime/data/dpo_dataset.py b/torchprime/data/dpo_dataset.py new file mode 100644 index 00000000..c0db366a --- /dev/null +++ b/torchprime/data/dpo_dataset.py @@ -0,0 +1,171 @@ +"""DPO dataset utilities.""" + +from __future__ import annotations + +from typing import Literal + +from datasets import Dataset +from transformers.tokenization_utils import PreTrainedTokenizerBase + +from .dataset import load_hf_or_json_dataset + +TRUNCATE_OPTION = Literal["right", "left", "drop"] + + +def _pad( + ids: list[int], + labels: list[int], + max_length: int, + pad_id: int, +) -> tuple[list[int], list[int], list[int]]: + """Pad IDs and labels to ``max_length``. + + Args: + ids: Encoded token IDs. + labels: Label token IDs. + max_length: Desired sequence length. + pad_id: Token ID to use for padding. + + Returns: + Tuple containing padded ``ids``, ``labels`` and the attention mask. + """ + ids = ids[:max_length] + labels = labels[:max_length] + attn = [1] * len(ids) + # Pad sequences to ``max_length`` and mask out the padding tokens. + if len(ids) < max_length: + ids = ids + [pad_id] * (max_length - len(ids)) + labels = labels + [-100] * (max_length - len(labels)) + attn = attn + [0] * (max_length - len(attn)) + return ids, labels, attn + + +def _tokenize_pair( + example: dict, + tokenizer: PreTrainedTokenizerBase, + *, + max_length: int, + truncation: TRUNCATE_OPTION, +) -> dict | None: + """Tokenize a preference pair. + + Each example contains a ``prompt`` and two completions: ``chosen`` and + ``rejected``. The completions are concatenated with the prompt and padded to + ``max_length``. + + Args: + example: Raw example with ``prompt``, ``chosen`` and ``rejected`` fields. + tokenizer: Tokenizer used to encode text. + max_length: Target length for the encoded sequences. + truncation: Strategy to handle sequences longer than ``max_length``. If + ``"drop"`` is specified the pair is skipped. + + Returns: + A dictionary with encoded tensors or ``None`` if the example was dropped. + """ + prompt = example.get("prompt", "") + chosen = example.get("chosen") + rejected = example.get("rejected") + if chosen is None or rejected is None: + return None + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + + def build(completion: str): + """Encode completion and append EOS token if necessary.""" + ids = prompt_ids + tokenizer.encode(completion, add_special_tokens=False) + # Mask out the prompt portion so that only the completion contributes to the loss. + labels = [-100] * len(prompt_ids) + tokenizer.encode( + completion, add_special_tokens=False + ) + if tokenizer.eos_token_id is not None: + ids.append(tokenizer.eos_token_id) + labels.append(tokenizer.eos_token_id) + if len(ids) > max_length: + if truncation == "drop": + # Skip examples that overflow the maximum length. + return None + if truncation == "left": + # Keep the last tokens when truncating from the left. + ids = ids[-max_length:] + labels = labels[-max_length:] + else: + # Default to truncating from the right. + ids = ids[:max_length] + labels = labels[:max_length] + return ids, labels + + built_c = build(chosen) + built_r = build(rejected) + if built_c is None or built_r is None: + return None + + # Fall back to the EOS token when the tokenizer has no dedicated PAD token. + pad_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) + ids_c, labels_c = built_c + ids_r, labels_r = built_r + # Pad both completions to a fixed ``block_size``. + ids_c, labels_c, mask_c = _pad(ids_c, labels_c, max_length, pad_id) + ids_r, labels_r, mask_r = _pad(ids_r, labels_r, max_length, pad_id) + return { + "chosen_input_ids": ids_c, + "chosen_labels": labels_c, + "chosen_attention_mask": mask_c, + "rejected_input_ids": ids_r, + "rejected_labels": labels_r, + "rejected_attention_mask": mask_r, + } + + +def make_dpo_dataset( + hf_dataset_name: str | None = None, + hf_dataset_config_name: str | None = None, + file_dataset_path: str | None = None, + split: str = "train", + cache_dir: str | None = None, + truncation: TRUNCATE_OPTION = "right", + *, + tokenizer: PreTrainedTokenizerBase, + block_size: int, +) -> Dataset: + """Create a dataset for Direct Preference Optimization. + + The function supports loading data from the Hugging Face hub or from a local + JSONL file. Each record must contain ``prompt``, ``chosen`` and ``rejected`` + fields which represent a single preference pair. + + Args: + hf_dataset_name: Optional name of a dataset on the Hugging Face hub. + hf_dataset_config_name: Optional dataset configuration name. + file_dataset_path: Optional path to a local JSONL file. + split: Dataset split to load when using the hub. + cache_dir: Directory to cache downloaded data. + truncation: Strategy used when sequences exceed ``block_size``. + tokenizer: Tokenizer used to encode the examples. + block_size: Maximum sequence length after tokenization. + + Returns: + A :class:`datasets.Dataset` containing processed pairs ready for training. + """ + data = load_hf_or_json_dataset( + hf_dataset_name=hf_dataset_name, + hf_dataset_config_name=hf_dataset_config_name, + file_dataset_path=file_dataset_path, + split=split, + cache_dir=cache_dir, + ) + + records = [] + for ex in data: + out = _tokenize_pair( + ex, + tokenizer, + max_length=block_size, + truncation=truncation, + ) + if out is not None: + records.append(out) + return Dataset.from_list(records) diff --git a/torchprime/tests/test_dpo_dataset.py b/torchprime/tests/test_dpo_dataset.py new file mode 100644 index 00000000..f854e1b4 --- /dev/null +++ b/torchprime/tests/test_dpo_dataset.py @@ -0,0 +1,62 @@ +import json +from pathlib import Path + +from datasets import Dataset +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from transformers import PreTrainedTokenizerFast + +from torchprime.data.dpo_dataset import make_dpo_dataset + + +def _write_json(tmpdir: Path, name: str, data): + path = tmpdir / name + with path.open("w") as f: + for item in data: + json.dump(item, f) + f.write("\n") + return path + + +def _tokenizer(): + vocab = { + "": 0, + "": 1, + "": 2, + "Hello": 3, + "World": 4, + "Hey": 5, + "": 6, + } + model = WordLevel(vocab, unk_token="") + tok = Tokenizer(model) + tok.pre_tokenizer = Whitespace() + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tok, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + ) + return tokenizer + + +def test_local_json_pair(tmp_path: Path): + data = [ + {"prompt": "Hello", "chosen": "World", "rejected": "Hey"}, + ] + path = _write_json(tmp_path, "pairs.json", data) + tok = _tokenizer() + ds = make_dpo_dataset(file_dataset_path=str(path), tokenizer=tok, block_size=8) + assert isinstance(ds, Dataset) + rec = ds[0] + hello_id = tok.convert_tokens_to_ids("Hello") + world_id = tok.convert_tokens_to_ids("World") + hey_id = tok.convert_tokens_to_ids("Hey") + eos = tok.eos_token_id + assert rec["chosen_input_ids"][:3] == [hello_id, world_id, eos] + assert rec["chosen_labels"][0] == -100 + assert rec["chosen_labels"][1] == world_id + assert rec["rejected_input_ids"][:3] == [hello_id, hey_id, eos] + assert rec["rejected_labels"][1] == hey_id diff --git a/torchprime/torch_xla_models/configs/dataset/orca.yaml b/torchprime/torch_xla_models/configs/dataset/orca.yaml new file mode 100644 index 00000000..43242d43 --- /dev/null +++ b/torchprime/torch_xla_models/configs/dataset/orca.yaml @@ -0,0 +1,7 @@ +# Dataset configuration for DPO using the Orca pairs dataset +hf_dataset_name: Intel/orca_dpo_pairs +hf_dataset_config_name: null +split: train +block_size: 256 +cache_dir: /tmp/ +truncation: drop diff --git a/torchprime/torch_xla_models/configs/llama-3-8b-dpo-w-orca.yaml b/torchprime/torch_xla_models/configs/llama-3-8b-dpo-w-orca.yaml new file mode 100644 index 00000000..21b59682 --- /dev/null +++ b/torchprime/torch_xla_models/configs/llama-3-8b-dpo-w-orca.yaml @@ -0,0 +1,15 @@ +# Configuration for DPO training on the Orca dataset + +defaults: + - default + - override model: llama-3-8b + - override dataset: orca + - override task: dpo + - _self_ + +task: + convert_to_safetensors: False + +model: + pretrained_model: meta-llama/Meta-Llama-3-8B + reference_model: meta-llama/Meta-Llama-3-8B diff --git a/torchprime/torch_xla_models/configs/task/dpo.yaml b/torchprime/torch_xla_models/configs/task/dpo.yaml new file mode 100644 index 00000000..fc5f51d8 --- /dev/null +++ b/torchprime/torch_xla_models/configs/task/dpo.yaml @@ -0,0 +1,15 @@ +# Task configuration for Direct Preference Optimization +name: dpo +global_batch_size: 64 +max_steps: 100 +export_checkpoint_path: export +convert_to_safetensors: True +beta: 0.1 +max_grad_norm: 1.0 +max_grad_value: null +optimizer: + learning_rate: 4.e-5 + type: adafactor +lr_scheduler: + type: linear + warmup_steps: 10 diff --git a/torchprime/torch_xla_models/tests/test_dpo_trainer.py b/torchprime/torch_xla_models/tests/test_dpo_trainer.py new file mode 100644 index 00000000..9c71bc9c --- /dev/null +++ b/torchprime/torch_xla_models/tests/test_dpo_trainer.py @@ -0,0 +1,135 @@ +"""Tests for the :class:`DPOTrainer` class.""" + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from torchprime.torch_xla_models.trainer.dpo_trainer import DPOTrainer + + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 2) + self.loaded = False + self.saved = False + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + logits = self.linear(input_ids) + loss = logits.mean() + return logits, loss + + def from_pretrained(self, path): + self.loaded = True + + def export(self, path): + self.saved = True + + def _maybe_save_checkpoint(self, config): + self.saved = True + + +class DummyDataset(Dataset): + def __init__(self): + self.device = xm.xla_device() + + def __getitem__(self, idx): + return { + "chosen_input_ids": torch.ones(4, device=self.device), + "chosen_attention_mask": torch.ones(4, device=self.device), + "chosen_labels": torch.ones(4, dtype=torch.long, device=self.device), + "rejected_input_ids": torch.zeros(4, device=self.device), + "rejected_attention_mask": torch.ones(4, device=self.device), + "rejected_labels": torch.zeros(4, dtype=torch.long, device=self.device), + } + + def __len__(self): + return 4 + + +class FakeMesh: + def __init__(self): + self.device_ids = [0] + self.axis_names = ("data", "fsdp") + self.mesh_shape = (1, 1) + + def shape(self): + return {"data": 1, "fsdp": 1} + + def get_axis_name_idx(self, axis_name): + return self.axis_names.index(axis_name) + + def get_logical_mesh(self): + return np.array(self.device_ids).reshape(self.mesh_shape) + + +@pytest.fixture +def dummy_config(): + return OmegaConf.create( + { + "model": { + "pure_modules": [], + "remat": { + "activation_checkpoint_layers": [], + "optimization_barrier_layers": [], + "scan_layers": None, + "offload_tensors": [], + }, + "sharding": {"type": "spmd"}, + "pretrained_model": "dummy", + }, + "data": {"name": "dummy_dataset", "block_size": 4}, + "task": { + "name": "dpo", + "global_batch_size": 2, + "max_steps": 1, + "max_grad_norm": None, + "max_grad_value": None, + "export_checkpoint_path": "dummy_export_path", + "beta": 0.1, + "optimizer": {"type": "adafactor", "learning_rate": 1e-3}, + "lr_scheduler": {"type": "constant", "warmup_steps": 0}, + }, + "run_name": None, + "output_dir": "/tmp/test_output", + "logging_steps": 1, + "profile_step": -1, + "profile_dir": "/tmp/profile", + "ici_mesh": {"data": 1, "fsdp": 1, "tensor": 1}, + "dcn_mesh": {}, + "torch_dtype": "bfloat16", + } + ) + + +def test_dpo_trainer(monkeypatch, dummy_config): + from torchprime.torch_xla_models.model_rewriting import sharding_initialization + + # Patch mesh setup + monkeypatch.setattr( + sharding_initialization, "get_mesh", lambda *args, **kwargs: FakeMesh() + ) + monkeypatch.setattr( + sharding_initialization, + "shard_torch_xla_model_from_config", + lambda model, *args, **kwargs: model, + ) + + # Patch process index and count + monkeypatch.setattr("torch_xla.runtime.process_index", lambda: 0) + monkeypatch.setattr("torch_xla.runtime.process_count", lambda: 1) + + device = xm.xla_device() + model = DummyModel().to(device) + dataset = DummyDataset() + trainer = DPOTrainer(model, dummy_config, dataset) + + assert model.loaded is True + + trainer.train_loop() + + assert model.saved is True diff --git a/torchprime/torch_xla_models/trainer/__init__.py b/torchprime/torch_xla_models/trainer/__init__.py index 4ed60df2..5200fd6a 100644 --- a/torchprime/torch_xla_models/trainer/__init__.py +++ b/torchprime/torch_xla_models/trainer/__init__.py @@ -1,15 +1,18 @@ """Trainer module for Torch XLA models.""" from .base_trainer import Trainer +from .dpo_trainer import DPOTrainer from .sft_trainer import SFTTrainer TRAINERS = { "train": Trainer, "sft": SFTTrainer, + "dpo": DPOTrainer, } __all__ = [ "TRAINERS", "Trainer", "SFTTrainer", + "DPOTrainer", ] diff --git a/torchprime/torch_xla_models/trainer/dpo_trainer.py b/torchprime/torch_xla_models/trainer/dpo_trainer.py new file mode 100644 index 00000000..cb246df9 --- /dev/null +++ b/torchprime/torch_xla_models/trainer/dpo_trainer.py @@ -0,0 +1,150 @@ +"""Trainer for Direct Preference Optimization (DPO).""" + +from __future__ import annotations + +import logging +from collections.abc import Generator +from contextlib import contextmanager + +import torch +import torch.nn.functional as F +import torch_xla +from omegaconf import DictConfig +from torch import nn + +from torchprime.torch_xla_models.model import model_utils + +from .sft_trainer import SFTTrainer + +logger = logging.getLogger(__name__) + + +class DPOTrainer(SFTTrainer): + """Trainer implementing a simple DPO objective.""" + + def __init__( + self, + model: nn.Module, + config: DictConfig, + train_dataset, + ) -> None: + """Initialize the trainer and create the reference model. + + Args: + model: The policy model to train. + config: Hydra configuration specifying optimizer and model options. + train_dataset: Dataset providing preference pairs. + """ + self.beta = getattr(config.task, "beta", 0.1) + super().__init__(model, config, train_dataset) + + dtype_name = config.get("torch_dtype", "bfloat16") + model_dtype = getattr(torch, dtype_name) + with model_utils.set_default_dtype(model_dtype), torch_xla.device(): + model_class = getattr(config.model, "model_class", None) + if model_class is None: + # Fall back to the same class as the policy model when the config does + # not specify ``model_class``. This is useful for unit tests that + # construct a small dummy model directly. + self.ref_model = model.__class__() + else: + # The reference model shares the same architecture as the policy model + # and is initialized from pretrained weights. It remains frozen during + # training. + self.ref_model = model_utils.initialize_model_class(config.model) + if getattr(config.model, "pretrained_model", None): + self.ref_model.from_pretrained(config.model.pretrained_model) + # Keep the reference model on CPU unless needed to save TPU memory. + self.ref_model.to("cpu") + self.ref_model.eval() + # Ensure the reference model does not receive gradient updates. + for p in self.ref_model.parameters(): + p.requires_grad_(False) + + @contextmanager + def _ref_model_on_device(self) -> Generator[None, None, None]: + """Context manager to temporarily move the reference model to the XLA device.""" + self.ref_model.to(self.device) + try: + yield + finally: + self.ref_model.to("cpu") + + def _seq_log_prob(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute the log probability of a sequence. + + Args: + logits: Model logits of shape ``[B, T, V]``. + labels: Target token IDs of shape ``[B, T]`` with ``-100`` for padding. + + Returns: + A tensor of shape ``[B]`` containing the summed log probabilities. + """ + vocab = logits.size(-1) + logits = logits[:, :-1].reshape(-1, vocab) + labels = labels[:, 1:].reshape(-1) + log_probs = F.log_softmax(logits, dim=-1) + labels_clipped = torch.where(labels == -100, torch.zeros_like(labels), labels) + token_log_probs = log_probs.gather(1, labels_clipped.unsqueeze(-1)).squeeze(-1) + # Ignore padding tokens when summing probabilities. + mask = labels != -100 + token_log_probs = token_log_probs * mask + seq_log_probs = token_log_probs.view(labels.size()).sum(dim=1) + return seq_log_probs + + @torch_xla.compile(full_graph=True) + def train_step(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor]: + """Run a single optimization step. + + The method computes the DPO loss between the current model and the + reference model for a batch of preference pairs and updates the model + parameters. + + Args: + batch: A dictionary containing tokenized ``chosen`` and ``rejected`` + sequences. + + Returns: + A tuple with the loss and gradient norm. + """ + # Forward pass for the policy model on both the preferred and rejected responses. + c_logits = self.model( + input_ids=batch["chosen_input_ids"], + attention_mask=batch["chosen_attention_mask"], + )[0] + r_logits = self.model( + input_ids=batch["rejected_input_ids"], + attention_mask=batch["rejected_attention_mask"], + )[0] + + # Reference model forward pass is executed without gradient tracking. The + # model is temporarily moved to the XLA device to save memory. + with self._ref_model_on_device(), torch.no_grad(): + c_ref = self.ref_model( + input_ids=batch["chosen_input_ids"], + attention_mask=batch["chosen_attention_mask"], + )[0] + r_ref = self.ref_model( + input_ids=batch["rejected_input_ids"], + attention_mask=batch["rejected_attention_mask"], + )[0] + + c_logp = self._seq_log_prob(c_logits, batch["chosen_labels"]) + r_logp = self._seq_log_prob(r_logits, batch["rejected_labels"]) + c_ref_logp = self._seq_log_prob(c_ref, batch["chosen_labels"]) + r_ref_logp = self._seq_log_prob(r_ref, batch["rejected_labels"]) + + # DPO loss compares the advantage of the policy over the reference model + # for the preferred vs. rejected responses. + pi_logratios = c_logp - r_logp + ref_logratios = c_ref_logp - r_ref_logp + losses = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios)) + # Average over the batch to obtain the final loss. + loss = losses.mean() + loss.backward() + grad_norm = self.clip_gradients() + self.optimizer.step() + self.lr_scheduler.step() + # Clear gradients for the next iteration. + self.model.zero_grad() + return loss, grad_norm