Skip to content

Add DPO training support <- created by AI agent #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
llama-3-8b-sft-name: ${{ steps.run-llama-3-8b-sft.outputs.name }}
llama-3-8b-dpo-name: ${{ steps.run-llama-3-8b-dpo.outputs.name }}
llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }}
llama-3-8b-fsdp-cp-name: ${{ steps.run-llama-3-8b-fsdp-cp.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
Expand Down Expand Up @@ -262,6 +263,26 @@ jobs:
task.convert_to_safetensors=False \
profile_start_step=3

- name: Run Llama 3.0 8B DPO
id: run-llama-3-8b-dpo
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-dpo)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
--config-name llama-3-8b-dpo-w-orca \
ici_mesh.fsdp=4 \
task.max_steps=20 \
task.global_batch_size=16 \
task.lr_scheduler.type=constant \
task.convert_to_safetensors=False \
profile_start_step=3

- name: Run Llama 3.0 8B (ddp + fsdp)
id: run-llama-3-8b-ddp-fsdp
env:
Expand Down Expand Up @@ -334,6 +355,7 @@ jobs:
matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name ||
matrix.config.benchmark == 'mixtral-8x7b' && needs.tp-run.outputs.mixtral-8x7b-name ||
matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name ||
matrix.config.benchmark == 'llama-3-8b-dpo' && needs.tp-run.outputs.llama-3-8b-dpo-name ||
matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name ||
matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name ||
matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name
Expand Down
7 changes: 7 additions & 0 deletions e2e_testing/step_time_bounds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ benchmarks:
sample_size: 11
target_loss: 0.4735
loss_tolerance: 0.001
llama-3-8b-dpo:
name: Llama 3.0 8B DPO
step_time_lower_bound: 0
step_time_upper_bound: 1
confidence_interval: 0.5
average: 0.5
sample_size: 1
llama-3-8b-ddp-fsdp:
name: Llama 3.0 8B (ddp + fsdp)
step_time_lower_bound: 3.22900775
Expand Down
3 changes: 3 additions & 0 deletions torchprime/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
"""

from .dataset import make_train_dataset
from .dpo_dataset import make_dpo_dataset
from .sft_dataset import make_sft_dataset

DATASET_BUILDERS = {
"train": make_train_dataset,
"sft": make_sft_dataset,
"dpo": make_dpo_dataset,
}

__all__ = [
"DATASET_BUILDERS",
"make_train_dataset",
"make_sft_dataset",
"make_dpo_dataset",
]
171 changes: 171 additions & 0 deletions torchprime/data/dpo_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""DPO dataset utilities."""

from __future__ import annotations

from typing import Literal

from datasets import Dataset
from transformers.tokenization_utils import PreTrainedTokenizerBase

from .dataset import load_hf_or_json_dataset

TRUNCATE_OPTION = Literal["right", "left", "drop"]


def _pad(
ids: list[int],
labels: list[int],
max_length: int,
pad_id: int,
) -> tuple[list[int], list[int], list[int]]:
"""Pad IDs and labels to ``max_length``.

Args:
ids: Encoded token IDs.
labels: Label token IDs.
max_length: Desired sequence length.
pad_id: Token ID to use for padding.

Returns:
Tuple containing padded ``ids``, ``labels`` and the attention mask.
"""
ids = ids[:max_length]
labels = labels[:max_length]
attn = [1] * len(ids)
# Pad sequences to ``max_length`` and mask out the padding tokens.
if len(ids) < max_length:
ids = ids + [pad_id] * (max_length - len(ids))
labels = labels + [-100] * (max_length - len(labels))
attn = attn + [0] * (max_length - len(attn))
return ids, labels, attn


def _tokenize_pair(
example: dict,
tokenizer: PreTrainedTokenizerBase,
*,
max_length: int,
truncation: TRUNCATE_OPTION,
) -> dict | None:
"""Tokenize a preference pair.

Each example contains a ``prompt`` and two completions: ``chosen`` and
``rejected``. The completions are concatenated with the prompt and padded to
``max_length``.

Args:
example: Raw example with ``prompt``, ``chosen`` and ``rejected`` fields.
tokenizer: Tokenizer used to encode text.
max_length: Target length for the encoded sequences.
truncation: Strategy to handle sequences longer than ``max_length``. If
``"drop"`` is specified the pair is skipped.

Returns:
A dictionary with encoded tensors or ``None`` if the example was dropped.
"""
prompt = example.get("prompt", "")
chosen = example.get("chosen")
rejected = example.get("rejected")
if chosen is None or rejected is None:
return None
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)

def build(completion: str):
"""Encode completion and append EOS token if necessary."""
ids = prompt_ids + tokenizer.encode(completion, add_special_tokens=False)
# Mask out the prompt portion so that only the completion contributes to the loss.
labels = [-100] * len(prompt_ids) + tokenizer.encode(
completion, add_special_tokens=False
)
if tokenizer.eos_token_id is not None:
ids.append(tokenizer.eos_token_id)
labels.append(tokenizer.eos_token_id)
if len(ids) > max_length:
if truncation == "drop":
# Skip examples that overflow the maximum length.
return None
if truncation == "left":
# Keep the last tokens when truncating from the left.
ids = ids[-max_length:]
labels = labels[-max_length:]
else:
# Default to truncating from the right.
ids = ids[:max_length]
labels = labels[:max_length]
return ids, labels

built_c = build(chosen)
built_r = build(rejected)
if built_c is None or built_r is None:
return None

# Fall back to the EOS token when the tokenizer has no dedicated PAD token.
pad_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
ids_c, labels_c = built_c
ids_r, labels_r = built_r
# Pad both completions to a fixed ``block_size``.
ids_c, labels_c, mask_c = _pad(ids_c, labels_c, max_length, pad_id)
ids_r, labels_r, mask_r = _pad(ids_r, labels_r, max_length, pad_id)
return {
"chosen_input_ids": ids_c,
"chosen_labels": labels_c,
"chosen_attention_mask": mask_c,
"rejected_input_ids": ids_r,
"rejected_labels": labels_r,
"rejected_attention_mask": mask_r,
}


def make_dpo_dataset(
hf_dataset_name: str | None = None,
hf_dataset_config_name: str | None = None,
file_dataset_path: str | None = None,
split: str = "train",
cache_dir: str | None = None,
truncation: TRUNCATE_OPTION = "right",
*,
tokenizer: PreTrainedTokenizerBase,
block_size: int,
) -> Dataset:
"""Create a dataset for Direct Preference Optimization.

The function supports loading data from the Hugging Face hub or from a local
JSONL file. Each record must contain ``prompt``, ``chosen`` and ``rejected``
fields which represent a single preference pair.

Args:
hf_dataset_name: Optional name of a dataset on the Hugging Face hub.
hf_dataset_config_name: Optional dataset configuration name.
file_dataset_path: Optional path to a local JSONL file.
split: Dataset split to load when using the hub.
cache_dir: Directory to cache downloaded data.
truncation: Strategy used when sequences exceed ``block_size``.
tokenizer: Tokenizer used to encode the examples.
block_size: Maximum sequence length after tokenization.

Returns:
A :class:`datasets.Dataset` containing processed pairs ready for training.
"""
data = load_hf_or_json_dataset(
hf_dataset_name=hf_dataset_name,
hf_dataset_config_name=hf_dataset_config_name,
file_dataset_path=file_dataset_path,
split=split,
cache_dir=cache_dir,
)

records = []
for ex in data:
out = _tokenize_pair(
ex,
tokenizer,
max_length=block_size,
truncation=truncation,
)
if out is not None:
records.append(out)
return Dataset.from_list(records)
62 changes: 62 additions & 0 deletions torchprime/tests/test_dpo_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json
from pathlib import Path

from datasets import Dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from transformers import PreTrainedTokenizerFast

from torchprime.data.dpo_dataset import make_dpo_dataset


def _write_json(tmpdir: Path, name: str, data):
path = tmpdir / name
with path.open("w") as f:
for item in data:
json.dump(item, f)
f.write("\n")
return path


def _tokenizer():
vocab = {
"<pad>": 0,
"<s>": 1,
"</s>": 2,
"Hello": 3,
"World": 4,
"Hey": 5,
"<unk>": 6,
}
model = WordLevel(vocab, unk_token="<unk>")
tok = Tokenizer(model)
tok.pre_tokenizer = Whitespace()
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tok,
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
unk_token="<unk>",
)
return tokenizer


def test_local_json_pair(tmp_path: Path):
data = [
{"prompt": "Hello", "chosen": "World", "rejected": "Hey"},
]
path = _write_json(tmp_path, "pairs.json", data)
tok = _tokenizer()
ds = make_dpo_dataset(file_dataset_path=str(path), tokenizer=tok, block_size=8)
assert isinstance(ds, Dataset)
rec = ds[0]
hello_id = tok.convert_tokens_to_ids("Hello")
world_id = tok.convert_tokens_to_ids("World")
hey_id = tok.convert_tokens_to_ids("Hey")
eos = tok.eos_token_id
assert rec["chosen_input_ids"][:3] == [hello_id, world_id, eos]
assert rec["chosen_labels"][0] == -100
assert rec["chosen_labels"][1] == world_id
assert rec["rejected_input_ids"][:3] == [hello_id, hey_id, eos]
assert rec["rejected_labels"][1] == hey_id
7 changes: 7 additions & 0 deletions torchprime/torch_xla_models/configs/dataset/orca.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Dataset configuration for DPO using the Orca pairs dataset
hf_dataset_name: Intel/orca_dpo_pairs
hf_dataset_config_name: null
split: train
block_size: 256
cache_dir: /tmp/
truncation: drop
15 changes: 15 additions & 0 deletions torchprime/torch_xla_models/configs/llama-3-8b-dpo-w-orca.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Configuration for DPO training on the Orca dataset

defaults:
- default
- override model: llama-3-8b
- override dataset: orca
- override task: dpo
- _self_

task:
convert_to_safetensors: False

model:
pretrained_model: meta-llama/Meta-Llama-3-8B
reference_model: meta-llama/Meta-Llama-3-8B
15 changes: 15 additions & 0 deletions torchprime/torch_xla_models/configs/task/dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Task configuration for Direct Preference Optimization
name: dpo
global_batch_size: 64
max_steps: 100
export_checkpoint_path: export
convert_to_safetensors: True
beta: 0.1
max_grad_norm: 1.0
max_grad_value: null
optimizer:
learning_rate: 4.e-5
type: adafactor
lr_scheduler:
type: linear
warmup_steps: 10
Loading
Loading