Date: Tue, 26 Aug 2025 10:26:52 -0700
Subject: [PATCH 68/91] Update README.md (#154)
---
README.md | 16 +++++++++++++++-
1 file changed, 15 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 876a58f0..4ef20827 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
Try gpt-oss ·
Guides ·
- Model card ·
+ Model card ·
OpenAI blog
@@ -498,3 +498,17 @@ We recommend sampling with `temperature=1.0` and `top_p=1.0`.
The reference implementations in this repository are meant as a starting point and inspiration. Outside of bug fixes we do not intend to accept new feature contributions. If you build implementations based on this code such as new tool implementations you are welcome to contribute them to the [`awesome-gpt-oss.md`](./awesome-gpt-oss.md) file.
[harmony]: https://github.com/openai/harmony
+
+## Citation
+
+```bibtex
+@misc{openai2025gptoss120bgptoss20bmodel,
+ title={gpt-oss-120b & gpt-oss-20b Model Card},
+ author={OpenAI},
+ year={2025},
+ eprint={2508.10925},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ url={https://arxiv.org/abs/2508.10925},
+}
+```
From 5ec1d16f423a735375a755eb9f511d738c02bbe3 Mon Sep 17 00:00:00 2001
From: Samagra Sharma
Date: Wed, 27 Aug 2025 19:57:03 -0700
Subject: [PATCH 69/91] Added Tensorfuse (AWS) guide (#118)
---
awesome-gpt-oss.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md
index 82cf7071..9cebe650 100644
--- a/awesome-gpt-oss.md
+++ b/awesome-gpt-oss.md
@@ -65,6 +65,8 @@ This is a list of guides and resources to help you get started with the gpt-oss
- [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b)
- AMD
- [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot)
+- AWS (Deploy via Tensorfuse)
+ - [Deploy gpt-oss for both 20b and 120b models on AWS EKS](https://tensorfuse.io/docs/guides/modality/text/openai_oss)
## Examples & Tutorials
From a19d0bc94d480505adb1e9b493d66e1a99d24443 Mon Sep 17 00:00:00 2001
From: Daniel Holanda
Date: Thu, 28 Aug 2025 11:58:19 -0700
Subject: [PATCH 70/91] Add Lemonade to `awesome-gpt-oss` (#117)
* Update awesome-gpt-oss.md
* Update awesome-gpt-oss.md
* Update awesome-gpt-oss.md
---------
Co-authored-by: Dominik Kundel
---
awesome-gpt-oss.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md
index 9cebe650..ac5a1c38 100644
--- a/awesome-gpt-oss.md
+++ b/awesome-gpt-oss.md
@@ -32,6 +32,8 @@ This is a list of guides and resources to help you get started with the gpt-oss
- [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss)
- AMD
- [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html)
+- Lemonade
+ - [Running gpt-oss on STX Halo and Radeon dGPUs using Lemonade](https://lemonade-server.ai/news/gpt-oss.html)
- llama.cpp
- [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)
From 0c39f1da17df3a5f895802a4e9d130f033f79ba5 Mon Sep 17 00:00:00 2001
From: Chen Zhang
Date: Thu, 28 Aug 2025 12:03:52 -0700
Subject: [PATCH 71/91] Add uv python backend (#156)
* add uv python backend
Co-authored-by: simon-mo
* dangerously_use_uv
---------
Co-authored-by: simon-mo
---
gpt_oss/tools/python_docker/docker_tool.py | 32 +++++++++++++++++++++-
1 file changed, 31 insertions(+), 1 deletion(-)
diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py
index c31680ea..3d630cc1 100644
--- a/gpt_oss/tools/python_docker/docker_tool.py
+++ b/gpt_oss/tools/python_docker/docker_tool.py
@@ -3,6 +3,9 @@
import io
import tarfile
from typing import Any, AsyncIterator
+import tempfile
+import os
+import subprocess
import docker
from openai_harmony import (
@@ -18,6 +21,11 @@
_docker_client = None
+PYTHON_EXECUTION_BACKEND = "docker"
+
+if os.environ.get("PYTHON_EXECUTION_BACKEND") == "dangerously_use_uv":
+ PYTHON_EXECUTION_BACKEND = "dangerously_use_uv"
+
def call_python_script(script: str) -> str:
"""
@@ -58,6 +66,21 @@ def call_python_script(script: str) -> str:
return output
+def call_python_script_with_uv(script: str) -> str:
+ """
+ Call a python script by writing it to a file to a temporary directory
+ and executing it with uv.
+ """
+ with tempfile.TemporaryDirectory() as temp_dir:
+ script_path = os.path.join(temp_dir, "script.py")
+ with open(script_path, "w") as f:
+ f.write(script)
+ exec_result = subprocess.run(
+ ["uv", "run", "--no-project", "python", script_path],
+ capture_output=True)
+ return exec_result.stdout.decode("utf-8")
+
+
class PythonTool(Tool):
def __init__(
self,
@@ -118,5 +141,12 @@ def make_response(
async def _process(self, message: Message) -> AsyncIterator[Message]:
script = message.content[0].text
channel = message.channel
- output = call_python_script(script)
+ if PYTHON_EXECUTION_BACKEND == "docker":
+ output = call_python_script(script)
+ elif PYTHON_EXECUTION_BACKEND == "dangerously_use_uv":
+ output = call_python_script_with_uv(script)
+ else:
+ raise ValueError(
+ f"Invalid PYTHON_EXECUTION_BACKEND: {PYTHON_EXECUTION_BACKEND}"
+ )
yield self._make_response(output, channel=channel)
From 7be9334950053a888e24887a57dac797a17d6e00 Mon Sep 17 00:00:00 2001
From: Dominik Kundel
Date: Thu, 28 Aug 2025 12:04:11 -0700
Subject: [PATCH 72/91] Update pyproject.toml
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 9ed47f92..fd38db07 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ dependencies = [
]
readme = "README.md"
requires-python = ">=3.12,<3.13"
-version = "0.0.4"
+version = "0.0.5"
[project.optional-dependencies]
triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"]
From 8ee92ec85de42a50deefb4075c45f014da5b2e22 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 08:21:27 -0700
Subject: [PATCH 73/91] Metal: add end-to-end benchmarks (#161)
---
gpt_oss/metal/CMakeLists.txt | 4 ++
gpt_oss/metal/benchmark/end-to-end.cc | 82 +++++++++++++++++++++++++++
2 files changed, 86 insertions(+)
create mode 100644 gpt_oss/metal/benchmark/end-to-end.cc
diff --git a/gpt_oss/metal/CMakeLists.txt b/gpt_oss/metal/CMakeLists.txt
index c6a8e32b..d18708f4 100644
--- a/gpt_oss/metal/CMakeLists.txt
+++ b/gpt_oss/metal/CMakeLists.txt
@@ -147,6 +147,10 @@ add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
+add_executable(end-to-end-bench benchmark/end-to-end.cc)
+target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
+target_include_directories(end-to-end-bench PRIVATE source/include)
+
# --- [ Python extension ] -----------------------------------------------
find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc
new file mode 100644
index 00000000..4f73be7a
--- /dev/null
+++ b/gpt_oss/metal/benchmark/end-to-end.cc
@@ -0,0 +1,82 @@
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+
+
+static void end2end(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+
+ gptoss_tokenizer_t tokenizer_ptr = nullptr;
+ status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to retrieve Tokenizer");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_tokenizer_release)> tokenizer(tokenizer_ptr, gptoss_tokenizer_release);
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), nullptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+
+ for (std::uint32_t i = 0; i < 3; i++) {
+ std::uint32_t predicted_token = std::numeric_limits::max();
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ }
+
+ for (auto _ : state) {
+ std::uint32_t predicted_token = std::numeric_limits::max();
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ }
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+}
+
+BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(end2end, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond);
+
+BENCHMARK_MAIN();
From 57e45b11b3a135e3e09d8fc5fea1bc793e003d44 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 08:22:01 -0700
Subject: [PATCH 74/91] Metal: simplify and optimize Reponses API adapter
(#162)
---
gpt_oss/responses_api/inference/metal.py | 62 ++----------------------
1 file changed, 5 insertions(+), 57 deletions(-)
diff --git a/gpt_oss/responses_api/inference/metal.py b/gpt_oss/responses_api/inference/metal.py
index 9abe50db..ec84af7e 100644
--- a/gpt_oss/responses_api/inference/metal.py
+++ b/gpt_oss/responses_api/inference/metal.py
@@ -11,68 +11,16 @@ def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
model = Model(checkpoint)
context = Context(model)
- def lcp(cache: list[int], inp: list[int]) -> list[int]:
- i = 0
- max_len = min(len(cache), len(inp))
- while i < max_len and cache[i] == inp[i]:
- i += 1
- return cache[:i]
-
- tokens_so_far = []
-
def infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
"""Infer next token using incremental LCP caching when possible."""
- nonlocal tokens_so_far
-
- # Fast path: first call or explicitly new request.
- if new_request or not tokens_so_far:
- context.reset()
- for t in tokens:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
-
- # Longest common prefix length
- overlap = lcp(tokens_so_far, tokens)
- ol = len(overlap)
- prev_len = len(tokens_so_far)
- cur_len = len(tokens)
-
- diverged_midstream = (ol < prev_len) and (
- ol < cur_len
- ) # mismatch not at the end
-
- if diverged_midstream:
- # safest: rebuild
- context.reset()
- for t in tokens:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
-
- if cur_len > prev_len:
- # pure extension (good for KV reuse)
- extension = tokens[prev_len:]
- for t in extension:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
-
- if cur_len < prev_len:
- # truncation/backspace; easiest correct behavior is rebuild
- context.reset()
- for t in tokens:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
- # cur_len == prev_len and everything matches => no new tokens; just sample.
+ # Context handles LCP caching internally; if `tokens` matches the
+ # tokens in the KV cache, the KV cache is reused after reset+append.
+ context.reset()
+ for t in tokens:
+ context.append(t)
return int(context.sample(temperature=temperature))
return infer_next_token
From 38df14a605d27bb0e4fc473266e3c2094e66a42b Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 09:43:56 -0700
Subject: [PATCH 75/91] Metal: fix KV-cache invalidation after reset+append
(#163)
---
gpt_oss/metal/source/context.c | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c
index b58df99a..c0155d64 100644
--- a/gpt_oss/metal/source/context.c
+++ b/gpt_oss/metal/source/context.c
@@ -618,12 +618,13 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
size_t num_verified_tokens = 0;
for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
+ // Invalidate the KV cache starting with the newly added tokens.
+ context->num_kv_tokens = context->num_tokens + num_verified_tokens;
break;
}
}
context->num_tokens += num_verified_tokens;
- context->num_kv_tokens = context->num_tokens;
tokens += num_verified_tokens;
num_tokens -= num_verified_tokens;
} else {
From 24804a6ac991b0dae88e32f8f1335c94bdfbf285 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 13:00:01 -0700
Subject: [PATCH 76/91] Increase max output tokens in Reponses API to 131K
(#165)
---
gpt_oss/responses_api/types.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gpt_oss/responses_api/types.py b/gpt_oss/responses_api/types.py
index 4ca72c56..454d8e07 100644
--- a/gpt_oss/responses_api/types.py
+++ b/gpt_oss/responses_api/types.py
@@ -6,7 +6,7 @@
MODEL_IDENTIFIER = "gpt-oss-120b"
DEFAULT_TEMPERATURE = 0.0
REASONING_EFFORT = ReasoningEffort.LOW
-DEFAULT_MAX_OUTPUT_TOKENS = 10_000
+DEFAULT_MAX_OUTPUT_TOKENS = 131072
class UrlCitation(BaseModel):
From 942ef444ae25493ff99cf2223e21096877f24f21 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 13:44:13 -0700
Subject: [PATCH 77/91] Remove requirement on maximum Python version (#167)
Codebase works fine with CPython 3.13, and the current stable is 3.13.7, so no good reason to restrict that
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index fd38db07..88f0ac45 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ dependencies = [
"termcolor",
]
readme = "README.md"
-requires-python = ">=3.12,<3.13"
+requires-python = ">=3.12"
version = "0.0.5"
[project.optional-dependencies]
From a8ce88fcaf376e336f2265a8cabdf3acacc69323 Mon Sep 17 00:00:00 2001
From: Daniel Holanda
Date: Tue, 2 Sep 2025 13:45:07 -0700
Subject: [PATCH 78/91] Move Lemonade to AMD section of `awesome-gpt-oss`
(#164)
* Update awesome-gpt-oss.md
* Update awesome-gpt-oss.md
* Update awesome-gpt-oss.md
* Add Lemonade to AMD section
---------
Co-authored-by: Dominik Kundel
---
awesome-gpt-oss.md | 1 -
1 file changed, 1 deletion(-)
diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md
index ac5a1c38..8b82ebf8 100644
--- a/awesome-gpt-oss.md
+++ b/awesome-gpt-oss.md
@@ -32,7 +32,6 @@ This is a list of guides and resources to help you get started with the gpt-oss
- [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss)
- AMD
- [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html)
-- Lemonade
- [Running gpt-oss on STX Halo and Radeon dGPUs using Lemonade](https://lemonade-server.ai/news/gpt-oss.html)
- llama.cpp
- [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)
From 864020abceb92dc5354ebd0b0f51be43bedf65ed Mon Sep 17 00:00:00 2001
From: hrithiksagar-tih
Date: Wed, 3 Sep 2025 02:16:19 +0530
Subject: [PATCH 79/91] Added VLLM Offline Serve working code. (#150)
---
README.md | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 76 insertions(+)
diff --git a/README.md b/README.md
index 4ef20827..c4612bca 100644
--- a/README.md
+++ b/README.md
@@ -91,6 +91,82 @@ vllm serve openai/gpt-oss-20b
[Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm)
+Offline Serve Code:
+- run this code after installing proper libraries as described, while additionally installing this:
+- `uv pip install openai-harmony`
+```python
+# source .oss/bin/activate
+
+import os
+os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0"
+
+import json
+from openai_harmony import (
+ HarmonyEncodingName,
+ load_harmony_encoding,
+ Conversation,
+ Message,
+ Role,
+ SystemContent,
+ DeveloperContent,
+)
+
+from vllm import LLM, SamplingParams
+import os
+
+# --- 1) Render the prefill with Harmony ---
+encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
+
+convo = Conversation.from_messages(
+ [
+ Message.from_role_and_content(Role.SYSTEM, SystemContent.new()),
+ Message.from_role_and_content(
+ Role.DEVELOPER,
+ DeveloperContent.new().with_instructions("Always respond in riddles"),
+ ),
+ Message.from_role_and_content(Role.USER, "What is the weather like in SF?"),
+ ]
+)
+
+prefill_ids = encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
+
+# Harmony stop tokens (pass to sampler so they won't be included in output)
+stop_token_ids = encoding.stop_tokens_for_assistant_actions()
+
+# --- 2) Run vLLM with prefill ---
+llm = LLM(
+ model="openai/gpt-oss-20b",
+ trust_remote_code=True,
+ gpu_memory_utilization = 0.95,
+ max_num_batched_tokens=4096,
+ max_model_len=5000,
+ tensor_parallel_size=1
+)
+
+sampling = SamplingParams(
+ max_tokens=128,
+ temperature=1,
+ stop_token_ids=stop_token_ids,
+)
+
+outputs = llm.generate(
+ prompt_token_ids=[prefill_ids], # batch of size 1
+ sampling_params=sampling,
+)
+
+# vLLM gives you both text and token IDs
+gen = outputs[0].outputs[0]
+text = gen.text
+output_tokens = gen.token_ids # <-- these are the completion token IDs (no prefill)
+
+# --- 3) Parse the completion token IDs back into structured Harmony messages ---
+entries = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)
+
+# 'entries' is a sequence of structured conversation entries (assistant messages, tool calls, etc.).
+for message in entries:
+ print(f"{json.dumps(message.to_dict())}")
+```
+
#### PyTorch / Triton / Metal
These implementations are largely reference implementations for educational purposes and are not expected to be run in production.
From 95d7716e75bb0dd8966f234bad0e32c61fd1e851 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 21:33:40 -0700
Subject: [PATCH 80/91] Metal: indicate threadgroup is a multiple of simdgroup
(#168)
2% speedup on gpt-oss-20b end-to-end sampling
---
gpt_oss/metal/source/metal.m | 67 +++++++++++++++++++++++++-----------
1 file changed, 46 insertions(+), 21 deletions(-)
diff --git a/gpt_oss/metal/source/metal.m b/gpt_oss/metal/source/metal.m
index a873bb36..03d69962 100644
--- a/gpt_oss/metal/source/metal.m
+++ b/gpt_oss/metal/source/metal.m
@@ -96,18 +96,19 @@ enum gptoss_status gptoss_metal_library_create_default(
enum gptoss_status status = gptoss_status_success;
id device_obj = (id) device->object;
id library_obj = nil;
- NSError* error_obj = nil;
- NSString* error_string_obj = nil;
+ NSAutoreleasePool* autorelease_pool = nil;
dispatch_data_t library_blob = NULL;
unsigned long library_size = 0;
uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
if (library_data != NULL) {
library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
+
+ autorelease_pool = [[NSAutoreleasePool alloc] init];
+ NSError* error_obj = nil;
library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
if (library_obj == nil) {
- error_string_obj = [error_obj localizedDescription];
- GPTOSS_LOG_ERROR("failed to create Metal library: %s", [error_string_obj UTF8String]);
+ GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]);
status = gptoss_status_unsupported_system;
goto cleanup;
}
@@ -129,11 +130,8 @@ enum gptoss_status gptoss_metal_library_create_default(
if (library_blob != NULL) {
dispatch_release(library_blob);
}
- if (error_string_obj != nil) {
- [error_string_obj release];
- }
- if (error_obj != nil) {
- [error_obj release];
+ if (autorelease_pool != nil) {
+ [autorelease_pool drain];
}
return status;
}
@@ -154,14 +152,16 @@ enum gptoss_status gptoss_metal_function_create(
const char* name,
struct gptoss_metal_function* function_out)
{
- NSString* name_obj = nil;
- NSError* error_obj = nil;
- NSString* error_string_obj = nil;
+ __block NSString* error_string_obj = nil;
id function_obj = nil;
+ MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;
+ __block id pipeline_state_obj = nil;
+ dispatch_semaphore_t pipeline_build_semaphore = NULL;
enum gptoss_status status = gptoss_status_success;
+ NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];
id library_obj = (id) library->object;
- name_obj = [NSString stringWithUTF8String:name];
+ NSString* name_obj = [NSString stringWithUTF8String:name];
function_obj = [library_obj newFunctionWithName:name_obj];
if (function_obj == nil) {
GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
@@ -169,11 +169,33 @@ enum gptoss_status gptoss_metal_function_create(
goto cleanup;
}
id device_obj = [library_obj device];
- id pipeline_state_obj = [device_obj newComputePipelineStateWithFunction:function_obj error:&error_obj];
+ pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];
+ [pipeline_descriptor_obj setComputeFunction:function_obj];
+ [pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];
+
+ pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);
+ [device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj
+ options:MTLPipelineOptionNone
+ completionHandler:^(id _Nullable new_state,
+ MTLComputePipelineReflection* _Nullable reflection,
+ NSError* _Nullable error_obj) {
+ if (new_state != nil) {
+ pipeline_state_obj = [new_state retain];
+ }
+ if (error_obj != nil) {
+ error_string_obj = [[error_obj localizedDescription] copy];
+ }
+ dispatch_semaphore_signal(pipeline_build_semaphore);
+ }];
+ dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);
+
if (pipeline_state_obj == nil) {
- error_string_obj = [error_obj localizedDescription];
+ const char* error_string = "unknown error";
+ if (error_string_obj != nil) {
+ error_string = [error_string_obj UTF8String];
+ }
GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
- name, [error_string_obj UTF8String]);
+ name, error_string);
status = gptoss_status_unsupported_system;
goto cleanup;
}
@@ -189,17 +211,20 @@ enum gptoss_status gptoss_metal_function_create(
pipeline_state_obj = nil;
cleanup:
- if (name_obj != nil) {
- [name_obj release];
- }
if (function_obj != nil) {
[function_obj release];
}
+ if (pipeline_descriptor_obj != nil) {
+ [pipeline_descriptor_obj release];
+ }
if (error_string_obj != nil) {
[error_string_obj release];
}
- if (error_obj != nil) {
- [error_obj release];
+ if (pipeline_build_semaphore != NULL) {
+ dispatch_release(pipeline_build_semaphore);
+ }
+ if (autorelease_pool != nil) {
+ [autorelease_pool drain];
}
return status;
}
From 7f3c896dad67c3d39c73372d9a0a16f2c8835755 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 2 Sep 2025 23:16:12 -0700
Subject: [PATCH 81/91] Metal: mlock model weights in memory (#170)
---
gpt_oss/metal/source/include/internal/model.h | 3 +++
gpt_oss/metal/source/model.c | 12 ++++++++++++
2 files changed, 15 insertions(+)
diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h
index 6b477745..ae62a3ec 100644
--- a/gpt_oss/metal/source/include/internal/model.h
+++ b/gpt_oss/metal/source/include/internal/model.h
@@ -1,6 +1,7 @@
#pragma once
#include
+#include
#include
#include
@@ -54,6 +55,8 @@ struct gptoss_model {
// Once the batch size is reached, we process it to fill the KV cache.
size_t max_batch_tokens;
+ bool lock_memory;
+
size_t weights_size;
size_t allocation_size;
diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c
index e3aeb98f..70668639 100644
--- a/gpt_oss/metal/source/model.c
+++ b/gpt_oss/metal/source/model.c
@@ -290,6 +290,12 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
+ if (mlock(model_mapping_ptr, model_mapping_size) != 0) {
+ GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
+ } else {
+ model->lock_memory = true;
+ }
+
// Initialize Metal
status = gptoss_metal_device_create_system_default(&model->device);
if (status != gptoss_status_success) {
@@ -497,6 +503,12 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release(
// Weight buffers
if (model->mapping_ptr != NULL && model->mapping_size != 0) {
+ if (model->lock_memory) {
+ if (munlock(model->mapping_ptr, model->mapping_size) != 0) {
+ GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno);
+ }
+ }
+
if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
}
From a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd Mon Sep 17 00:00:00 2001
From: bojanbabic
Date: Wed, 3 Sep 2025 15:30:02 -0700
Subject: [PATCH 82/91] Add You.com as tool for browser (#171)
* Add You.com as tool for browser
* change key name
* update tests in order to mock API key
* address changes
* address changes
* update README
---
README.md | 13 ++-
gpt-oss-mcp-server/browser_server.py | 12 ++-
gpt-oss-mcp-server/reference-system-prompt.py | 4 +-
gpt_oss/chat.py | 4 +-
gpt_oss/responses_api/api_server.py | 13 ++-
gpt_oss/tools/simple_browser/__init__.py | 3 +-
gpt_oss/tools/simple_browser/backend.py | 102 ++++++++++++++++--
.../tools/simple_browser/test_backend.py | 70 ++++++++++++
8 files changed, 197 insertions(+), 24 deletions(-)
create mode 100644 tests/gpt_oss/tools/simple_browser/test_backend.py
diff --git a/README.md b/README.md
index c4612bca..0104cec4 100644
--- a/README.md
+++ b/README.md
@@ -426,7 +426,7 @@ codex -p oss
### Browser
> [!WARNING]
-> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`ExaBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment.
+> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`YouComBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment. Currently we have available `YouComBackend` and `ExaBackend`.
Both gpt-oss models were trained with the capability to browse using the `browser` tool that exposes the following three methods:
@@ -441,15 +441,20 @@ To enable the browser tool, you'll have to place the definition into the `system
```python
import datetime
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend
from openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
-# Exa backend requires you to have set the EXA_API_KEY environment variable
-backend = ExaBackend(
+# Depending on the choice of the browser backend you need corresponding env variables setup
+# In case you use You.com backend requires you to have set the YDC_API_KEY environment variable,
+# while for Exa you might need EXA_API_KEY environment variable set
+backend = YouComBackend(
source="web",
)
+# backend = ExaBackend(
+# source="web",
+# )
browser_tool = SimpleBrowserTool(backend=backend)
# create a basic system prompt
diff --git a/gpt-oss-mcp-server/browser_server.py b/gpt-oss-mcp-server/browser_server.py
index 5d5ad4ad..b37a63a6 100644
--- a/gpt-oss-mcp-server/browser_server.py
+++ b/gpt-oss-mcp-server/browser_server.py
@@ -1,3 +1,4 @@
+import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
@@ -5,8 +6,7 @@
from mcp.server.fastmcp import Context, FastMCP
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
-
+from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend
@dataclass
class AppContext:
@@ -14,7 +14,13 @@ class AppContext:
def create_or_get_browser(self, session_id: str) -> SimpleBrowserTool:
if session_id not in self.browsers:
- backend = ExaBackend(source="web")
+ tool_backend = os.getenv("BROWSER_BACKEND", "exa")
+ if tool_backend == "youcom":
+ backend = YouComBackend(source="web")
+ elif tool_backend == "exa":
+ backend = ExaBackend(source="web")
+ else:
+ raise ValueError(f"Invalid tool backend: {tool_backend}")
self.browsers[session_id] = SimpleBrowserTool(backend=backend)
return self.browsers[session_id]
diff --git a/gpt-oss-mcp-server/reference-system-prompt.py b/gpt-oss-mcp-server/reference-system-prompt.py
index 98f171dd..6ddbf7c9 100644
--- a/gpt-oss-mcp-server/reference-system-prompt.py
+++ b/gpt-oss-mcp-server/reference-system-prompt.py
@@ -1,7 +1,7 @@
import datetime
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from gpt_oss.tokenizer import tokenizer
@@ -22,7 +22,7 @@
ReasoningEffort.LOW).with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")))
-backend = ExaBackend(source="web", )
+backend = YouComBackend(source="web")
browser_tool = SimpleBrowserTool(backend=backend)
system_message_content = system_message_content.with_tools(
browser_tool.tool_config)
diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py
index 5e40079d..4856a397 100644
--- a/gpt_oss/chat.py
+++ b/gpt_oss/chat.py
@@ -19,7 +19,7 @@
from gpt_oss.tools import apply_patch
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from openai_harmony import (
@@ -85,7 +85,7 @@ def main(args):
)
if args.browser:
- backend = ExaBackend(
+ backend = YouComBackend(
source="web",
)
browser_tool = SimpleBrowserTool(backend=backend)
diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py
index 2934b011..8eb053f1 100644
--- a/gpt_oss/responses_api/api_server.py
+++ b/gpt_oss/responses_api/api_server.py
@@ -1,3 +1,4 @@
+import os
import datetime
import uuid
from typing import Callable, Literal, Optional
@@ -20,7 +21,7 @@
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend
from .events import (
ResponseCodeInterpreterCallCompleted,
@@ -904,9 +905,13 @@ async def generate(body: ResponsesRequest, request: Request):
)
if use_browser_tool:
- backend = ExaBackend(
- source="web",
- )
+ tool_backend = os.getenv("BROWSER_BACKEND", "exa")
+ if tool_backend == "youcom":
+ backend = YouComBackend(source="web")
+ elif tool_backend == "exa":
+ backend = ExaBackend(source="web")
+ else:
+ raise ValueError(f"Invalid tool backend: {tool_backend}")
browser_tool = SimpleBrowserTool(backend=backend)
else:
browser_tool = None
diff --git a/gpt_oss/tools/simple_browser/__init__.py b/gpt_oss/tools/simple_browser/__init__.py
index 9043cb18..da3ff280 100644
--- a/gpt_oss/tools/simple_browser/__init__.py
+++ b/gpt_oss/tools/simple_browser/__init__.py
@@ -1,7 +1,8 @@
from .simple_browser_tool import SimpleBrowserTool
-from .backend import ExaBackend
+from .backend import ExaBackend, YouComBackend
__all__ = [
"SimpleBrowserTool",
"ExaBackend",
+ "YouComBackend",
]
diff --git a/gpt_oss/tools/simple_browser/backend.py b/gpt_oss/tools/simple_browser/backend.py
index 03bdf566..33daf8d6 100644
--- a/gpt_oss/tools/simple_browser/backend.py
+++ b/gpt_oss/tools/simple_browser/backend.py
@@ -3,6 +3,7 @@
"""
import functools
+import asyncio
import logging
import os
from abc import abstractmethod
@@ -87,6 +88,24 @@ async def search(
async def fetch(self, url: str, session: ClientSession) -> PageContents:
pass
+ async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict:
+ headers = {"x-api-key": self._get_api_key()}
+ async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp:
+ if resp.status != 200:
+ raise BackendError(
+ f"{self.__class__.__name__} error {resp.status}: {await resp.text()}"
+ )
+ return await resp.json()
+
+ async def _get(self, session: ClientSession, endpoint: str, params: dict) -> dict:
+ headers = {"x-api-key": self._get_api_key()}
+ async with session.get(f"{self.BASE_URL}{endpoint}", params=params, headers=headers) as resp:
+ if resp.status != 200:
+ raise BackendError(
+ f"{self.__class__.__name__} error {resp.status}: {await resp.text()}"
+ )
+ return await resp.json()
+
@chz.chz(typecheck=True)
class ExaBackend(Backend):
@@ -106,14 +125,6 @@ def _get_api_key(self) -> str:
raise BackendError("Exa API key not provided")
return key
- async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict:
- headers = {"x-api-key": self._get_api_key()}
- async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp:
- if resp.status != 200:
- raise BackendError(
- f"Exa API error {resp.status}: {await resp.text()}"
- )
- return await resp.json()
async def search(
self, query: str, topn: int, session: ClientSession
@@ -164,3 +175,78 @@ async def fetch(self, url: str, session: ClientSession) -> PageContents:
display_urls=True,
session=session,
)
+
+@chz.chz(typecheck=True)
+class YouComBackend(Backend):
+ """Backend that uses the You.com Search API."""
+
+ source: str = chz.field(doc="Description of the backend source")
+
+ BASE_URL: str = "https://api.ydc-index.io"
+
+ def _get_api_key(self) -> str:
+ key = os.environ.get("YDC_API_KEY")
+ if not key:
+ raise BackendError("You.com API key not provided")
+ return key
+
+
+ async def search(
+ self, query: str, topn: int, session: ClientSession
+ ) -> PageContents:
+ data = await self._get(
+ session,
+ "/v1/search",
+ {"query": query, "count": topn},
+ )
+ # make a simple HTML page to work with browser format
+ web_titles_and_urls, news_titles_and_urls = [], []
+ if "web" in data["results"]:
+ web_titles_and_urls = [
+ (result["title"], result["url"], result["snippets"])
+ for result in data["results"]["web"]
+ ]
+ if "news" in data["results"]:
+ news_titles_and_urls = [
+ (result["title"], result["url"], result["description"])
+ for result in data["results"]["news"]
+ ]
+ titles_and_urls = web_titles_and_urls + news_titles_and_urls
+ html_page = f"""
+
+Search Results
+
+{"".join([f"- {title} {summary}
" for title, url, summary in titles_and_urls])}
+
+
+"""
+
+ return process_html(
+ html=html_page,
+ url="",
+ title=query,
+ display_urls=True,
+ session=session,
+ )
+
+ async def fetch(self, url: str, session: ClientSession) -> PageContents:
+ is_view_source = url.startswith(VIEW_SOURCE_PREFIX)
+ if is_view_source:
+ url = url[len(VIEW_SOURCE_PREFIX) :]
+ data = await self._post(
+ session,
+ "/v1/contents",
+ {"urls": [url], "livecrawl_formats": "html"},
+ )
+ if not data:
+ raise BackendError(f"No contents returned for {url}")
+ if "html" not in data[0]:
+ raise BackendError(f"No HTML returned for {url}")
+ return process_html(
+ html=data[0].get("html", ""),
+ url=url,
+ title=data[0].get("title", ""),
+ display_urls=True,
+ session=session,
+ )
+
diff --git a/tests/gpt_oss/tools/simple_browser/test_backend.py b/tests/gpt_oss/tools/simple_browser/test_backend.py
new file mode 100644
index 00000000..ab0dc780
--- /dev/null
+++ b/tests/gpt_oss/tools/simple_browser/test_backend.py
@@ -0,0 +1,70 @@
+import pytest
+from typing import Generator, Any
+from unittest import mock
+from aiohttp import ClientSession
+
+from gpt_oss.tools.simple_browser.backend import YouComBackend
+
+class MockAiohttpResponse:
+ """Mocks responses for get/post requests from async libraries."""
+
+ def __init__(self, json: dict, status: int):
+ self._json = json
+ self.status = status
+
+ async def json(self):
+ return self._json
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+def mock_os_environ_get(name: str, default: Any = "test_api_key"):
+ assert name in ["YDC_API_KEY"]
+ return default
+
+def test_youcom_backend():
+ backend = YouComBackend(source="web")
+ assert backend.source == "web"
+
+@pytest.mark.asyncio
+@mock.patch("aiohttp.ClientSession.get")
+async def test_youcom_backend_search(mock_session_get):
+ backend = YouComBackend(source="web")
+ api_response = {
+ "results": {
+ "web": [
+ {"title": "Web Result 1", "url": "https://www.example.com/web1", "snippets": "Web Result 1 snippets"},
+ {"title": "Web Result 2", "url": "https://www.example.com/web2", "snippets": "Web Result 2 snippets"},
+ ],
+ "news": [
+ {"title": "News Result 1", "url": "https://www.example.com/news1", "description": "News Result 1 description"},
+ {"title": "News Result 2", "url": "https://www.example.com/news2", "description": "News Result 2 description"},
+ ],
+ }
+ }
+ with mock.patch("os.environ.get", wraps=mock_os_environ_get):
+ mock_session_get.return_value = MockAiohttpResponse(api_response, 200)
+ async with ClientSession() as session:
+ result = await backend.search(query="test", topn=10, session=session)
+ assert result.title == "test"
+ assert result.urls == {"0": "https://www.example.com/web1", "1": "https://www.example.com/web2", "2": "https://www.example.com/news1", "3": "https://www.example.com/news2"}
+
+@pytest.mark.asyncio
+@mock.patch("aiohttp.ClientSession.post")
+async def test_youcom_backend_fetch(mock_session_get):
+ backend = YouComBackend(source="web")
+ api_response = [
+ {"title": "Fetch Result 1", "url": "https://www.example.com/fetch1", "html": "Fetch Result 1 text
"},
+ ]
+ with mock.patch("os.environ.get", wraps=mock_os_environ_get):
+ mock_session_get.return_value = MockAiohttpResponse(api_response, 200)
+ async with ClientSession() as session:
+ result = await backend.fetch(url="https://www.example.com/fetch1", session=session)
+ assert result.title == "Fetch Result 1"
+ assert result.text == "\nURL: https://www.example.com/fetch1\nFetch Result 1 text"
+
+
+
\ No newline at end of file
From b558ecc5534986fb73fd8555ca04e2436149eb12 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Mon, 8 Sep 2025 00:21:32 -0700
Subject: [PATCH 83/91] Evals: correctly pass temperature/max_tokens when using
Responses API (#174)
---
gpt_oss/evals/responses_sampler.py | 25 +++++++++----------------
1 file changed, 9 insertions(+), 16 deletions(-)
diff --git a/gpt_oss/evals/responses_sampler.py b/gpt_oss/evals/responses_sampler.py
index fd9daef3..134303f5 100644
--- a/gpt_oss/evals/responses_sampler.py
+++ b/gpt_oss/evals/responses_sampler.py
@@ -42,24 +42,17 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
trial = 0
while True:
try:
+ request_kwargs = {
+ "model": self.model,
+ "input": message_list,
+ "temperature": self.temperature,
+ "max_output_tokens": self.max_tokens,
+ }
if self.reasoning_model:
- reasoning = (
- {"effort": self.reasoning_effort}
- if self.reasoning_effort
- else None
- )
- response = self.client.responses.create(
- model=self.model,
- input=message_list,
- reasoning=reasoning,
- )
- else:
- response = self.client.responses.create(
- model=self.model,
- input=message_list,
- temperature=self.temperature,
- max_output_tokens=self.max_tokens,
+ request_kwargs["reasoning"] = (
+ {"effort": self.reasoning_effort} if self.reasoning_effort else None
)
+ response = self.client.responses.create(**request_kwargs)
for output in response.output:
if hasattr(output, "text"):
From be0d32efaef14a33bb5fad0a9b1d87ca240ad85b Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Mon, 8 Sep 2025 13:34:35 -0700
Subject: [PATCH 84/91] Metal: move sampling to GPU (#175)
---
gpt_oss/metal/source/context.c | 225 +++++++++---------
.../source/include/internal/kernel-args.h | 8 +
.../source/include/internal/metal-kernels.h | 16 ++
gpt_oss/metal/source/include/internal/model.h | 1 +
gpt_oss/metal/source/metal-kernels.c | 53 +++++
gpt_oss/metal/source/model.c | 5 +
gpt_oss/metal/source/sample.metal | 141 +++++++++++
7 files changed, 332 insertions(+), 117 deletions(-)
diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c
index c0155d64..0791c3eb 100644
--- a/gpt_oss/metal/source/context.c
+++ b/gpt_oss/metal/source/context.c
@@ -162,6 +162,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
static enum gptoss_status process_tokens(
gptoss_context_t context,
+ struct gptoss_metal_command_buffer* command_buffer,
size_t input_tokens_offset,
size_t num_input_tokens,
size_t num_output_tokens)
@@ -173,14 +174,9 @@ static enum gptoss_status process_tokens(
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
- struct gptoss_metal_command_buffer command_buffer = {0};
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
- status = gptoss_metal_command_buffer_create(&model->command_queue, &command_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
for (size_t input_batch_start = input_tokens_offset;
input_batch_start < input_tokens_end;
@@ -191,7 +187,7 @@ static enum gptoss_status process_tokens(
const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
- &command_buffer,
+ command_buffer,
&model->bf16_f32_embeddings_fn,
/*threadgroup_size=*/512,
&context->token_buffer,
@@ -204,14 +200,14 @@ static enum gptoss_status process_tokens(
/*num_channels=*/model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
- goto cleanup;
+ return status;
}
for (uint32_t n = 0; n < model->num_blocks; n++) {
const bool last_block = n + 1 == model->num_blocks;
const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_rmsnorm_fn,
&context->residual_activation_buffer,
/*input_offset=*/0,
@@ -224,11 +220,11 @@ static enum gptoss_status process_tokens(
model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
- goto cleanup;
+ return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_matmul_fn,
/*threadgroup_size=*/256,
&context->rmsnorm_activation_buffer,
@@ -244,11 +240,11 @@ static enum gptoss_status process_tokens(
/*num_rows=*/attn_qkv_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
- goto cleanup;
+ return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
- &command_buffer,
+ command_buffer,
&model->f32_rope_fn,
/*threadgroup_size=*/32,
&context->qkv_activation_buffer,
@@ -264,12 +260,12 @@ static enum gptoss_status process_tokens(
/*token_offset=*/input_batch_start);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
- goto cleanup;
+ return status;
}
for (uint32_t t = 0; t < input_batch_size; t++) {
status = gptoss_metal_command_buffer_encode_copy_buffer(
- &command_buffer,
+ command_buffer,
&context->qkv_activation_buffer,
/*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float),
&context->kvcache_buffer,
@@ -277,13 +273,13 @@ static enum gptoss_status process_tokens(
/*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float));
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
- goto cleanup;
+ return status;
}
}
if (num_block_output_tokens != 0) {
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
- &command_buffer,
+ command_buffer,
&model->f32_sdpa_q8_d64_fn,
&context->qkv_activation_buffer,
/*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
@@ -301,10 +297,11 @@ static enum gptoss_status process_tokens(
model->num_heads, model->num_kv_heads, model->head_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
- goto cleanup;
+ return status;
}
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_matmul_fn,
/*threadgroup_size=*/256,
&context->sdpa_activation_buffer,
@@ -320,11 +317,11 @@ static enum gptoss_status process_tokens(
/*num_rows=*/model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
- goto cleanup;
+ return status;
}
-
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_rmsnorm_fn,
&context->residual_activation_buffer,
/*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
@@ -337,11 +334,11 @@ static enum gptoss_status process_tokens(
model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
- goto cleanup;
+ return status;
}
-
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_matmul_fn,
/*threadgroup_size=*/256,
&context->rmsnorm_activation_buffer,
@@ -357,15 +354,15 @@ static enum gptoss_status process_tokens(
/*num_rows=*/model->num_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
- goto cleanup;
+ return status;
}
-
+
const char* kernel_name = NULL;
switch (model->num_experts) {
case 32:
kernel_name = "f32_topk_softmax_e32_k4_fn";
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
- &command_buffer,
+ command_buffer,
&model->f32_topk_softmax_e32_k4_fn,
&context->gate_activation_buffer, /*input_offset=*/0,
&context->expert_activation_buffer, /*output_offset=*/0,
@@ -376,7 +373,7 @@ static enum gptoss_status process_tokens(
case 128:
kernel_name = "f32_topk_softmax_e128_k4_fn";
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
- &command_buffer,
+ command_buffer,
&model->f32_topk_softmax_e128_k4_fn,
&context->gate_activation_buffer, /*input_offset=*/0,
&context->expert_activation_buffer, /*output_offset=*/0,
@@ -387,15 +384,15 @@ static enum gptoss_status process_tokens(
default:
status = gptoss_status_unsupported_argument;
GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
- goto cleanup;
+ return status;
}
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
- goto cleanup;
+ return status;
}
-
+
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
- &command_buffer,
+ command_buffer,
&model->f32_mf4w_moe_matmul_swiglu_fn,
/*threadgroup_size=*/512,
&context->rmsnorm_activation_buffer,
@@ -418,11 +415,11 @@ static enum gptoss_status process_tokens(
model->mlp_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
- goto cleanup;
+ return status;
}
-
+
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
- &command_buffer,
+ command_buffer,
&model->f32_mf4w_moe_matmul_fn,
/*threadgroup_size=*/512,
&context->swiglu_activation_buffer,
@@ -444,11 +441,11 @@ static enum gptoss_status process_tokens(
model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
- goto cleanup;
+ return status;
}
-
+
status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
- &command_buffer,
+ command_buffer,
&model->f32_accumulate_e4_fn,
/*threadgroup_size=*/256,
model->max_threadgroups,
@@ -463,14 +460,14 @@ static enum gptoss_status process_tokens(
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
- goto cleanup;
+ return status;
}
}
}
if (output_batch_size != 0) {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_rmsnorm_fn,
&context->residual_activation_buffer,
/*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
@@ -483,22 +480,22 @@ static enum gptoss_status process_tokens(
model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
- goto cleanup;
+ return status;
}
status = gptoss_metal_command_buffer_encode_fill_buffer(
- &command_buffer,
+ command_buffer,
&context->argmax_buffer,
/*offset=*/0,
/*size=*/sizeof(uint64_t) * output_batch_size,
/*fill_value=*/0xFF);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode fill buffer command");
- goto cleanup;
+ return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
- &command_buffer,
+ command_buffer,
&model->f32_bf16w_unembedding_fn,
/*threadgroup_size=*/256,
model->max_threadgroups,
@@ -515,17 +512,11 @@ static enum gptoss_status process_tokens(
/*num_rows=*/model->vocabulary_size);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
- goto cleanup;
+ return status;
}
}
}
-
- gptoss_metal_command_buffer_commit(&command_buffer);
- gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
-
-cleanup:
- gptoss_metal_command_buffer_release(&command_buffer);
- return status;
+ return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
@@ -643,16 +634,38 @@ enum gptoss_status GPTOSS_ABI gptoss_context_process(
gptoss_context_t context)
{
if (context->num_tokens > context->num_kv_tokens) {
- enum gptoss_status status = process_tokens(
+ struct gptoss_metal_command_buffer command_buffer = {0};
+
+ enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
+ status = process_tokens(
context,
+ &command_buffer,
/*input_tokens_offset=*/context->num_kv_tokens,
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
/*num_output_tokens=*/0);
if (status != gptoss_status_success) {
- return status;
+ goto cleanup;
+ }
+
+ status = gptoss_metal_command_buffer_commit(&command_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
+ status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
+ if (status != gptoss_status_success) {
+ goto cleanup;
}
context->num_kv_tokens = context->num_tokens;
+
+cleanup:
+ gptoss_metal_command_buffer_release(&command_buffer);
+ return status;
}
return gptoss_status_success;
@@ -669,9 +682,16 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
struct gptoss_metal_command_buffer command_buffer = {0};
*token_out = UINT32_MAX;
+
+ status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
if (context->num_kv_tokens < context->num_tokens) {
status = process_tokens(
context,
+ &command_buffer,
/*input_tokens_offset=*/context->num_kv_tokens,
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
/*num_output_tokens=*/1);
@@ -679,30 +699,23 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
} else {
status = process_tokens(
context,
+ &command_buffer,
/*input_tokens_offset=*/context->num_tokens - 1,
/*num_input_tokens=*/1,
/*num_output_tokens=*/1);
}
if (status != gptoss_status_success) {
- return status;
+ goto cleanup;
}
- if (temperature == 0.0f) {
- const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0];
- *token_out = (uint32_t) argmax_bits;
- } else {
+ if (temperature != 0.0f) {
assert(context->num_processed_tokens != 0);
- status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
-
uint32_t num_threadgroups = 0;
uint32_t num_dims_per_threadgroup = 0;
status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
&command_buffer,
&model->f32_softmax_fn,
- /*threadgroup_size=*/256,
+ /*threadgroup_size=*/512,
model->max_threadgroups,
&context->score_buffer,
/*score_offset=*/0,
@@ -719,65 +732,43 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
&num_dims_per_threadgroup);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
+ goto cleanup;
}
- gptoss_metal_command_buffer_commit(&command_buffer);
- gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
-
- const uint32_t sample_word = rng_squares32(context->num_tokens, seed + UINT64_C(0x123456789ABCDEF));
- float sample_cdf = (float) ((int32_t) sample_word & INT32_C(0x00FFFFFF)) * 0x1.0p-24f;
-
- const float* sum_ptr = (const float*) context->sum_buffer.ptr;
- float sum = 0.0f;
- for (uint32_t i = 0; i < num_threadgroups; i++) {
- sum += sum_ptr[i];
- }
- sample_cdf *= sum;
-
- uint32_t block_idx = 0, token_idx = 0;
- if (sample_cdf == 0.0f) {
- // Make sure we choose the first token with non-zero probability rather than just the first token
- sample_cdf = FLT_TRUE_MIN;
- }
-
- // Step 1: find block
- float cumsum = 0.0f;
- for (; block_idx < num_threadgroups; block_idx++) {
- const float new_cumsum = cumsum + sum_ptr[block_idx];
- if (new_cumsum >= sample_cdf) {
- break;
- }
- cumsum = new_cumsum;
- }
- if (block_idx == num_threadgroups) {
- block_idx -= 1;
- }
-
- // Step 2: find token
- const float* prob_ptr = (const float*) context->prob_buffer.ptr + block_idx * num_dims_per_threadgroup;
- assert(model->vocabulary_size > num_dims_per_threadgroup * block_idx);
- uint32_t num_dims_per_block = math_min(num_dims_per_threadgroup, model->vocabulary_size - num_dims_per_threadgroup * block_idx);
- for (; token_idx < num_dims_per_block; token_idx++) {
- const float new_cumsum = cumsum + prob_ptr[token_idx];
- if (new_cumsum >= sample_cdf) {
- break;
- }
- cumsum = new_cumsum;
- }
- if (token_idx == num_dims_per_block) {
- token_idx -= 1;
+ status = gptoss_metal_command_buffer_encode_launch_f32_sample(
+ &command_buffer,
+ &model->f32_sample_fn,
+ /*min_threadgroup_size=*/512,
+ &context->prob_buffer,
+ /*prob_offset=*/0,
+ &context->sum_buffer,
+ /*sum_offset=*/0,
+ &context->argmax_buffer,
+ /*prediction_offset=*/0,
+ /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
+ /*num_blocks=*/num_threadgroups,
+ /*num_channels=*/model->vocabulary_size,
+ /*num_channels_per_block=*/num_dims_per_threadgroup,
+ /*token_offset=*/context->num_tokens);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
+ goto cleanup;
}
+ }
- token_idx += block_idx * num_dims_per_threadgroup;
-
- *token_out = token_idx;
+ gptoss_metal_command_buffer_commit(&command_buffer);
+ gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
-cleanup:
- gptoss_metal_command_buffer_release(&command_buffer);
- return status;
+ if (temperature == 0.0f) {
+ const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0];
+ *token_out = (uint32_t) argmax_bits;
+ } else {
+ *token_out = ((uint32_t*) context->argmax_buffer.ptr)[0];
}
- return gptoss_status_success;
+cleanup:
+ gptoss_metal_command_buffer_release(&command_buffer);
+ return status;
}
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h
index 677ce488..a031902d 100644
--- a/gpt_oss/metal/source/include/internal/kernel-args.h
+++ b/gpt_oss/metal/source/include/internal/kernel-args.h
@@ -103,3 +103,11 @@ struct gptoss_softmax_args {
uint32_t max_threadgroups;
float temperature;
};
+
+struct gptoss_sample_args {
+ uint64_t seed;
+ uint32_t token_offset;
+ uint32_t num_blocks;
+ uint32_t num_dims;
+ uint32_t num_dims_per_block;
+};
diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h
index aa5a3ef7..64cb36e0 100644
--- a/gpt_oss/metal/source/include/internal/metal-kernels.h
+++ b/gpt_oss/metal/source/include/internal/metal-kernels.h
@@ -265,6 +265,22 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
uint32_t* num_threadgroups_out,
uint32_t* num_channels_per_threadgroup_out);
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_sample_fn,
+ size_t min_threadgroup_size,
+ const struct gptoss_metal_buffer* prob_buffer,
+ size_t prob_offset,
+ const struct gptoss_metal_buffer* sum_buffer,
+ size_t sum_offset,
+ const struct gptoss_metal_buffer* prediction_buffer,
+ size_t prediction_offset,
+ uint64_t rng_seed,
+ uint32_t num_blocks,
+ uint32_t num_channels,
+ uint32_t num_channels_per_block,
+ uint32_t token_offset);
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h
index ae62a3ec..c17510b8 100644
--- a/gpt_oss/metal/source/include/internal/model.h
+++ b/gpt_oss/metal/source/include/internal/model.h
@@ -77,6 +77,7 @@ struct gptoss_model {
struct gptoss_metal_function f32_topk_softmax_e128_k4_fn;
struct gptoss_metal_function f32_sdpa_q8_d64_fn;
struct gptoss_metal_function f32_softmax_fn;
+ struct gptoss_metal_function f32_sample_fn;
size_t per_block_shared_weights_size;
size_t per_expert_block_weight_size;
diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c
index 46fd1586..a9a5253c 100644
--- a/gpt_oss/metal/source/metal-kernels.c
+++ b/gpt_oss/metal/source/metal-kernels.c
@@ -836,3 +836,56 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
(const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset},
/*threadgroup_buffer_size=*/0);
}
+
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_sample_fn,
+ size_t min_threadgroup_size,
+ const struct gptoss_metal_buffer* prob_buffer,
+ size_t prob_offset,
+ const struct gptoss_metal_buffer* sum_buffer,
+ size_t sum_offset,
+ const struct gptoss_metal_buffer* prediction_buffer,
+ size_t prediction_offset,
+ uint64_t rng_seed,
+ uint32_t num_blocks,
+ uint32_t num_channels,
+ uint32_t num_channels_per_block,
+ uint32_t token_offset)
+{
+ if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) {
+ return gptoss_status_invalid_state;
+ }
+
+ if (min_threadgroup_size > f32_sample_fn->max_threadgroup_threads) {
+ return gptoss_status_invalid_argument;
+ }
+
+ if (min_threadgroup_size % f32_sample_fn->simdgroup_threads != 0) {
+ return gptoss_status_invalid_argument;
+ }
+
+ if (num_blocks > f32_sample_fn->max_threadgroup_threads) {
+ return gptoss_status_invalid_argument;
+ }
+
+ const struct gptoss_sample_args args = {
+ .seed = rng_seed,
+ .token_offset = token_offset,
+ .num_blocks = num_blocks,
+ .num_dims = num_channels,
+ .num_dims_per_block = num_channels_per_block,
+ };
+
+ const size_t threadgroup_size = math_max(min_threadgroup_size,
+ math_round_up_po2(num_blocks, f32_sample_fn->simdgroup_threads));
+ return gptoss_metal_command_buffer_encode_launch_kernel(
+ command_buffer, f32_sample_fn,
+ threadgroup_size, 1, 1,
+ 1, 1, 1,
+ sizeof(args), &args,
+ 3,
+ (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, prediction_buffer},
+ (const size_t[]) {prob_offset, sum_offset, prediction_offset},
+ /*threadgroup_buffer_size=*/0);
+}
diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c
index 70668639..7a0450ce 100644
--- a/gpt_oss/metal/source/model.c
+++ b/gpt_oss/metal/source/model.c
@@ -356,6 +356,10 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
if (status != gptoss_status_success) {
goto cleanup;
}
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
if (status != gptoss_status_success) {
goto cleanup;
@@ -495,6 +499,7 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release(
gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
gptoss_metal_function_release(&model->f32_softmax_fn);
+ gptoss_metal_function_release(&model->f32_sample_fn);
gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
gptoss_metal_library_release(&model->library);
diff --git a/gpt_oss/metal/source/sample.metal b/gpt_oss/metal/source/sample.metal
index b739f72c..8ce4598b 100644
--- a/gpt_oss/metal/source/sample.metal
+++ b/gpt_oss/metal/source/sample.metal
@@ -9,6 +9,27 @@
#pragma METAL fp contract(off)
+inline static uint rng_squares32(ulong offset, ulong seed) {
+ const ulong y = offset * seed;
+ const ulong z = y + seed;
+
+ /* Round 1 */
+ ulong x = y * y + y;
+ x = metal::rotate(x, 32ul);
+
+ /* Round 2 */
+ x = x * x + z;
+ x = metal::rotate(x, 32ul);
+
+ /* Round 3 */
+ x = x * x + y;
+ x = metal::rotate(x, 32ul);
+
+ /* Round 4 */
+ x = x * x + z;
+ return as_type(x).y;
+}
+
kernel void gptoss_f32_softmax(
constant gptoss_softmax_args& args [[ buffer(0) ]],
const device float* score [[ buffer(1) ]],
@@ -58,3 +79,123 @@ kernel void gptoss_f32_softmax(
}
}
}
+
+[[max_total_threads_per_threadgroup(1024)]]
+kernel void gptoss_f32_sample(
+ constant gptoss_sample_args& args [[ buffer(0) ]],
+ device const float* prob [[ buffer(1) ]],
+ device const float* sum [[ buffer(2) ]],
+ device uint* prediction [[ buffer(3) ]],
+ uint tid [[thread_position_in_threadgroup]],
+ uint threadgroup_size [[threads_per_threadgroup]],
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
+{
+ threadgroup float threadgroup_sum_buffer[32];
+ threadgroup uint threadgroup_idx_buffer[32];
+ threadgroup float threadgroup_cumsum_buffer[32];
+
+ const uint sample_word = rng_squares32(args.token_offset, args.seed);
+ float sample_cdf = static_cast(sample_word & 0x00FFFFFFu) * 0x1.0p-24f;
+
+ float cumsum = 0.0f;
+ if (tid < args.num_blocks) {
+ cumsum = sum[tid];
+ }
+ cumsum = metal::simd_prefix_inclusive_sum(cumsum);
+ if (simdgroup_tid == 31) {
+ threadgroup_sum_buffer[simdgroup_idx] = cumsum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
+ if (simdgroup_tid < num_simdgroups) {
+ threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
+ if (simdgroup_tid < simdgroup_idx) {
+ threadgroup_cumsum = threadgroup_sum;
+ }
+ }
+ threadgroup_sum = metal::simd_sum(threadgroup_sum);
+ cumsum += metal::simd_sum(threadgroup_cumsum);
+
+ sample_cdf *= threadgroup_sum;
+ sample_cdf = metal::max(sample_cdf, 0x1.0p-149f);
+
+ // Find the block: the smallest tid where sample_cdf >= s
+ uint block_idx = args.num_blocks;
+ float block_sum = cumsum;
+ if (tid >= args.num_blocks - 1) {
+ block_idx = args.num_blocks - 1;
+ block_sum = 0.0f;
+ } else if (cumsum >= sample_cdf) {
+ block_idx = tid;
+ block_sum = 0.0f;
+ }
+ block_idx = metal::simd_min(block_idx);
+ block_sum = metal::simd_max(block_sum);
+ if (simdgroup_tid == 0) {
+ threadgroup_idx_buffer[simdgroup_idx] = block_idx;
+ threadgroup_cumsum_buffer[simdgroup_idx] = block_sum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ if (simdgroup_tid < num_simdgroups) {
+ block_idx = threadgroup_idx_buffer[simdgroup_tid];
+ block_sum = threadgroup_cumsum_buffer[simdgroup_tid];
+ }
+ block_idx = metal::simd_min(block_idx);
+ block_sum = metal::simd_max(block_sum);
+
+ const uint block_start = args.num_dims_per_block * block_idx;
+ const uint block_end = metal::min(block_start + args.num_dims_per_block, args.num_dims);
+ uint offset = block_start + tid;
+ float accumulated_sum = block_sum;
+ uint sample_idx;
+
+ // This loop must be threadgroup-uniform.
+ do {
+ // Find the token: the smallest tid where sample_cdf >= s
+ float cumsum = 0.0f;
+ if (offset < block_end) {
+ cumsum = prob[offset];
+ }
+ cumsum = metal::simd_prefix_inclusive_sum(cumsum);
+ if (simdgroup_tid == 31) {
+ threadgroup_sum_buffer[simdgroup_idx] = cumsum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
+ if (simdgroup_tid < num_simdgroups) {
+ threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
+ if (simdgroup_tid < simdgroup_idx) {
+ threadgroup_cumsum = threadgroup_sum;
+ }
+ }
+ threadgroup_sum = metal::simd_sum(threadgroup_sum);
+ cumsum += metal::simd_sum(threadgroup_cumsum);
+ cumsum += accumulated_sum;
+
+ sample_idx = block_end;
+ if (offset >= block_end) {
+ // Trigger loop exit, with the last token in the block being sampled if no other candidate was found.
+ sample_idx = block_end - 1;
+ } else if (cumsum >= sample_cdf) {
+ sample_idx = offset;
+ }
+ sample_idx = metal::simd_min(sample_idx);
+ if (simdgroup_tid == 0) {
+ threadgroup_idx_buffer[simdgroup_idx] = sample_idx;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ if (simdgroup_tid < num_simdgroups) {
+ sample_idx = threadgroup_idx_buffer[simdgroup_tid];
+ }
+ sample_idx = metal::simd_min(sample_idx);
+
+ offset += threadgroup_size;
+ accumulated_sum += threadgroup_sum;
+ } while (sample_idx == block_end);
+
+ if (tid == 0) {
+ *prediction = sample_idx;
+ }
+}
From f2a1458a5625adafbb9fcc8e0df363ef76ef170b Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Mon, 8 Sep 2025 14:21:39 -0700
Subject: [PATCH 85/91] Metal: benchmark generation of 100 tokens instead of 1
(#178)
---
gpt_oss/metal/benchmark/end-to-end.cc | 58 ++++++++++++++-----
gpt_oss/metal/source/include/internal/model.h | 25 ++++----
2 files changed, 60 insertions(+), 23 deletions(-)
diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc
index 4f73be7a..f4168f94 100644
--- a/gpt_oss/metal/benchmark/end-to-end.cc
+++ b/gpt_oss/metal/benchmark/end-to-end.cc
@@ -1,13 +1,20 @@
#include
+#include
-#include
+#include
+#include
#include
+#include
#include
+#include
#include
#include
+constexpr std::uint32_t num_generated_tokens = 100;
+
+
static void end2end(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
@@ -40,7 +47,8 @@ static void end2end(benchmark::State& state, const char* env_var_name) {
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
- status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), nullptr);
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
@@ -53,25 +61,49 @@ static void end2end(benchmark::State& state, const char* env_var_name) {
return;
}
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+ std::uint64_t rng_seed = 0;
for (std::uint32_t i = 0; i < 3; i++) {
- std::uint32_t predicted_token = std::numeric_limits::max();
- status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token);
- if (status != gptoss_status_success) {
- state.SkipWithError("failed to sample from the Context object");
- return;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ for (std::uint32_t n = 0; n < num_generated_tokens; n++) {
+ std::uint32_t predicted_token = std::numeric_limits::max();
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ status = gptoss_context_append_tokens(context.get(), 1, &predicted_token);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token));
+ return;
+ }
}
}
for (auto _ : state) {
- std::uint32_t predicted_token = std::numeric_limits::max();
- status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token);
- if (status != gptoss_status_success) {
- state.SkipWithError("failed to sample from the Context object");
- return;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ for (std::uint32_t n = 0; n < num_generated_tokens; n++) {
+ std::uint32_t predicted_token = std::numeric_limits::max();
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ status = gptoss_context_append_tokens(context.get(), 1, &predicted_token);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token));
+ return;
+ }
}
}
- state.counters["tokens"] =
+ state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * num_generated_tokens, benchmark::Counter::kIsRate);
}
BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH")
diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h
index c17510b8..50ed201c 100644
--- a/gpt_oss/metal/source/include/internal/model.h
+++ b/gpt_oss/metal/source/include/internal/model.h
@@ -1,6 +1,8 @@
#pragma once
-#include
+#ifndef __cplusplus
+ #include
+#endif
#include
#include
#include
@@ -9,7 +11,11 @@
struct gptoss_tokenizer {
+#ifndef __cplusplus
atomic_uint_least64_t ref_count;
+#else
+ uint_least64_t ref_count;
+#endif
void* mapping_ptr;
size_t mapping_size;
@@ -24,7 +30,11 @@ struct gptoss_tokenizer {
};
struct gptoss_model {
+#ifndef __cplusplus
atomic_uint_least64_t ref_count;
+#else
+ uint_least64_t ref_count;
+#endif
struct gptoss_tokenizer* tokenizer;
@@ -108,7 +118,11 @@ struct gptoss_model {
#define GPTOSS_DEFAULT_BATCH_SIZE 128
struct gptoss_context {
+#ifndef __cplusplus
atomic_uint_least64_t ref_count;
+#else
+ uint_least64_t ref_count;
+#endif
struct gptoss_model* model;
// Number of tokens processed in the context.
@@ -140,12 +154,3 @@ struct gptoss_context {
struct gptoss_metal_buffer argmax_buffer;
struct gptoss_metal_buffer kvcache_buffer;
};
-
-struct gptoss_sampler {
- atomic_uint_least64_t ref_count;
-
- float temperature;
- float top_p;
- float presence_penalty;
- float frequency_penalty;
-};
From 152fc0ce3b500752214c8d59440ef3a909e1e556 Mon Sep 17 00:00:00 2001
From: Maratyszcza
Date: Tue, 9 Sep 2025 10:20:33 -0700
Subject: [PATCH 86/91] Metal: support generating multiple tokens at once
(#179)
---
gpt_oss/metal/benchmark/end-to-end.cc | 52 ++---
gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc | 4 +
gpt_oss/metal/include/gpt-oss/functions.h | 4 +-
gpt_oss/metal/python/context.c | 57 +++--
gpt_oss/metal/source/accumulate.metal | 4 +
gpt_oss/metal/source/context.c | 200 ++++++++++++------
gpt_oss/metal/source/embeddings.metal | 5 +
gpt_oss/metal/source/generate.c | 3 +-
.../source/include/internal/kernel-args.h | 8 +-
.../source/include/internal/metal-kernels.h | 35 ++-
gpt_oss/metal/source/include/internal/model.h | 1 +
gpt_oss/metal/source/matmul.metal | 8 +
gpt_oss/metal/source/metal-kernels.c | 115 ++++++----
gpt_oss/metal/source/moematmul.metal | 8 +
gpt_oss/metal/source/rmsnorm.metal | 4 +
gpt_oss/metal/source/rope.metal | 5 +
gpt_oss/metal/source/sample.metal | 10 +-
gpt_oss/metal/source/sdpa.metal | 4 +
gpt_oss/metal/source/topk.metal | 8 +
.../metal/test/embeddings-kernel-tester.hpp | 4 +
gpt_oss/metal/test/matmul-kernel-tester.hpp | 4 +
gpt_oss/metal/test/rmsnorm-kernel-tester.hpp | 4 +
gpt_oss/metal/test/rope-kernel-tester.hpp | 5 +
gpt_oss/responses_api/inference/metal.py | 29 ++-
24 files changed, 409 insertions(+), 172 deletions(-)
diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc
index f4168f94..0a242340 100644
--- a/gpt_oss/metal/benchmark/end-to-end.cc
+++ b/gpt_oss/metal/benchmark/end-to-end.cc
@@ -1,6 +1,7 @@
#include
#include
+#include
#include
#include
#include
@@ -12,7 +13,7 @@
#include
-constexpr std::uint32_t num_generated_tokens = 100;
+constexpr std::uint32_t kNumGeneratedTokens = 100;
static void end2end(benchmark::State& state, const char* env_var_name) {
@@ -30,14 +31,6 @@ static void end2end(benchmark::State& state, const char* env_var_name) {
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
- gptoss_tokenizer_t tokenizer_ptr = nullptr;
- status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr);
- if (status != gptoss_status_success) {
- state.SkipWithError("failed to retrieve Tokenizer");
- return;
- }
- std::unique_ptr, decltype(&gptoss_tokenizer_release)> tokenizer(tokenizer_ptr, gptoss_tokenizer_release);
-
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
if (status != gptoss_status_success) {
@@ -60,50 +53,51 @@ static void end2end(benchmark::State& state, const char* env_var_name) {
state.SkipWithError("failed to prefill Context object");
return;
}
-
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
std::uint64_t rng_seed = 0;
for (std::uint32_t i = 0; i < 3; i++) {
+ const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
- for (std::uint32_t n = 0; n < num_generated_tokens; n++) {
- std::uint32_t predicted_token = std::numeric_limits::max();
- status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token);
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
- status = gptoss_context_append_tokens(context.get(), 1, &predicted_token);
- if (status != gptoss_status_success) {
- state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token));
- return;
- }
- }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
}
for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
- for (std::uint32_t n = 0; n < num_generated_tokens; n++) {
- std::uint32_t predicted_token = std::numeric_limits::max();
- status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token);
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
- status = gptoss_context_append_tokens(context.get(), 1, &predicted_token);
- if (status != gptoss_status_success) {
- state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token));
- return;
- }
- }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
}
+
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
- benchmark::Counter(state.iterations() * num_generated_tokens, benchmark::Counter::kIsRate);
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH")
diff --git a/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc b/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
index 17515942..ee7551c2 100644
--- a/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
+++ b/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
@@ -26,6 +26,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) {
Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)};
Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)};
Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)};
+ Buffer control_buffer{device, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
{
CommandBuffer command_buffer{command_queue};
@@ -69,6 +71,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens,
num_channels,
kEpsilon),
diff --git a/gpt_oss/metal/include/gpt-oss/functions.h b/gpt_oss/metal/include/gpt-oss/functions.h
index 085ebe0d..6ddde253 100644
--- a/gpt_oss/metal/include/gpt-oss/functions.h
+++ b/gpt_oss/metal/include/gpt-oss/functions.h
@@ -290,7 +290,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
- uint32_t* token_out);
+ size_t max_tokens,
+ uint32_t* tokens_out,
+ size_t* num_tokens_out);
/*
* Increments a Context object's reference count.
diff --git a/gpt_oss/metal/python/context.c b/gpt_oss/metal/python/context.c
index d71cc396..abc031af 100644
--- a/gpt_oss/metal/python/context.c
+++ b/gpt_oss/metal/python/context.c
@@ -120,25 +120,54 @@ static PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) {
}
static PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
- static char *kwlist[] = {"temperature", "seed", NULL};
+ static char *kwlist[] = {"max_output_tokens", "temperature", "seed", NULL};
+ PyObject* token_list_obj = NULL;
+ uint32_t* token_ptr = NULL;
+ unsigned int max_output_tokens = 0;
unsigned long long seed = 0;
float temperature = 1.0f;
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$fK", kwlist,
- &temperature, &seed))
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I|$fK", kwlist,
+ &max_output_tokens, &temperature, &seed))
{
return NULL;
}
- uint32_t token_out = UINT32_MAX;
- enum gptoss_status status = gptoss_context_sample(
- self->handle, temperature, (uint64_t) seed, &token_out);
+ token_ptr = (uint32_t*) PyMem_Malloc(max_output_tokens * sizeof(uint32_t));
+ if (token_ptr == NULL) {
+ goto error;
+ }
+
+ size_t num_tokens = 0;
+ const enum gptoss_status status = gptoss_context_sample(
+ self->handle, temperature, (uint64_t) seed,
+ (size_t) max_output_tokens, token_ptr, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
- return NULL;
+ goto error;
}
- return PyLong_FromUnsignedLong((unsigned long) token_out);
+ token_list_obj = PyList_New((Py_ssize_t) num_tokens);
+ if (token_list_obj == NULL) {
+ goto error;
+ }
+
+ for (size_t t = 0; t < num_tokens; t++) {
+ PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
+ if (token_obj == NULL) {
+ goto error;
+ }
+
+ PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
+ }
+
+ PyMem_Free(token_ptr);
+ return token_list_obj;
+
+error:
+ PyMem_Free(token_ptr);
+ Py_XDECREF(token_list_obj);
+ return NULL;
}
static PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) {
@@ -155,7 +184,7 @@ static PyMethodDef PyGPTOSSContext_methods[] = {
{"__copy__", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, "Create a copy of the Context"},
{"append", (PyCFunction) PyGPTOSSContext_append, METH_O, "Append bytes to the Context"},
{"process", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, "Process tokens in the Context"},
- {"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token prediction from the Context"},
+ {"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token predictions from the Context"},
{"reset", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, "Discard the content of the Context"},
{NULL},
};
@@ -184,7 +213,6 @@ static PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* clo
static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) {
PyObject* token_list_obj = NULL;
- PyObject* token_obj = NULL;
uint32_t* token_ptr = NULL;
size_t num_tokens = 0;
@@ -210,14 +238,12 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure
}
for (size_t t = 0; t < num_tokens; t++) {
- token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
+ PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
if (token_obj == NULL) {
goto error;
}
- if (PyList_SetItem(token_list_obj, (Py_ssize_t) t, token_obj) < 0) {
- goto error;
- }
- token_obj = NULL; // PyList_SetItem stole the reference
+
+ PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
}
PyMem_Free(token_ptr);
@@ -225,7 +251,6 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure
error:
PyMem_Free(token_ptr);
- Py_XDECREF(token_obj);
Py_XDECREF(token_list_obj);
return NULL;
}
diff --git a/gpt_oss/metal/source/accumulate.metal b/gpt_oss/metal/source/accumulate.metal
index f7ebc506..70dc4c2b 100644
--- a/gpt_oss/metal/source/accumulate.metal
+++ b/gpt_oss/metal/source/accumulate.metal
@@ -12,11 +12,15 @@ kernel void gptoss_f32_accumulate_e4(
const device float4* input [[ buffer(1) ]],
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
+ const device gptoss_control* control [[ buffer(4) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint2 threadgroup_size [[ threads_per_threadgroup ]])
{
const uint num_active_experts = 4;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c
index 0791c3eb..2d246294 100644
--- a/gpt_oss/metal/source/context.c
+++ b/gpt_oss/metal/source/context.c
@@ -82,6 +82,10 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
}
// Input/output buffers
+ status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
if (status != gptoss_status_success) {
goto cleanup;
@@ -196,6 +200,8 @@ static enum gptoss_status process_tokens(
/*weight_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_channels=*/model->embedding_dim);
if (status != gptoss_status_success) {
@@ -215,6 +221,8 @@ static enum gptoss_status process_tokens(
/*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
&context->rmsnorm_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_channels=*/model->embedding_dim,
model->rmsnorm_epsilon);
@@ -235,6 +243,8 @@ static enum gptoss_status process_tokens(
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
&context->qkv_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_cols=*/model->embedding_dim,
/*num_rows=*/attn_qkv_dim);
@@ -248,6 +258,9 @@ static enum gptoss_status process_tokens(
&model->f32_rope_fn,
/*threadgroup_size=*/32,
&context->qkv_activation_buffer,
+ /*input_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
model->rope_theta,
model->interpolation_scale,
model->yarn_offset,
@@ -291,6 +304,8 @@ static enum gptoss_status process_tokens(
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
&context->sdpa_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
num_block_output_tokens,
input_batch_start + input_batch_size - num_block_output_tokens,
@@ -312,6 +327,8 @@ static enum gptoss_status process_tokens(
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/num_block_output_tokens,
/*num_cols=*/model->num_heads * model->head_dim,
/*num_rows=*/model->embedding_dim);
@@ -329,6 +346,8 @@ static enum gptoss_status process_tokens(
/*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
&context->rmsnorm_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
num_block_output_tokens,
model->embedding_dim,
model->rmsnorm_epsilon);
@@ -349,6 +368,8 @@ static enum gptoss_status process_tokens(
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
&context->gate_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/num_block_output_tokens,
/*num_cols=*/model->embedding_dim,
/*num_rows=*/model->num_experts);
@@ -366,6 +387,7 @@ static enum gptoss_status process_tokens(
&model->f32_topk_softmax_e32_k4_fn,
&context->gate_activation_buffer, /*input_offset=*/0,
&context->expert_activation_buffer, /*output_offset=*/0,
+ &context->control_buffer, /*control_offset=*/0,
num_block_output_tokens,
model->num_experts,
model->num_active_experts);
@@ -377,6 +399,7 @@ static enum gptoss_status process_tokens(
&model->f32_topk_softmax_e128_k4_fn,
&context->gate_activation_buffer, /*input_offset=*/0,
&context->expert_activation_buffer, /*output_offset=*/0,
+ &context->control_buffer, /*control_offset=*/0,
num_block_output_tokens,
model->num_experts,
model->num_active_experts);
@@ -407,6 +430,8 @@ static enum gptoss_status process_tokens(
/*bias_offset=*/model->mlp_swiglu_bias_offset,
&context->swiglu_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
model->swiglu_limit,
model->per_expert_block_weight_size,
num_block_output_tokens,
@@ -434,6 +459,8 @@ static enum gptoss_status process_tokens(
/*bias_offset=*/model->mlp_out_bias_offset,
&context->moe_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
@@ -455,6 +482,8 @@ static enum gptoss_status process_tokens(
/*expert_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &context->control_buffer,
+ /*control_offset=*/0,
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
@@ -475,6 +504,8 @@ static enum gptoss_status process_tokens(
/*weight_offset=*/model->rmsnorm_weight_offset,
&context->rmsnorm_activation_buffer,
/*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/output_batch_size,
/*num_channels=*/model->embedding_dim,
model->rmsnorm_epsilon);
@@ -507,6 +538,8 @@ static enum gptoss_status process_tokens(
/*output_offset=*/0,
&context->argmax_buffer,
/*argmax_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
/*num_tokens=*/output_batch_size,
/*num_cols=*/model->embedding_dim,
/*num_rows=*/model->vocabulary_size);
@@ -641,6 +674,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_process(
goto cleanup;
}
+ struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
+ control->abort = 0;
+
status = process_tokens(
context,
&command_buffer,
@@ -675,96 +711,121 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
- uint32_t* token_out)
+ size_t max_tokens,
+ uint32_t* tokens_out,
+ size_t* num_tokens_out)
{
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
struct gptoss_metal_command_buffer command_buffer = {0};
- *token_out = UINT32_MAX;
+ *num_tokens_out = 0;
- status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
+ const uint32_t num_original_tokens = context->num_tokens;
- if (context->num_kv_tokens < context->num_tokens) {
- status = process_tokens(
- context,
- &command_buffer,
- /*input_tokens_offset=*/context->num_kv_tokens,
- /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
- /*num_output_tokens=*/1);
- context->num_kv_tokens = context->num_tokens;
- } else {
- status = process_tokens(
- context,
- &command_buffer,
- /*input_tokens_offset=*/context->num_tokens - 1,
- /*num_input_tokens=*/1,
- /*num_output_tokens=*/1);
- }
+ status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
- if (temperature != 0.0f) {
- assert(context->num_processed_tokens != 0);
- uint32_t num_threadgroups = 0;
- uint32_t num_dims_per_threadgroup = 0;
- status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
- &command_buffer,
- &model->f32_softmax_fn,
- /*threadgroup_size=*/512,
- model->max_threadgroups,
- &context->score_buffer,
- /*score_offset=*/0,
- &context->argmax_buffer,
- /*argmax_offset=*/0,
- &context->prob_buffer,
- /*prob_offset=*/0,
- &context->sum_buffer,
- /*sum_offset=*/0,
- model->vocabulary_size,
- /*num_tokens=*/1,
- temperature,
- &num_threadgroups,
- &num_dims_per_threadgroup);
+ struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
+ control->abort = 0;
+
+ for (size_t t = 0; t < max_tokens; t++) {
+ if (context->num_kv_tokens < context->num_tokens) {
+ status = process_tokens(
+ context,
+ &command_buffer,
+ /*input_tokens_offset=*/context->num_kv_tokens,
+ /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
+ /*num_output_tokens=*/1);
+ context->num_kv_tokens = context->num_tokens;
+ } else {
+ status = process_tokens(
+ context,
+ &command_buffer,
+ /*input_tokens_offset=*/context->num_tokens - 1,
+ /*num_input_tokens=*/1,
+ /*num_output_tokens=*/1);
+ }
if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
goto cleanup;
}
- status = gptoss_metal_command_buffer_encode_launch_f32_sample(
- &command_buffer,
- &model->f32_sample_fn,
- /*min_threadgroup_size=*/512,
- &context->prob_buffer,
- /*prob_offset=*/0,
- &context->sum_buffer,
- /*sum_offset=*/0,
- &context->argmax_buffer,
- /*prediction_offset=*/0,
- /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
- /*num_blocks=*/num_threadgroups,
- /*num_channels=*/model->vocabulary_size,
- /*num_channels_per_block=*/num_dims_per_threadgroup,
- /*token_offset=*/context->num_tokens);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
- goto cleanup;
+ if (temperature != 0.0f) {
+ assert(context->num_processed_tokens != 0);
+ uint32_t num_threadgroups = 0;
+ uint32_t num_dims_per_threadgroup = 0;
+ status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
+ &command_buffer,
+ &model->f32_softmax_fn,
+ /*threadgroup_size=*/512,
+ model->max_threadgroups,
+ &context->score_buffer,
+ /*score_offset=*/0,
+ &context->argmax_buffer,
+ /*argmax_offset=*/0,
+ &context->prob_buffer,
+ /*prob_offset=*/0,
+ &context->sum_buffer,
+ /*sum_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ model->vocabulary_size,
+ /*num_tokens=*/1,
+ temperature,
+ &num_threadgroups,
+ &num_dims_per_threadgroup);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
+ goto cleanup;
+ }
+
+ status = gptoss_metal_command_buffer_encode_launch_f32_sample(
+ &command_buffer,
+ &model->f32_sample_fn,
+ /*min_threadgroup_size=*/512,
+ &context->prob_buffer,
+ /*prob_offset=*/0,
+ &context->sum_buffer,
+ /*sum_offset=*/0,
+ &context->token_buffer,
+ /*token_offset=*/context->num_tokens * sizeof(uint32_t),
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
+ /*rng_offset=*/context->num_tokens,
+ /*num_blocks=*/num_threadgroups,
+ /*num_channels=*/model->vocabulary_size,
+ /*num_channels_per_block=*/num_dims_per_threadgroup);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
+ goto cleanup;
+ }
+ } else {
+ status = gptoss_metal_command_buffer_encode_copy_buffer(
+ &command_buffer,
+ &context->argmax_buffer,
+ /*input_offset=*/0,
+ &context->token_buffer,
+ /*output_offset=*/context->num_tokens * sizeof(uint32_t),
+ /*size=*/sizeof(uint32_t));
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode copy buffer");
+ goto cleanup;
+ }
}
+ context->num_tokens += 1;
+ context->num_kv_tokens = context->num_tokens;
}
gptoss_metal_command_buffer_commit(&command_buffer);
gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
- if (temperature == 0.0f) {
- const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0];
- *token_out = (uint32_t) argmax_bits;
- } else {
- *token_out = ((uint32_t*) context->argmax_buffer.ptr)[0];
- }
+ const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
+ const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
+ memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
+ *num_tokens_out = num_generated_tokens;
cleanup:
gptoss_metal_command_buffer_release(&command_buffer);
@@ -805,6 +866,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release(
gptoss_metal_buffer_release(&context->moe_activation_buffer);
// Input/output buffers
+ gptoss_metal_buffer_release(&context->control_buffer);
gptoss_metal_buffer_release(&context->token_buffer);
gptoss_metal_buffer_release(&context->score_buffer);
gptoss_metal_buffer_release(&context->prob_buffer);
diff --git a/gpt_oss/metal/source/embeddings.metal b/gpt_oss/metal/source/embeddings.metal
index b4541d21..9cc7d121 100644
--- a/gpt_oss/metal/source/embeddings.metal
+++ b/gpt_oss/metal/source/embeddings.metal
@@ -9,10 +9,15 @@ kernel void gptoss_bf16_f32_embeddings(
const device uint* tokens [[ buffer(1) ]],
const device bfloat4* weights [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
+ const device gptoss_control* control [[ buffer(4) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
+ if (control->abort != 0) {
+ return;
+ }
+
const uint t = tokens[gid];
weights += t * args.num_vecs;
diff --git a/gpt_oss/metal/source/generate.c b/gpt_oss/metal/source/generate.c
index 1711410a..36a5527b 100644
--- a/gpt_oss/metal/source/generate.c
+++ b/gpt_oss/metal/source/generate.c
@@ -268,8 +268,9 @@ int main(int argc, char *argv[]) {
while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
uint32_t predicted_token = UINT32_MAX;
+ size_t num_predicted_tokens = 0;
const uint64_t inference_start_timestamp = mach_continuous_time();
- status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, &predicted_token);
+ status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to sample from the Context object\n");
goto error;
diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h
index a031902d..259eaa8a 100644
--- a/gpt_oss/metal/source/include/internal/kernel-args.h
+++ b/gpt_oss/metal/source/include/internal/kernel-args.h
@@ -9,6 +9,10 @@ struct gptoss_expert_prediction {
float score;
};
+struct gptoss_control {
+ uint32_t abort;
+};
+
struct gptoss_topk_args {
uint32_t num_vecs_per_token;
};
@@ -105,8 +109,8 @@ struct gptoss_softmax_args {
};
struct gptoss_sample_args {
- uint64_t seed;
- uint32_t token_offset;
+ uint64_t rng_seed;
+ uint32_t rng_offset;
uint32_t num_blocks;
uint32_t num_dims;
uint32_t num_dims_per_block;
diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h
index 64cb36e0..269f025d 100644
--- a/gpt_oss/metal/source/include/internal/metal-kernels.h
+++ b/gpt_oss/metal/source/include/internal/metal-kernels.h
@@ -74,6 +74,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels);
@@ -86,6 +88,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels,
float epsilon);
@@ -102,6 +106,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
@@ -118,6 +124,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
@@ -135,6 +143,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
size_t output_offset,
const struct gptoss_metal_buffer* argmax_buffer,
size_t argmax_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
@@ -155,6 +165,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float swiglu_limit,
uint32_t expert_stride,
uint32_t num_tokens,
@@ -178,6 +190,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t expert_stride,
uint32_t num_tokens,
uint32_t num_active_experts,
@@ -189,6 +203,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
const struct gptoss_metal_function* f32_rope_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* activations_buffer,
+ size_t activations_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float rope_base,
float interpolation_scale,
float yarn_offset,
@@ -211,6 +228,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
size_t expert_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_experts);
@@ -222,6 +241,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
size_t input_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_active_experts);
@@ -239,6 +260,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
size_t s_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t window,
uint32_t num_q_tokens,
uint32_t num_kv_tokens,
@@ -259,6 +282,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
float temperature,
@@ -273,13 +298,15 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
- const struct gptoss_metal_buffer* prediction_buffer,
- size_t prediction_offset,
+ const struct gptoss_metal_buffer* token_buffer,
+ size_t token_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint64_t rng_seed,
+ uint32_t rng_offset,
uint32_t num_blocks,
uint32_t num_channels,
- uint32_t num_channels_per_block,
- uint32_t token_offset);
+ uint32_t num_channels_per_block);
#ifdef __cplusplus
} // extern "C"
diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h
index 50ed201c..34e273aa 100644
--- a/gpt_oss/metal/source/include/internal/model.h
+++ b/gpt_oss/metal/source/include/internal/model.h
@@ -147,6 +147,7 @@ struct gptoss_context {
struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
// Input/output buffers.
+ struct gptoss_metal_buffer control_buffer;
struct gptoss_metal_buffer token_buffer; // uint32 token IDs
struct gptoss_metal_buffer score_buffer; // unembedding outputs
struct gptoss_metal_buffer prob_buffer;
diff --git a/gpt_oss/metal/source/matmul.metal b/gpt_oss/metal/source/matmul.metal
index 6396f6cc..a4ec60d5 100644
--- a/gpt_oss/metal/source/matmul.metal
+++ b/gpt_oss/metal/source/matmul.metal
@@ -23,12 +23,16 @@ kernel void gptoss_f32_bf16w_matmul(
const device bfloat4* weight [[ buffer(2) ]],
const device bfloat* bias [[ buffer(3) ]],
device float* output [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(5) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
const uint simdgroup_size = 32;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row = gid.x * num_simdgroups + simdgroup_idx;
@@ -68,6 +72,7 @@ kernel void gptoss_f32_bf16w_unembedding(
const device bfloat4* weight [[ buffer(2) ]],
device float* output [[ buffer(3) ]],
device metal::atomic_ulong* argmax [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(5) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
@@ -75,6 +80,9 @@ kernel void gptoss_f32_bf16w_unembedding(
{
const uint simdgroup_size = 32;
threadgroup uint2 threadgroup_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx;
diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c
index a9a5253c..1316fa50 100644
--- a/gpt_oss/metal/source/metal-kernels.c
+++ b/gpt_oss/metal/source/metal-kernels.c
@@ -197,6 +197,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels)
{
@@ -224,9 +226,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
threadgroup_size, 1, 1,
num_tokens, 1, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer},
- (const size_t[]) {token_offset, weight_offset, output_offset},
+ 4,
+ (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer, control_buffer},
+ (const size_t[]) {token_offset, weight_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -239,6 +241,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels,
float epsilon)
@@ -271,9 +275,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
/*threadgroup_size=*/1024, 1, 1,
num_tokens, 1, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer},
- (const size_t[]) {input_offset, weight_offset, output_offset},
+ 4,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -289,6 +293,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows)
@@ -329,9 +335,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
threadgroup_size, 1, 1,
num_rows / num_simdgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset},
+ 5,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -347,6 +353,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows)
@@ -387,9 +395,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
threadgroup_size, 1, 1,
num_rows / num_simdgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset},
+ 5,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -406,6 +414,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
size_t output_offset,
const struct gptoss_metal_buffer* argmax_buffer,
size_t argmax_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows)
@@ -443,9 +453,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
threadgroup_size, 1, 1,
num_threadgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer},
- (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset},
+ 5,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -465,6 +475,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float swiglu_limit,
uint32_t expert_stride,
uint32_t num_tokens,
@@ -517,9 +529,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
threadgroup_size, 1, 1,
(2 * num_rows) / num_simdgroups, num_tokens, num_active_experts,
sizeof(args), &args,
- 6,
- (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
+ 7,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -539,6 +551,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t expert_stride,
uint32_t num_tokens,
uint32_t num_active_experts,
@@ -589,9 +603,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
threadgroup_size, 1, 1,
num_rows / num_simdgroups, num_tokens, num_active_experts,
sizeof(args), &args,
- 6,
- (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
+ 7,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -600,6 +614,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
const struct gptoss_metal_function* f32_rope_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* activations_buffer,
+ size_t activations_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float rope_base,
float interpolation_scale,
float yarn_offset,
@@ -642,7 +659,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
threadgroup_size, 1, 1,
num_qk_heads / num_simdgroups, num_tokens, 1,
sizeof(args), &args,
- 1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL,
+ 2,
+ (const struct gptoss_metal_buffer *[]) {activations_buffer, control_buffer},
+ (const size_t[]) {activations_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -657,6 +676,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
size_t expert_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_experts)
@@ -690,9 +711,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
threadgroup_size, 1, 1,
num_threadgroups, num_tokens, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer},
- (const size_t[]) {input_offset, expert_offset, output_offset},
+ 4,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, expert_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -703,6 +724,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
size_t input_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_active_experts)
@@ -726,9 +749,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
/*threadgroup_size=*/32, 1, 1,
num_tokens, 1, 1,
sizeof(args), &args,
- 2,
- (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer},
- (const size_t[]) {input_offset, output_offset},
+ 3,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -745,6 +768,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
size_t s_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t window,
uint32_t num_q_tokens,
uint32_t num_kv_tokens,
@@ -783,9 +808,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
threadgroup_size, 1, 1,
num_q_tokens, num_kv_heads, 1,
sizeof(args), &args,
- 5,
- (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer},
- (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset},
+ 6,
+ (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer, control_buffer},
+ (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset, control_offset},
/*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float));
}
@@ -802,6 +827,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
float temperature,
@@ -831,9 +858,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
threadgroup_size, 1, 1,
num_threadgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer},
- (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset},
+ 5,
+ (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer, control_buffer},
+ (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
@@ -845,13 +872,15 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
- const struct gptoss_metal_buffer* prediction_buffer,
- size_t prediction_offset,
+ const struct gptoss_metal_buffer* token_buffer,
+ size_t token_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint64_t rng_seed,
+ uint32_t rng_offset,
uint32_t num_blocks,
uint32_t num_channels,
- uint32_t num_channels_per_block,
- uint32_t token_offset)
+ uint32_t num_channels_per_block)
{
if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) {
return gptoss_status_invalid_state;
@@ -870,8 +899,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
}
const struct gptoss_sample_args args = {
- .seed = rng_seed,
- .token_offset = token_offset,
+ .rng_seed = rng_seed,
+ .rng_offset = rng_offset,
.num_blocks = num_blocks,
.num_dims = num_channels,
.num_dims_per_block = num_channels_per_block,
@@ -884,8 +913,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
threadgroup_size, 1, 1,
1, 1, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, prediction_buffer},
- (const size_t[]) {prob_offset, sum_offset, prediction_offset},
+ 4,
+ (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, token_buffer, control_buffer},
+ (const size_t[]) {prob_offset, sum_offset, token_offset, control_offset},
/*threadgroup_buffer_size=*/0);
}
diff --git a/gpt_oss/metal/source/moematmul.metal b/gpt_oss/metal/source/moematmul.metal
index 6e2f6950..58247484 100644
--- a/gpt_oss/metal/source/moematmul.metal
+++ b/gpt_oss/metal/source/moematmul.metal
@@ -24,6 +24,7 @@ kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
const device uchar* weight_scales [[ buffer(4) ]],
const device bfloat* bias [[ buffer(5) ]],
device float* output [[ buffer(6) ]],
+ const device gptoss_control* control [[ buffer(7) ]],
uint3 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
@@ -32,6 +33,9 @@ kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
{
const uint simdgroup_size = 32;
threadgroup float threadgroup_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row = gid.x * num_simdgroups + simdgroup_idx;
@@ -130,6 +134,7 @@ kernel void gptoss_f32_mf4w_moe_matmul(
const device uchar* weight_scales [[ buffer(4) ]],
const device bfloat* bias [[ buffer(5) ]],
device float* output [[ buffer(6) ]],
+ const device gptoss_control* control [[ buffer(7) ]],
uint3 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
@@ -137,6 +142,9 @@ kernel void gptoss_f32_mf4w_moe_matmul(
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
const uint simdgroup_size = 32;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row = gid.x * num_simdgroups + simdgroup_idx;
diff --git a/gpt_oss/metal/source/rmsnorm.metal b/gpt_oss/metal/source/rmsnorm.metal
index ceb690f0..fc4bcaa2 100644
--- a/gpt_oss/metal/source/rmsnorm.metal
+++ b/gpt_oss/metal/source/rmsnorm.metal
@@ -14,12 +14,16 @@ kernel void gptoss_f32_bf16w_rmsnorm(
const device float4* input [[ buffer(1) ]],
const device bfloat4* weights [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
+ const device gptoss_control* control [[ buffer(4) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
const uint simdgroup_size = 32;
threadgroup float threadgroup_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
input += gid * args.num_vecs;
output += gid * args.num_vecs;
diff --git a/gpt_oss/metal/source/rope.metal b/gpt_oss/metal/source/rope.metal
index 2739b5fa..ce4c3c8f 100644
--- a/gpt_oss/metal/source/rope.metal
+++ b/gpt_oss/metal/source/rope.metal
@@ -13,9 +13,14 @@
kernel void gptoss_f32_rope(
constant gptoss_rope_args& args [[ buffer(0) ]],
device float2* activations [[ buffer(1) ]],
+ const device gptoss_control* control [[ buffer(2) ]],
uint2 gid [[thread_position_in_grid]])
{
const uint num_head_dims = 64;
+ if (control->abort != 0) {
+ return;
+ }
+
const float head_idx = static_cast(gid.x % (num_head_dims / 2));
const uint token_idx = args.token_offset + gid.y;
activations += gid.y * args.token_stride + gid.x;
diff --git a/gpt_oss/metal/source/sample.metal b/gpt_oss/metal/source/sample.metal
index 8ce4598b..4a0efe3b 100644
--- a/gpt_oss/metal/source/sample.metal
+++ b/gpt_oss/metal/source/sample.metal
@@ -36,6 +36,7 @@ kernel void gptoss_f32_softmax(
const device uint2* argmax [[ buffer(2) ]],
device float* prob [[ buffer(3) ]],
device float* sum [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(5) ]],
uint tidx [[thread_index_in_threadgroup]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 threadgroup_size [[threads_per_threadgroup]],
@@ -44,6 +45,9 @@ kernel void gptoss_f32_softmax(
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
threadgroup float threadgroup_sumexp[32];
+ if (control->abort != 0) {
+ return;
+ }
score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
@@ -86,6 +90,7 @@ kernel void gptoss_f32_sample(
device const float* prob [[ buffer(1) ]],
device const float* sum [[ buffer(2) ]],
device uint* prediction [[ buffer(3) ]],
+ device gptoss_control* control [[ buffer(4) ]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[threads_per_threadgroup]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
@@ -95,8 +100,11 @@ kernel void gptoss_f32_sample(
threadgroup float threadgroup_sum_buffer[32];
threadgroup uint threadgroup_idx_buffer[32];
threadgroup float threadgroup_cumsum_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
- const uint sample_word = rng_squares32(args.token_offset, args.seed);
+ const uint sample_word = rng_squares32(args.rng_offset, args.rng_seed);
float sample_cdf = static_cast(sample_word & 0x00FFFFFFu) * 0x1.0p-24f;
float cumsum = 0.0f;
diff --git a/gpt_oss/metal/source/sdpa.metal b/gpt_oss/metal/source/sdpa.metal
index 5050cb41..459bbe28 100644
--- a/gpt_oss/metal/source/sdpa.metal
+++ b/gpt_oss/metal/source/sdpa.metal
@@ -18,6 +18,7 @@ kernel void gptoss_f32_sdpa_q8_d64(
const device float* v [[ buffer(3) ]],
const device bfloat* s [[ buffer(4) ]],
device float* output [[ buffer(5) ]],
+ const device gptoss_control* control [[ buffer(6) ]],
threadgroup void* threadgroup_buffer [[ threadgroup(0) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 tid [[thread_position_in_threadgroup]],
@@ -26,6 +27,9 @@ kernel void gptoss_f32_sdpa_q8_d64(
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
const uint simdgroup_size = 32;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_q_heads = 64;
const uint num_kv_heads = 8;
diff --git a/gpt_oss/metal/source/topk.metal b/gpt_oss/metal/source/topk.metal
index d3532ac6..90f4e51c 100644
--- a/gpt_oss/metal/source/topk.metal
+++ b/gpt_oss/metal/source/topk.metal
@@ -14,11 +14,15 @@ kernel void gptoss_f32_topk_softmax_e128_k4(
constant gptoss_topk_args& args [[ buffer(0) ]],
const device float4* input [[ buffer(1) ]],
device gptoss_expert_prediction* output [[ buffer(2) ]],
+ const device gptoss_control* control [[ buffer(3) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]])
{
const uint num_experts = 128;
const uint num_active_experts = 4;
+ if (control->abort != 0) {
+ return;
+ }
input += gid * (num_experts / 4);
output += gid * num_active_experts;
@@ -132,11 +136,15 @@ kernel void gptoss_f32_topk_softmax_e32_k4(
constant gptoss_topk_args& args [[ buffer(0) ]],
const device float* input [[ buffer(1) ]],
device gptoss_expert_prediction* output [[ buffer(2) ]],
+ const device gptoss_control* control [[ buffer(3) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]])
{
const uint num_experts = 32;
const uint num_active_experts = 4;
+ if (control->abort != 0) {
+ return;
+ }
input += gid * num_experts;
output += gid * num_active_experts;
diff --git a/gpt_oss/metal/test/embeddings-kernel-tester.hpp b/gpt_oss/metal/test/embeddings-kernel-tester.hpp
index fd810c6d..83092a8c 100644
--- a/gpt_oss/metal/test/embeddings-kernel-tester.hpp
+++ b/gpt_oss/metal/test/embeddings-kernel-tester.hpp
@@ -69,6 +69,8 @@ class EmbeddingsKernelTester {
metal::Buffer token_buffer{device_, sizeof(std::uint32_t)};
metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)};
metal::Buffer output_buffer{device_, num_channels() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
std::uint32_t* token_ptr = static_cast(token_buffer.ptr());
for (std::uint32_t t = 0; t < num_tokens(); t++) {
@@ -85,6 +87,8 @@ class EmbeddingsKernelTester {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens(),
num_channels()),
"gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings");
diff --git a/gpt_oss/metal/test/matmul-kernel-tester.hpp b/gpt_oss/metal/test/matmul-kernel-tester.hpp
index ec13af6b..30826f70 100644
--- a/gpt_oss/metal/test/matmul-kernel-tester.hpp
+++ b/gpt_oss/metal/test/matmul-kernel-tester.hpp
@@ -78,6 +78,8 @@ class MatMulKernelTester {
metal::Buffer weight_buffer{device_, num_rows() * num_cols() * sizeof(gptoss_bfloat16)};
metal::Buffer bias_buffer{device_, num_rows() * sizeof(gptoss_bfloat16)};
metal::Buffer output_buffer{device_, num_tokens() * num_rows() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
command_buffer.encode_launch_f32_fill_random(
f32_fill_random_fn_,
@@ -115,6 +117,8 @@ class MatMulKernelTester {
/*bias_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens(),
num_cols(),
num_rows()),
diff --git a/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp b/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
index 16a6da64..3111eecb 100644
--- a/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
+++ b/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
@@ -64,6 +64,8 @@ class RMSNormKernelTester {
metal::Buffer input_buffer{device_, num_tokens() * num_channels() * sizeof(float)};
metal::Buffer weight_buffer{device_, num_channels() * sizeof(gptoss_bfloat16)};
metal::Buffer output_buffer{device_, num_tokens() * num_channels() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
metal::CommandBuffer command_buffer{command_queue_};
@@ -90,6 +92,8 @@ class RMSNormKernelTester {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens(),
num_channels(),
epsilon()),
diff --git a/gpt_oss/metal/test/rope-kernel-tester.hpp b/gpt_oss/metal/test/rope-kernel-tester.hpp
index 602912a1..cb930621 100644
--- a/gpt_oss/metal/test/rope-kernel-tester.hpp
+++ b/gpt_oss/metal/test/rope-kernel-tester.hpp
@@ -112,6 +112,8 @@ class RoPEKernelTester {
metal::Buffer activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};
metal::Buffer ref_activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
metal::CommandBuffer command_buffer{command_queue_};
@@ -138,6 +140,9 @@ class RoPEKernelTester {
f32_rope_fn_.handle(),
threadgroup_size(),
activations_buffer.handle(),
+ /*activations_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
frequency_base(),
/*interpolation_scale=*/1.0f,
/*yarn_offset=*/0.0f,
diff --git a/gpt_oss/responses_api/inference/metal.py b/gpt_oss/responses_api/inference/metal.py
index ec84af7e..9b62b660 100644
--- a/gpt_oss/responses_api/inference/metal.py
+++ b/gpt_oss/responses_api/inference/metal.py
@@ -5,22 +5,39 @@
from gpt_oss.metal import Context, Model
+# Tunables
+MAX_OUTPUT_TOKENS = 100
+
+
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
"""Load the Metal model and return an inference function."""
model = Model(checkpoint)
context = Context(model)
+ seed = 0
+ output_tokens = []
+
def infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
"""Infer next token using incremental LCP caching when possible."""
+ nonlocal output_tokens
+
+ if new_request:
+ output_tokens = []
+
+ if len(output_tokens) == 0:
+ # Context handles LCP caching internally; if `tokens` matches the
+ # tokens in the KV cache, the KV cache is reused after reset+append.
+ context.reset()
+ for t in tokens:
+ context.append(t)
+
+ output_tokens = context.sample(max_output_tokens=MAX_OUTPUT_TOKENS,
+ temperature=temperature,
+ seed=seed)
- # Context handles LCP caching internally; if `tokens` matches the
- # tokens in the KV cache, the KV cache is reused after reset+append.
- context.reset()
- for t in tokens:
- context.append(t)
- return int(context.sample(temperature=temperature))
+ return int(output_tokens.pop(0))
return infer_next_token
From 1b5b45a77fb1175fc9b165d7efe8af84721e7569 Mon Sep 17 00:00:00 2001
From: ibahmed-oai
Date: Tue, 9 Sep 2025 23:14:19 -0700
Subject: [PATCH 87/91] Adding prefill benchmarking for metal backend (#181)
---
gpt_oss/metal/benchmark/end-to-end.cc | 155 ++++++++++++++++++++--
gpt_oss/metal/include/gpt-oss/functions.h | 6 +-
gpt_oss/metal/python/model.c | 2 +-
gpt_oss/metal/source/generate.c | 2 +-
gpt_oss/metal/source/model.c | 5 +-
5 files changed, 154 insertions(+), 16 deletions(-)
diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc
index 0a242340..6637de67 100644
--- a/gpt_oss/metal/benchmark/end-to-end.cc
+++ b/gpt_oss/metal/benchmark/end-to-end.cc
@@ -2,9 +2,10 @@
#include
#include
-#include
#include
+#include
#include
+#include