Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion gpt_oss/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
# torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/

import argparse
import os
from pathlib import Path

from gpt_oss.tokenizer import get_tokenizer


def main(args):
def main(args: argparse.Namespace) -> None:
# Validate checkpoint path exists
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint path does not exist: {args.checkpoint}")
Comment on lines +13 to +17

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow non-local checkpoints for vLLM backend

The new preflight check unconditionally wraps args.checkpoint in Path and raises FileNotFoundError when the path does not exist. This happens before the backend switch, so it also triggers when using the vLLM backend, which previously accepted HuggingFace model identifiers or other non-local URIs and downloaded weights on demand. With the change, any string that is not an existing filesystem path now fails before vLLM can handle it, breaking valid use cases such as --backend vllm --checkpoint mistralai/Mixtral-8x7B. Consider moving the existence check inside the torch/triton branches or restricting it to local backends only.

Useful? React with 👍 / 👎.


match args.backend:
case "torch":
from gpt_oss.torch.utils import init_distributed
Expand Down
3 changes: 2 additions & 1 deletion gpt_oss/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tiktoken

def get_tokenizer():

def get_tokenizer() -> tiktoken.Encoding:
o200k_base = tiktoken.get_encoding("o200k_base")
tokenizer = tiktoken.Encoding(
name="o200k_harmony",
Expand Down
65 changes: 52 additions & 13 deletions gpt_oss/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist


def suppress_output(rank):
def suppress_output(rank: int) -> None:
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
Expand All @@ -20,21 +20,60 @@ def print(*args, **kwargs):

def init_distributed() -> torch.device:
"""Initialize the model for distributed inference."""
# Check CUDA availability
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA is not available. Please ensure CUDA is installed and accessible."
)

# Initialize distributed inference
world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
if world_size > 1:
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=rank

# Validate rank against available devices
if rank >= torch.cuda.device_count():
raise RuntimeError(
f"Rank {rank} exceeds available CUDA devices ({torch.cuda.device_count()}). "
f"Please set RANK to a value between 0 and {torch.cuda.device_count() - 1}."
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")

try:
if world_size > 1:
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=rank
)

torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")

# Test device accessibility
try:
torch.cuda.get_device_properties(device)
except RuntimeError as e:
raise RuntimeError(
f"Failed to access CUDA device {rank}: {e}. "
"Please check device availability and permissions."
) from e

# Warm up NCCL to avoid first-time latency
if world_size > 1:
x = torch.ones(1, device=device)
dist.all_reduce(x)
torch.cuda.synchronize(device)
# Warm up NCCL to avoid first-time latency
if world_size > 1:
try:
x = torch.ones(1, device=device)
dist.all_reduce(x)
torch.cuda.synchronize(device)
except RuntimeError as e:
raise RuntimeError(
f"Failed to initialize distributed communication on device {rank}: {e}"
) from e

suppress_output(rank)
return device
suppress_output(rank)
return device

except Exception as e:
# Clean up distributed process group if initialization failed
if world_size > 1 and dist.is_initialized():
try:
dist.destroy_process_group()
except Exception:
pass # Ignore cleanup errors
raise
212 changes: 212 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Unit tests for tokenizer encoding/decoding functionality."""

import pytest
from gpt_oss.tokenizer import get_tokenizer


class TestTokenizerBasics:
"""Test basic tokenizer functionality."""

def test_get_tokenizer_returns_encoding(self):
"""Test that get_tokenizer returns a valid encoding."""
tokenizer = get_tokenizer()
assert tokenizer is not None
assert tokenizer.name == "o200k_harmony"

def test_tokenizer_has_harmony_special_tokens(self):
"""Test that tokenizer includes Harmony special tokens."""
tokenizer = get_tokenizer()
special_tokens = tokenizer._special_tokens

# Verify key Harmony tokens are present
assert "<|channel|>" in special_tokens
assert special_tokens["<|channel|>"] == 200005
assert "<|start|>" in special_tokens
assert special_tokens["<|start|>"] == 200006
assert "<|end|>" in special_tokens
assert special_tokens["<|end|>"] == 200007
assert "<|message|>" in special_tokens
assert special_tokens["<|message|>"] == 200008
assert "<|call|>" in special_tokens
assert special_tokens["<|call|>"] == 200012
assert "<|return|>" in special_tokens
assert special_tokens["<|return|>"] == 200002

def test_tokenizer_has_reserved_tokens(self):
"""Test that tokenizer includes reserved token range."""
tokenizer = get_tokenizer()
special_tokens = tokenizer._special_tokens

# Check reserved tokens exist in range
assert "<|reserved_200013|>" in special_tokens
assert special_tokens["<|reserved_200013|>"] == 200013
assert "<|reserved_201087|>" in special_tokens
assert special_tokens["<|reserved_201087|>"] == 201087


class TestTokenizerEncoding:
"""Test tokenizer encoding functionality."""

def test_encode_simple_text(self):
"""Test encoding simple text."""
tokenizer = get_tokenizer()
text = "Hello, world!"
tokens = tokenizer.encode(text)

assert isinstance(tokens, list)
assert len(tokens) > 0
assert all(isinstance(t, int) for t in tokens)

def test_encode_special_tokens(self):
"""Test encoding text with special tokens."""
tokenizer = get_tokenizer()
text = "<|channel|>final<|message|>Hello<|return|>"
tokens = tokenizer.encode(text, allowed_special="all")

assert 200005 in tokens # <|channel|>
assert 200008 in tokens # <|message|>
assert 200002 in tokens # <|return|>

def test_encode_without_special_allowed_raises(self):
"""Test that encoding special tokens without permission raises error."""
tokenizer = get_tokenizer()
text = "<|channel|>test"

with pytest.raises(ValueError):
tokenizer.encode(text)

def test_encode_empty_string(self):
"""Test encoding empty string."""
tokenizer = get_tokenizer()
tokens = tokenizer.encode("")

assert isinstance(tokens, list)
assert len(tokens) == 0

def test_encode_unicode_text(self):
"""Test encoding unicode text."""
tokenizer = get_tokenizer()
text = "Hello 世界 🌍"
tokens = tokenizer.encode(text)

assert isinstance(tokens, list)
assert len(tokens) > 0


class TestTokenizerDecoding:
"""Test tokenizer decoding functionality."""

def test_decode_simple_tokens(self):
"""Test decoding simple tokens."""
tokenizer = get_tokenizer()
text = "Hello, world!"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)

assert decoded == text

def test_decode_with_special_tokens(self):
"""Test decoding tokens including special tokens."""
tokenizer = get_tokenizer()
text = "<|channel|>final<|message|>Hello<|return|>"
tokens = tokenizer.encode(text, allowed_special="all")
decoded = tokenizer.decode(tokens)

assert decoded == text

def test_decode_empty_list(self):
"""Test decoding empty token list."""
tokenizer = get_tokenizer()
decoded = tokenizer.decode([])

assert decoded == ""

def test_decode_single_token(self):
"""Test decoding single token."""
tokenizer = get_tokenizer()
tokens = tokenizer.encode("a")
decoded = tokenizer.decode(tokens)

assert decoded == "a"

def test_decode_unicode(self):
"""Test decoding unicode tokens."""
tokenizer = get_tokenizer()
text = "Hello 世界 🌍"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)

assert decoded == text


class TestTokenizerRoundTrip:
"""Test encode/decode round-trip consistency."""

@pytest.mark.parametrize("text", [
"Simple text",
"Text with numbers: 123456",
"Special chars: !@#$%^&*()",
"Unicode: 你好世界",
"Emoji: 🚀🌟💡",
"Mixed: Hello 世界 123 🎉",
"Newlines:\nand\ttabs",
])
def test_roundtrip_consistency(self, text):
"""Test that encode->decode returns original text."""
tokenizer = get_tokenizer()
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)

assert decoded == text

def test_roundtrip_with_harmony_format(self):
"""Test round-trip with Harmony message format."""
tokenizer = get_tokenizer()
text = "<|channel|>analysis<|start|><|message|>Thinking...<|end|><|channel|>final<|message|>Answer<|return|>"
tokens = tokenizer.encode(text, allowed_special="all")
decoded = tokenizer.decode(tokens)

assert decoded == text


class TestTokenizerEdgeCases:
"""Test edge cases and error handling."""

def test_encode_very_long_text(self):
"""Test encoding very long text."""
tokenizer = get_tokenizer()
text = "a" * 10000
tokens = tokenizer.encode(text)

assert isinstance(tokens, list)
assert len(tokens) > 0

def test_decode_invalid_token_ids(self):
"""Test decoding with potentially invalid token IDs."""
tokenizer = get_tokenizer()
# Use valid token IDs from the special tokens
tokens = [200005, 200006, 200007]
decoded = tokenizer.decode(tokens)

assert isinstance(decoded, str)

def test_multiple_tokenizer_instances_consistent(self):
"""Test that multiple tokenizer instances behave consistently."""
tokenizer1 = get_tokenizer()
tokenizer2 = get_tokenizer()

text = "Test consistency"
tokens1 = tokenizer1.encode(text)
tokens2 = tokenizer2.encode(text)

assert tokens1 == tokens2

def test_special_token_ids_immutable(self):
"""Test that special token IDs are consistent."""
tokenizer = get_tokenizer()

# Get special tokens multiple times
channel_id_1 = tokenizer.encode("<|channel|>", allowed_special="all")[0]
channel_id_2 = tokenizer.encode("<|channel|>", allowed_special="all")[0]

assert channel_id_1 == channel_id_2 == 200005