From 04257d5a8f7802956af29639d83a7676354ef871 Mon Sep 17 00:00:00 2001 From: Romain Huet Date: Tue, 5 Aug 2025 11:22:04 -0700 Subject: [PATCH 01/91] Update README --- README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 76b998ef..231697f3 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,15 @@ -
gpt-oss-120 -


-Try gpt-oss | Guides | Model card -
Learn more about OpenAI's open models
-Download gpt-oss-120b and gpt-oss-20b on Hugging Face +

+ Try gpt-oss · + Guides · + Model card · + OpenAI blog

+

+ Download gpt-oss-120b and gpt-oss-20b on Hugging Face +

+
-
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. From 0f0336796c55a417bc3919656eee6eea9dec31d7 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 5 Aug 2025 11:23:56 -0700 Subject: [PATCH 02/91] Try fix pypi ci (#13) --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 72d7411e..f00c72df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,7 @@ version = "0.0.1" [project.optional-dependencies] triton = [ - "triton @ git+https://github.com/triton-lang/triton.git", - "triton_kernels @ git+https://github.com/triton-lang/triton.git#subdirectory=python/triton_kernels", + "triton", "safetensors>=0.5.3", "torch>=2.7.0", ] From f615ce39b3f96bfe891508c6ce5babeacdad0e30 Mon Sep 17 00:00:00 2001 From: mkusaka Date: Wed, 6 Aug 2025 03:24:13 +0900 Subject: [PATCH 03/91] fix: Correct broken links in awesome-gpt-oss.md (#12) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix HTTP to HTTPS for Hugging Face blog link - Fix Groq blog link: HTTP to HTTPS, add /blog/ path, fix typo (open-model → open-models) - Fix TensorRT-LLM documentation filename (blog_9 to blog9) --- awesome-gpt-oss.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index 37befa2a..fc5cc527 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -25,7 +25,7 @@ This is a list of guides and resources to help you get started with the gpt-oss - [Use gpt-oss-120b with LM Studio](https://lmstudio.ai/models/openai/gpt-oss-120b) - Hugging Face & Transformers - [How to run gpt-oss with Transformers](https://cookbook.openai.com/articles/gpt-oss/run-transformers) - - [Hugging Face & gpt-oss launch blog](http://huggingface.co/blog/welcome-openai-gpt-oss) + - [Hugging Face & gpt-oss launch blog](https://huggingface.co/blog/welcome-openai-gpt-oss) - [Collection of Hugging Face examples](https://github.com/huggingface/gpt-oss-recipes) - NVIDIA - [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) @@ -36,12 +36,12 @@ This is a list of guides and resources to help you get started with the gpt-oss - [How to run gpt-oss with vLLM](https://cookbook.openai.com/articles/gpt-oss/run-vllm) - NVIDIA - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/gpt-oss/run-nvidia) - - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog_9_Deploying_GPT_OSS_on_TRTLLM.md) + - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md) ### Cloud - Groq - - [Groq & gpt-oss launch blog](http://groq.com/day-zero-support-for-openai-open-model) + - [Groq & gpt-oss launch blog](https://groq.com/blog/day-zero-support-for-openai-open-models) - [gpt-oss-120b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-120b) - [gpt-oss-20b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-20b) - [gpt-oss with built-in web search on GroqCloud](https://console.groq.com/docs/browser-search) From 08e50b3243f2d6a0a6930bfb43dfae20955debaa Mon Sep 17 00:00:00 2001 From: Sumit Aryal <58778103+sumitaryal@users.noreply.github.com> Date: Wed, 6 Aug 2025 00:11:02 +0545 Subject: [PATCH 04/91] Python Agents SDK Example (#14) --- examples/agents-sdk-python/example.py | 107 ++++++++++++++++++++++ examples/agents-sdk-python/pyproject.toml | 9 ++ 2 files changed, 116 insertions(+) create mode 100644 examples/agents-sdk-python/example.py create mode 100644 examples/agents-sdk-python/pyproject.toml diff --git a/examples/agents-sdk-python/example.py b/examples/agents-sdk-python/example.py new file mode 100644 index 00000000..06aa7d70 --- /dev/null +++ b/examples/agents-sdk-python/example.py @@ -0,0 +1,107 @@ +import asyncio +from pathlib import Path +import shutil + +from openai import AsyncOpenAI +from agents import ( + Agent, + ItemHelpers, + Runner, + set_default_openai_api, + set_default_openai_client, + set_tracing_disabled, + function_tool, +) +from agents.mcp import MCPServerStdio +from pydantic import BaseModel + + +class WeatherParams(BaseModel): + location: str + + +async def prompt_user(question: str) -> str: + """Async input prompt function""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, input, question) + + +async def main(): + # Set up OpenAI client for local server (e.g., Ollama) + openai_client = AsyncOpenAI( + api_key="local", + base_url="http://localhost:11434/v1", + ) + + # Get current working directory + samples_dir = str(Path.cwd()) + + # Create MCP server for filesystem operations + mcp_server = MCPServerStdio( + name="Filesystem MCP Server, via npx", + params={ + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + samples_dir, + ], + }, + ) + + # Connect to MCP server + await mcp_server.connect() + + # Configure agents SDK + set_tracing_disabled(True) + set_default_openai_client(openai_client) + set_default_openai_api("chat_completions") + + # Define weather tool + @function_tool + async def search_tool(location: str) -> str: + return f"The weather in {location} is sunny." + + # Create agent + agent = Agent( + name="My Agent", + instructions="You are a helpful assistant.", + tools=[search_tool], + model="gpt-oss:20b-test", + mcp_servers=[mcp_server], + ) + + # Get user input + user_input = await prompt_user("> ") + + # Run agent with streaming + result = Runner.run_streamed(agent, user_input) + + # Process streaming results + async for event in result.stream_events(): + if event.type == "raw_response_event": + continue + elif event.type == "agent_updated_stream_event": + print(f"Agent updated: {event.new_agent.name}") + elif event.type == "run_item_stream_event": + if event.item.type == "tool_call_item": + print("-- Tool was called") + elif event.item.type == "tool_call_output_item": + print(f"-- Tool output: {event.item.output}") + elif event.item.type == "message_output_item": + print( + f"-- Message output:\n {ItemHelpers.text_message_output(event.item)}" + ) + else: + pass + + print("=== Run complete ===") + + +if __name__ == "__main__": + + if not shutil.which("npx"): + raise RuntimeError( + "npx is not installed. Please install it with `npm install -g npx`." + ) + asyncio.run(main()) diff --git a/examples/agents-sdk-python/pyproject.toml b/examples/agents-sdk-python/pyproject.toml new file mode 100644 index 00000000..e8d24a81 --- /dev/null +++ b/examples/agents-sdk-python/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "agents-sdk-python" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "openai-agents>=0.2.4", +] From 9e5b84198755a6e3d89d5f63f96c9bdef6ea3d84 Mon Sep 17 00:00:00 2001 From: Li Yang <76434265+hewliyang@users.noreply.github.com> Date: Wed, 6 Aug 2025 02:27:31 +0800 Subject: [PATCH 05/91] readme: fix python tool ref (#10) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 231697f3..01fd4d5f 100644 --- a/README.md +++ b/README.md @@ -402,7 +402,7 @@ The model got trained on using a python tool to perform calculations and other a #### Usage -To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_python()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example: +To enable the python tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_python()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example: ```python import datetime From 1e47b7043616c1163837534c1e90cbf0fa917515 Mon Sep 17 00:00:00 2001 From: Mohammad Miadh Angkad Date: Wed, 6 Aug 2025 02:28:29 +0800 Subject: [PATCH 06/91] docs: Fix another extra "= messages" (#7) Fix another extra "= messages" in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 00f7a0d0..77465c13 100644 --- a/README.md +++ b/README.md @@ -446,7 +446,7 @@ token_ids = encoding.render_conversation_for_completion(conversation, Role.ASSIS # ... # parse the output -messages = messages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT) +messages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT) last_message = messages[-1] if last_message.recipient == "python": # perform python call From 51bfa9ed2412eac37d22fe22f6382eeeff456c6b Mon Sep 17 00:00:00 2001 From: Ricky Saull Date: Tue, 5 Aug 2025 14:29:13 -0400 Subject: [PATCH 07/91] Fix typos and grammar in README (#6) Corrected several typos and updated all references from 'Pytorch' to 'PyTorch' for consistency. Improved clarity in model descriptions and usage instructions throughout the README. --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 77465c13..518e6f0b 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,9 @@ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. -We're releasing two flavors of the open models: +We're releasing two flavors of these open models: -- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fits into a single H100 GPU (117B parameters with 5.1B active parameters) +- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single H100 GPU (117B parameters with 5.1B active parameters) - `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters) Both models were trained on our [harmony response format][harmony] and should only be used with the harmony format as it will not work correctly otherwise. @@ -76,7 +76,7 @@ vllm serve openai/gpt-oss-20b [Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm) -#### Pytorch / Triton / Metal +#### PyTorch / Triton / Metal These implementations are largely reference implementations for educational purposes and are not expected to be run in production. @@ -116,14 +116,14 @@ Check out our [awesome list](./awesome-gpt-oss.md) for a broader collection of g This repository provides a collection of reference implementations: - **Inference:** - - [`torch`](#reference-pytorch-implementation) — a non-optimized [Pytorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4x H100s because it's not optimized - - [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [Pytorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching + - [`torch`](#reference-pytorch-implementation) — a non-optimized [PyTorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4x H100s because it's not optimized + - [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [PyTorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching - [`metal`](#reference-metal-implementation) — a Metal-specific implementation for running the models on Apple Silicon hardware - **Tools:** - [`browser`](#browser) — a reference implementation of the browser tool the models got trained on - [`python`](#python) — a stateless reference implementation of the python tool the model got trained on - **Client examples:** - - [`chat`](#terminal-chat) — a basic terminal chat application that uses the Pytorch or Triton implementations for inference along with the python and browser tools + - [`chat`](#terminal-chat) — a basic terminal chat application that uses the PyTorch or Triton implementations for inference along with the python and browser tools - [`responses_api`](#responses-api) — an example Responses API compatible server that implements the browser tool along with other Responses-compatible functionality ## Setup @@ -212,7 +212,7 @@ If you encounter `torch.OutOfMemoryError` make sure to turn on the expandable al ## Reference Metal implementation -Additionally we are providing a reference implementation for Metal to run on Apple Silicon. This implementation is not production ready but is accurate to the Pytorch implementation. +Additionally we are providing a reference implementation for Metal to run on Apple Silicon. This implementation is not production-ready but is accurate to the PyTorch implementation. The implementation will get automatically compiled when running the `.[metal]` installation on an Apple Silicon device: @@ -248,7 +248,7 @@ We also include two system tools for the model: browsing and python container. C ### Terminal Chat -The terminal chat application is a basic example on how to use the harmony format together with the Pytorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used. +The terminal chat application is a basic example on how to use the harmony format together with the PyTorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used. ```bash usage: python -m gpt_oss.chat [-h] [-r REASONING_EFFORT] [-a] [-b] [--show-browser-results] [-p] [--developer-message DEVELOPER_MESSAGE] [-c CONTEXT] [--raw] [--backend {triton,torch,vllm}] FILE @@ -402,7 +402,7 @@ To improve performance the tool caches requests so that the model can revisit a ### Python -The model got trained on using a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony]. +The model was trained to use using a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony]. > [!WARNING] > This implementation runs in a permissive Docker container which could be problematic in cases like prompt injections. It's serving as an example and you should consider implementing your own container restrictions in production. @@ -436,7 +436,7 @@ if use_python_tool: system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content) # create the overall prompt -messages = [system_message, Message.from_role_and_content(Role.USER, "What's the squareroot of 9001?")] +messages = [system_message, Message.from_role_and_content(Role.USER, "What's the square root of 9001?")] conversation = Conversation.from_messages(messages) # convert to tokens From 90743264994602e7ec8af6a8207249b2e0b2593c Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Tue, 5 Aug 2025 11:44:29 -0700 Subject: [PATCH 08/91] Update LICENSE --- LICENSE | 367 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 184 insertions(+), 183 deletions(-) diff --git a/LICENSE b/LICENSE index 4ecba18e..d6456956 100644 --- a/LICENSE +++ b/LICENSE @@ -1,181 +1,182 @@ + Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" @@ -186,16 +187,16 @@ APPENDIX: How to apply the Apache License to your work. same "printed page" as the copyright notice for easier identification within third-party archives. -Copyright 2025 OpenAI + Copyright [yyyy] [name of copyright owner] -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. From 89fe402d10a59879781a1eb0a64affdf4c278a4d Mon Sep 17 00:00:00 2001 From: Mihajlo Micic <44226809+micic-mihajlo@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:50:06 -0400 Subject: [PATCH 09/91] Add comprehensive test suite for Responses API (#20) The project had almost no test coverage - just a single test checking if the API returns 200. This adds proper testing infrastructure and 21 new tests covering the main API functionality. Tests now cover response creation, error handling, tools, sessions, performance, and usage tracking. All tests passing. --- tests/conftest.py | 118 ++++++++++++++++++ tests/test_api_endpoints.py | 230 ++++++++++++++++++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_api_endpoints.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..4c008a3a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,118 @@ +import os +import sys +import pytest +from typing import Generator, Any +from unittest.mock import Mock, MagicMock +from fastapi.testclient import TestClient + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from openai_harmony import ( + HarmonyEncodingName, + load_harmony_encoding, +) +from gpt_oss.responses_api.api_server import create_api_server + + +@pytest.fixture(scope="session") +def harmony_encoding(): + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +@pytest.fixture +def mock_infer_token(harmony_encoding): + fake_tokens = harmony_encoding.encode( + "<|channel|>final<|message|>Test response<|return|>", + allowed_special="all" + ) + token_queue = fake_tokens.copy() + + def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int: + nonlocal token_queue + if len(token_queue) == 0: + token_queue = fake_tokens.copy() + return token_queue.pop(0) + return _mock_infer + + +@pytest.fixture +def api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]: + app = create_api_server( + infer_next_token=mock_infer_token, + encoding=harmony_encoding + ) + with TestClient(app) as client: + yield client + + +@pytest.fixture +def sample_request_data(): + return { + "model": "gpt-oss-120b", + "input": "Hello, how can I help you today?", + "stream": False, + "reasoning_effort": "low", + "temperature": 0.7, + "tools": [] + } + + +@pytest.fixture +def mock_browser_tool(): + mock = MagicMock() + mock.search.return_value = ["Result 1", "Result 2"] + mock.open_page.return_value = "Page content" + mock.find_on_page.return_value = "Found text" + return mock + + +@pytest.fixture +def mock_python_tool(): + mock = MagicMock() + mock.execute.return_value = { + "output": "print('Hello')", + "error": None, + "exit_code": 0 + } + return mock + + +@pytest.fixture(autouse=True) +def reset_test_environment(): + test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH'] + original_values = {} + + for var in test_env_vars: + if var in os.environ: + original_values[var] = os.environ[var] + del os.environ[var] + + yield + + for var, value in original_values.items(): + os.environ[var] = value + + +@pytest.fixture +def performance_timer(): + import time + + class Timer: + def __init__(self): + self.start_time = None + self.end_time = None + + def start(self): + self.start_time = time.time() + + def stop(self): + self.end_time = time.time() + return self.elapsed + + @property + def elapsed(self): + if self.start_time and self.end_time: + return self.end_time - self.start_time + return None + + return Timer() \ No newline at end of file diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py new file mode 100644 index 00000000..7fd354bb --- /dev/null +++ b/tests/test_api_endpoints.py @@ -0,0 +1,230 @@ +import pytest +import json +import asyncio +from fastapi import status +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestResponsesEndpoint: + + def test_basic_response_creation(self, api_client, sample_request_data): + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "id" in data + assert data["object"] == "response" + assert data["model"] == sample_request_data["model"] + + def test_response_with_high_reasoning(self, api_client, sample_request_data): + sample_request_data["reasoning_effort"] = "high" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "id" in data + assert data["status"] == "completed" + + def test_response_with_medium_reasoning(self, api_client, sample_request_data): + sample_request_data["reasoning_effort"] = "medium" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "id" in data + assert data["status"] == "completed" + + def test_response_with_invalid_model(self, api_client, sample_request_data): + sample_request_data["model"] = "invalid-model" + response = api_client.post("/v1/responses", json=sample_request_data) + # Should still accept but might handle differently + assert response.status_code == status.HTTP_200_OK + + def test_response_with_empty_input(self, api_client, sample_request_data): + sample_request_data["input"] = "" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_response_with_tools(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "browser_search" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_response_with_custom_temperature(self, api_client, sample_request_data): + for temp in [0.0, 0.5, 1.0, 1.5, 2.0]: + sample_request_data["temperature"] = temp + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "usage" in data + + def test_streaming_response(self, api_client, sample_request_data): + sample_request_data["stream"] = True + with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response: + assert response.status_code == status.HTTP_200_OK + # Verify we get SSE events + for line in response.iter_lines(): + if line and line.startswith("data: "): + event_data = line[6:] # Remove "data: " prefix + if event_data != "[DONE]": + json.loads(event_data) # Should be valid JSON + break + + +class TestResponsesWithSession: + + def test_response_with_session_id(self, api_client, sample_request_data): + session_id = "test-session-123" + sample_request_data["session_id"] = session_id + + # First request + response1 = api_client.post("/v1/responses", json=sample_request_data) + assert response1.status_code == status.HTTP_200_OK + data1 = response1.json() + + # Second request with same session + sample_request_data["input"] = "Follow up question" + response2 = api_client.post("/v1/responses", json=sample_request_data) + assert response2.status_code == status.HTTP_200_OK + data2 = response2.json() + + # Should have different response IDs + assert data1["id"] != data2["id"] + + def test_response_continuation(self, api_client, sample_request_data): + # Create initial response + response1 = api_client.post("/v1/responses", json=sample_request_data) + assert response1.status_code == status.HTTP_200_OK + data1 = response1.json() + response_id = data1["id"] + + # Continue the response + continuation_request = { + "model": sample_request_data["model"], + "response_id": response_id, + "input": "Continue the previous thought" + } + response2 = api_client.post("/v1/responses", json=continuation_request) + assert response2.status_code == status.HTTP_200_OK + + +class TestErrorHandling: + + def test_missing_required_fields(self, api_client): + # Model field has default, so test with empty JSON + response = api_client.post("/v1/responses", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_invalid_reasoning_effort(self, api_client, sample_request_data): + sample_request_data["reasoning_effort"] = "invalid" + response = api_client.post("/v1/responses", json=sample_request_data) + # May handle gracefully or return error + assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY] + + def test_malformed_json(self, api_client): + response = api_client.post( + "/v1/responses", + data="not json", + headers={"Content-Type": "application/json"} + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_extremely_long_input(self, api_client, sample_request_data): + # Test with very long input + sample_request_data["input"] = "x" * 100000 + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + +class TestToolIntegration: + + def test_browser_search_tool(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "browser_search" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_function_tool_integration(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "function", + "name": "test_function", + "parameters": {"type": "object", "properties": {}}, + "description": "Test function" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_multiple_tools(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "browser_search" + }, + { + "type": "function", + "name": "test_function", + "parameters": {"type": "object", "properties": {}}, + "description": "Test function" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + +class TestPerformance: + + def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer): + performance_timer.start() + response = api_client.post("/v1/responses", json=sample_request_data) + elapsed = performance_timer.stop() + + assert response.status_code == status.HTTP_200_OK + # Response should be reasonably fast for mock inference + assert elapsed < 5.0 # 5 seconds threshold + + def test_multiple_sequential_requests(self, api_client, sample_request_data): + # Test multiple requests work correctly + for i in range(3): + data = sample_request_data.copy() + data["input"] = f"Request {i}" + response = api_client.post("/v1/responses", json=data) + assert response.status_code == status.HTTP_200_OK + + +class TestUsageTracking: + + def test_usage_object_structure(self, api_client, sample_request_data): + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + + assert "usage" in data + usage = data["usage"] + assert "input_tokens" in usage + assert "output_tokens" in usage + assert "total_tokens" in usage + # reasoning_tokens may not always be present + # assert "reasoning_tokens" in usage + + # Basic validation + assert usage["input_tokens"] >= 0 + assert usage["output_tokens"] >= 0 + assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] + + def test_usage_increases_with_longer_input(self, api_client, sample_request_data): + # Short input + response1 = api_client.post("/v1/responses", json=sample_request_data) + usage1 = response1.json()["usage"] + + # Longer input + sample_request_data["input"] = sample_request_data["input"] * 10 + response2 = api_client.post("/v1/responses", json=sample_request_data) + usage2 = response2.json()["usage"] + + # Longer input should use more tokens + assert usage2["input_tokens"] > usage1["input_tokens"] \ No newline at end of file From 8fe4ee2088c76f5a9b35653b36d5647d7c6940dc Mon Sep 17 00:00:00 2001 From: Jack Clayton Date: Tue, 5 Aug 2025 11:24:18 -1000 Subject: [PATCH 10/91] Fix import for metal example (#24) It was previously pointing to an empty __init__.py. Also remove unused date import. --- gpt_oss/metal/examples/generate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gpt_oss/metal/examples/generate.py b/gpt_oss/metal/examples/generate.py index b9c0beac..3b781999 100644 --- a/gpt_oss/metal/examples/generate.py +++ b/gpt_oss/metal/examples/generate.py @@ -3,8 +3,7 @@ import argparse import sys -from datetime import date -from gpt_oss import Context, Model +from gpt_oss.metal import Context, Model parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter) From 246e377ba5e27bed550e04ccdb196ca715f63147 Mon Sep 17 00:00:00 2001 From: Niles Burbank <24277297+nsburbank@users.noreply.github.com> Date: Tue, 5 Aug 2025 17:27:09 -0400 Subject: [PATCH 11/91] Add some additional links to awesome-gpt-oss.md (#22) Add links to relevant AMD resources --- awesome-gpt-oss.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index fc5cc527..40dca51d 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -29,6 +29,8 @@ This is a list of guides and resources to help you get started with the gpt-oss - [Collection of Hugging Face examples](https://github.com/huggingface/gpt-oss-recipes) - NVIDIA - [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) +- AMD + - [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html) ### Server @@ -37,6 +39,8 @@ This is a list of guides and resources to help you get started with the gpt-oss - NVIDIA - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/gpt-oss/run-nvidia) - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md) +- AMD + - [Running the Latest Open Models from OpenAI on AMD AI Hardware](https://rocm.blogs.amd.com/ecosystems-and-partners/openai-day-0/README.html) ### Cloud @@ -55,6 +59,8 @@ This is a list of guides and resources to help you get started with the gpt-oss - [Cloudflare & gpt-oss launch blog post](http://blog.cloudflare.com/openai-gpt-oss-on-workers-ai) - [gpt-oss-120b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-120b) - [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b) +- AMD + - [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot) ## Examples & Tutorials From 0a8f5f29d3cf79fb6d2e7431f190282fd086991e Mon Sep 17 00:00:00 2001 From: cwhitelam <79796271+cwhitelam@users.noreply.github.com> Date: Tue, 5 Aug 2025 17:28:41 -0400 Subject: [PATCH 12/91] Correct small grammar issues for better comprehension (#21) * Correct small grammar issues for better comprehension * Update README.md --------- Co-authored-by: Christopher Whitelam Co-authored-by: Dominik Kundel --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 518e6f0b..3144d7a6 100644 --- a/README.md +++ b/README.md @@ -18,16 +18,16 @@ We're releasing two flavors of these open models: - `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single H100 GPU (117B parameters with 5.1B active parameters) - `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters) -Both models were trained on our [harmony response format][harmony] and should only be used with the harmony format as it will not work correctly otherwise. +Both models were trained using our [harmony response format][harmony] and should only be used with this format; otherwise, they will not work correctly. ### Highlights - **Permissive Apache 2.0 license:** Build freely without copyleft restrictions or patent risk—ideal for experimentation, customization, and commercial deployment. - **Configurable reasoning effort:** Easily adjust the reasoning effort (low, medium, high) based on your specific use case and latency needs. -- **Full chain-of-thought:** Gain complete access to the model's reasoning process, facilitating easier debugging and increased trust in outputs. It's not intended to be shown to end users. +- **Full chain-of-thought:** Provides complete access to the model's reasoning process, facilitating easier debugging and greater trust in outputs. This information is not intended to be shown to end users. - **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning. - **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs. -- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, making `gpt-oss-120b` run on a single H100 GPU and the `gpt-oss-20b` model run within 16GB of memory. +- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, allowing `gpt-oss-120b` to run on a single H100 GPU and `gpt-oss-20b` to run within 16GB of memory.. ### Inference examples @@ -402,7 +402,7 @@ To improve performance the tool caches requests so that the model can revisit a ### Python -The model was trained to use using a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony]. +The model was trained to use a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony]. > [!WARNING] > This implementation runs in a permissive Docker container which could be problematic in cases like prompt injections. It's serving as an example and you should consider implementing your own container restrictions in production. From 3a68b4f56530b9621314a13e56c5c4aa7bc95e2a Mon Sep 17 00:00:00 2001 From: mkusaka Date: Wed, 6 Aug 2025 06:29:10 +0900 Subject: [PATCH 13/91] fix: Correct multiple documentation URLs (#17) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix OpenAI Cookbook NVIDIA article URL (remove incorrect gpt-oss/ prefix) - Fix Groq Responses API documentation URL (responses → responses-api) --- awesome-gpt-oss.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index 40dca51d..c8a57a22 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -37,7 +37,7 @@ This is a list of guides and resources to help you get started with the gpt-oss - vLLM - [How to run gpt-oss with vLLM](https://cookbook.openai.com/articles/gpt-oss/run-vllm) - NVIDIA - - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/gpt-oss/run-nvidia) + - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/run-nvidia) - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md) - AMD - [Running the Latest Open Models from OpenAI on AMD AI Hardware](https://rocm.blogs.amd.com/ecosystems-and-partners/openai-day-0/README.html) @@ -50,7 +50,7 @@ This is a list of guides and resources to help you get started with the gpt-oss - [gpt-oss-20b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-20b) - [gpt-oss with built-in web search on GroqCloud](https://console.groq.com/docs/browser-search) - [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution) - - [Responses API on Groq](https://console.groq.com/docs/responses) + - [Responses API on Groq](https://console.groq.com/docs/responses-api) - NVIDIA - [NVIDIA launch blog post](https://blogs.nvidia.com/blog/openai-gpt-oss/) - [NVIDIA & gpt-oss developer launch blog post](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/) From ba7d80ab89c1112211fa49d462d34c2afc106127 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 5 Aug 2025 14:35:56 -0700 Subject: [PATCH 14/91] Fix chat demo (#26) --- gpt_oss/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index ed2bda21..47d27067 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -120,7 +120,7 @@ def main(args): ]) ) messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content)) - elif args.developer_message: + else: developer_message_content = DeveloperContent.new().with_instructions(args.developer_message) messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content)) From a6d9d90ab56fefaf1eeaa4c4305a9e341184f435 Mon Sep 17 00:00:00 2001 From: draczer01 <31465304+draczer01@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:42:46 -0600 Subject: [PATCH 15/91] set plataform for CI porpuses (#18) used a specific plataform accepted by pypl --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f00c72df..61746718 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,5 @@ cmake.args = [ "-DCMAKE_BUILD_TYPE=Release", "-DBUILD_SHARED_LIBS=OFF", ] +[tool.scikit-build.wheel] +plat-name = "manylinux_2_17_x86_64" From d8db548846acb4ce63bc7027a54ec84e4764cb32 Mon Sep 17 00:00:00 2001 From: 3n0chK4n Date: Wed, 6 Aug 2025 00:15:12 +0100 Subject: [PATCH 16/91] Fix TOML parsing errors in pyproject.toml for scikit-build configuration (#27) * Update wheel configuration in pyproject.toml to include package tree * Added python dependency and shell globbing for metal command --- README.md | 3 ++- pyproject.toml | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3144d7a6..3f3f4de1 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ This repository provides a collection of reference implementations: ### Requirements +- python 3.12 - On macOS: Install the Xcode CLI tools --> `xcode-select --install` - On Linux: These reference implementations require CUDA - On Windows: These reference implementations have not been tested on Windows. Try using solutions like Ollama if you are trying to run the model locally. @@ -151,7 +152,7 @@ If you want to modify the code or try the metal implementation set the project u ```shell git clone https://github.com/openai/gpt-oss.git -pip install -e .[metal] +pip install -e ".[metal]" ``` ## Download the model diff --git a/pyproject.toml b/pyproject.toml index 61746718..38c43768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,11 +42,10 @@ packages = ["gpt_oss"] [tool.scikit-build] cmake.source-dir = "." # pick up the root CMakeLists.txt -wheel.packages = ["gpt_oss"] # copy the whole Python package tree cmake.args = [ "-DGPTOSS_BUILD_PYTHON=ON", "-DCMAKE_BUILD_TYPE=Release", "-DBUILD_SHARED_LIBS=OFF", ] [tool.scikit-build.wheel] -plat-name = "manylinux_2_17_x86_64" +packages = ["gpt_oss"] # copy the whole Python package tree From f1774c5110445b68f6effc2304ab8e4c55af3b44 Mon Sep 17 00:00:00 2001 From: Scott Lessans <142930063+scott-oai@users.noreply.github.com> Date: Tue, 5 Aug 2025 17:44:04 -0700 Subject: [PATCH 17/91] fix ci/pypi (#30) --- README.md | 3 +- _build/gpt_oss_build_backend/backend.py | 140 ++++++++++++++++++++++++ pyproject.toml | 11 +- 3 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 _build/gpt_oss_build_backend/backend.py diff --git a/README.md b/README.md index 3f3f4de1..7d4f2791 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,7 @@ If you want to modify the code or try the metal implementation set the project u ```shell git clone https://github.com/openai/gpt-oss.git -pip install -e ".[metal]" +GPTOSS_BUILD_METAL=1 pip install -e ".[metal]" ``` ## Download the model @@ -228,6 +228,7 @@ python gpt_oss/metal/scripts/create-local-model.py -s -d bool: + return str(os.environ.get("GPTOSS_BUILD_METAL", "")).strip() in TRUE_VALUES + + +def _setuptools_backend(): + from setuptools import build_meta as _bm # type: ignore + + return _bm + + +def _scikit_build_backend(): + return import_module("scikit_build_core.build") + + +def _backend(): + return _scikit_build_backend() if _use_metal_backend() else _setuptools_backend() + + +# Required PEP 517 hooks + +def build_wheel( + wheel_directory: str, + config_settings: Mapping[str, Any] | None = None, + metadata_directory: str | None = None, +) -> str: + return _backend().build_wheel(wheel_directory, config_settings, metadata_directory) + + +def build_sdist( + sdist_directory: str, config_settings: Mapping[str, Any] | None = None +) -> str: + return _backend().build_sdist(sdist_directory, config_settings) + + +def prepare_metadata_for_build_wheel( + metadata_directory: str, config_settings: Mapping[str, Any] | None = None +) -> str: + # Fallback if backend doesn't implement it + be = _backend() + fn = getattr(be, "prepare_metadata_for_build_wheel", None) + if fn is None: + # setuptools exposes it; scikit-build-core may not. Defer to building a wheel for metadata. + return _setuptools_backend().prepare_metadata_for_build_wheel( + metadata_directory, config_settings + ) + return fn(metadata_directory, config_settings) + + +# Optional hooks + +def build_editable( + editable_directory: str, config_settings: Mapping[str, Any] | None = None +) -> str: + be = _backend() + fn = getattr(be, "build_editable", None) + if fn is None: + # setuptools implements build_editable; if not available, raise the standard error + raise RuntimeError("Editable installs not supported by the selected backend") + return fn(editable_directory, config_settings) + + +def get_requires_for_build_wheel( + config_settings: Mapping[str, Any] | None = None, +) -> Sequence[str]: + if _use_metal_backend(): + # Add dynamic build requirements only when building the Metal backend + return [ + "scikit-build-core>=0.10", + "pybind11>=2.12", + "cmake>=3.26", + "ninja", + ] + # setuptools usually returns [] + return list(_setuptools_backend().get_requires_for_build_wheel(config_settings)) + + +def get_requires_for_build_sdist( + config_settings: Mapping[str, Any] | None = None, +) -> Sequence[str]: + # No special requirements for SDist + be = _backend() + fn = getattr(be, "get_requires_for_build_sdist", None) + if fn is None: + return [] + return list(fn(config_settings)) + + +def get_requires_for_build_editable( + config_settings: Mapping[str, Any] | None = None, +) -> Sequence[str]: + if _use_metal_backend(): + return [ + "scikit-build-core>=0.10", + "pybind11>=2.12", + "cmake>=3.26", + "ninja", + ] + be = _setuptools_backend() + fn = getattr(be, "get_requires_for_build_editable", None) + if fn is None: + return [] + return list(fn(config_settings)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 38c43768..b40487a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,19 +23,16 @@ requires-python = ">=3.12,<3.13" version = "0.0.1" [project.optional-dependencies] -triton = [ - "triton", - "safetensors>=0.5.3", - "torch>=2.7.0", -] +triton = ["triton", "safetensors>=0.5.3", "torch>=2.7.0"] torch = ["safetensors>=0.5.3", "torch>=2.7.0"] metal = ["numpy", "tqdm", "safetensors", "torch"] test = ["pytest>=8.4.1", "httpx>=0.28.1"] eval = ["pandas", "numpy", "openai", "jinja2", "tqdm", "blobfile"] [build-system] -requires = ["scikit-build-core>=0.9", "pybind11>=2.12", "cmake>=3.26", "ninja"] -build-backend = "scikit_build_core.build" +requires = ["setuptools>=68"] +build-backend = "gpt_oss_build_backend.backend" +backend-path = ["_build"] [tool.setuptools] packages = ["gpt_oss"] From 4931694686fadfa74a80554473d32f7dd4d059f3 Mon Sep 17 00:00:00 2001 From: Scott Lessans Date: Tue, 5 Aug 2025 17:52:43 -0700 Subject: [PATCH 18/91] fix build --- MANIFEST.in | 1 + _build/gpt_oss_build_backend/__init__.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 MANIFEST.in create mode 100644 _build/gpt_oss_build_backend/__init__.py diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..7bd37930 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include _build * \ No newline at end of file diff --git a/_build/gpt_oss_build_backend/__init__.py b/_build/gpt_oss_build_backend/__init__.py new file mode 100644 index 00000000..2f46b29d --- /dev/null +++ b/_build/gpt_oss_build_backend/__init__.py @@ -0,0 +1 @@ +"""In-tree PEP 517 backend package for gpt-oss.""" \ No newline at end of file From 754a56b63d38f3f4ff624271ad57f4276231ea49 Mon Sep 17 00:00:00 2001 From: "vol@" Date: Wed, 6 Aug 2025 13:57:00 -0700 Subject: [PATCH 19/91] evals: add chat completions API sampler (#59) * evals: admit --sampler chat_completions * gpt_oss.evals: allow modifying the model names --- gpt_oss/evals/__main__.py | 105 ++++++------------ ...sampler.py => chat_completions_sampler.py} | 41 ++++--- gpt_oss/evals/healthbench_eval.py | 7 +- 3 files changed, 65 insertions(+), 88 deletions(-) rename gpt_oss/evals/{chat_completion_sampler.py => chat_completions_sampler.py} (64%) diff --git a/gpt_oss/evals/__main__.py b/gpt_oss/evals/__main__.py index 7e95ab7e..6e93e1f2 100644 --- a/gpt_oss/evals/__main__.py +++ b/gpt_oss/evals/__main__.py @@ -6,9 +6,9 @@ from .gpqa_eval import GPQAEval from .aime_eval import AIME25Eval from .healthbench_eval import HealthBenchEval -from .chat_completion_sampler import ( +from .chat_completions_sampler import ( OPENAI_SYSTEM_MESSAGE_API, - ChatCompletionSampler, + ChatCompletionsSampler, ) from .responses_sampler import ResponsesSampler @@ -19,12 +19,23 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--list-models", action="store_true", help="List available models" + "--model", + type=str, + default="gpt-oss-120b,gpt-oss-20b", + help="Select a model by name. Accepts a comma-separated list.", ) parser.add_argument( - "--model", + "--reasoning-effort", + type=str, + default="low,medium,high", + help="Reasoning effort (low, medium, high). Accepts a comma-separated list.", + ) + parser.add_argument( + "--sampler", type=str, - help="Select a model by name. Also accepts a comma-separated list of models.", + choices=["responses", "chat_completions"], + default="responses", + help="Sampler backend to use for models.", ) parser.add_argument( "--base-url", @@ -36,7 +47,7 @@ def main(): "--eval", type=str, default="gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25", - help="Select an eval by name. Also accepts a comma-separated list of evals.", + help="Select an eval by name. Accepts a comma-separated list.", ) parser.add_argument( "--temperature", @@ -59,71 +70,26 @@ def main(): args = parser.parse_args() - models = { - "120b-low": ResponsesSampler( - model="gpt-oss-120b", - reasoning_model=True, - reasoning_effort="low", - temperature=args.temperature, - base_url=args.base_url, - ), - "120b": ResponsesSampler( - model="gpt-oss-120b", - reasoning_model=True, - reasoning_effort="medium", - temperature=args.temperature, - base_url=args.base_url, - ), - "120b-high": ResponsesSampler( - model="gpt-oss-120b", - reasoning_model=True, - reasoning_effort="high", - temperature=args.temperature, - base_url=args.base_url, - ), - "20b-low": ResponsesSampler( - model="gpt-oss-20b", - reasoning_model=True, - reasoning_effort="low", - temperature=args.temperature, - base_url=args.base_url, - ), - "20b": ResponsesSampler( - model="gpt-oss-20b", - reasoning_model=True, - reasoning_effort="medium", - temperature=args.temperature, - base_url=args.base_url, - ), - "20b-high": ResponsesSampler( - model="gpt-oss-20b", - reasoning_model=True, - reasoning_effort="high", - temperature=args.temperature, - base_url=args.base_url, - ), - } - - if args.list_models: - print("Available models:") - for model_name in models.keys(): - print(f" - {model_name}") - return - - if args.model: - models_chosen = args.model.split(",") - for model_name in models_chosen: - if model_name not in models: - print(f"Error: Model '{model_name}' not found.") - return - models = {model_name: models[model_name] for model_name in models_chosen} + sampler_cls = ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler + + models = {} + for model_name in args.model.split(","): + for reasoning_effort in args.reasoning_effort.split(","): + models[f"{model_name}-{reasoning_effort}"] = sampler_cls( + model=model_name, + reasoning_model=True, + reasoning_effort=reasoning_effort, + temperature=args.temperature, + base_url=args.base_url, + ) print(f"Running with args {args}") - grading_sampler = ChatCompletionSampler( + grading_sampler = ChatCompletionsSampler( model="gpt-4.1-2025-04-14", system_message=OPENAI_SYSTEM_MESSAGE_API, max_tokens=2048, + base_url="https://api.openai.com/v1", ) def get_evals(eval_name, debug_mode): @@ -172,17 +138,15 @@ def get_evals(eval_name, debug_mode): case _: raise Exception(f"Unrecognized eval type: {eval_name}") - evals_list = args.eval.split(",") evals = {} - for eval_name in evals_list: + for eval_name in args.eval.split(","): evals[eval_name] = get_evals(eval_name, args.debug) - print(evals) debug_suffix = "_DEBUG" if args.debug else "" print(debug_suffix) mergekey2resultpath = {} - print(f"Running the following evals: {list(evals.keys())}") - print(f"Running evals for the following models: {list(models.keys())}") + print(f"Running the following evals: {evals}") + print(f"Running evals for the following models: {models}") now = datetime.now() date_str = now.strftime("%Y%m%d_%H%M%S") @@ -220,6 +184,7 @@ def get_evals(eval_name, debug_mode): print(f"Writing all results to {full_result_filename}") mergekey2resultpath[f"{file_stem}"] = result_filename + merge_metrics = [] for eval_model_name, result_filename in mergekey2resultpath.items(): try: diff --git a/gpt_oss/evals/chat_completion_sampler.py b/gpt_oss/evals/chat_completions_sampler.py similarity index 64% rename from gpt_oss/evals/chat_completion_sampler.py rename to gpt_oss/evals/chat_completions_sampler.py index 4a1f9618..557bc906 100644 --- a/gpt_oss/evals/chat_completion_sampler.py +++ b/gpt_oss/evals/chat_completions_sampler.py @@ -6,6 +6,7 @@ from .types import MessageList, SamplerBase, SamplerResponse + OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." OPENAI_SYSTEM_MESSAGE_CHATGPT = ( "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." @@ -13,10 +14,8 @@ ) -class ChatCompletionSampler(SamplerBase): - """ - Sample from OpenAI's chat completion API - """ +class ChatCompletionsSampler(SamplerBase): + """Sample from a Chat Completions compatible API.""" def __init__( self, @@ -24,17 +23,21 @@ def __init__( system_message: str | None = None, temperature: float = 0.5, max_tokens: int = 1024, + reasoning_model: bool = False, + reasoning_effort: str | None = None, + base_url: str = "http://localhost:8000/v1", ): self.api_key_name = "OPENAI_API_KEY" - self.client = OpenAI() - # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY + self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60) self.model = model self.system_message = system_message self.temperature = temperature self.max_tokens = max_tokens + self.reasoning_model = reasoning_model + self.reasoning_effort = reasoning_effort self.image_format = "url" - def _pack_message(self, role: str, content: Any): + def _pack_message(self, role: str, content: Any) -> dict[str, Any]: return {"role": str(role), "content": content} def __call__(self, message_list: MessageList) -> SamplerResponse: @@ -45,12 +48,21 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: trial = 0 while True: try: - response = self.client.chat.completions.create( - model=self.model, - messages=message_list, - temperature=self.temperature, - max_tokens=self.max_tokens, - ) + if self.reasoning_model: + response = self.client.chat.completions.create( + model=self.model, + messages=message_list, + reasoning_effort=self.reasoning_effort, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + else: + response = self.client.chat.completions.create( + model=self.model, + messages=message_list, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) content = response.choices[0].message.content if content is None: raise ValueError("OpenAI API returned empty response; retrying") @@ -59,7 +71,6 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: response_metadata={"usage": response.usage}, actual_queried_message_list=message_list, ) - # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU except openai.BadRequestError as e: print("Bad Request Error", e) return SamplerResponse( @@ -68,7 +79,7 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: actual_queried_message_list=message_list, ) except Exception as e: - exception_backoff = 2**trial # expontial back off + exception_backoff = 2 ** trial # exponential back off print( f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", e, diff --git a/gpt_oss/evals/healthbench_eval.py b/gpt_oss/evals/healthbench_eval.py index 77a6b3a5..bf136ac3 100644 --- a/gpt_oss/evals/healthbench_eval.py +++ b/gpt_oss/evals/healthbench_eval.py @@ -26,9 +26,9 @@ import numpy as np from . import report -from .chat_completion_sampler import ( +from .chat_completions_sampler import ( OPENAI_SYSTEM_MESSAGE_API, - ChatCompletionSampler, + ChatCompletionsSampler, ) from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult @@ -540,10 +540,11 @@ def physician_completions_main( now = datetime.now() date_str = now.strftime("%Y%m%d_%H%M") - grading_sampler = ChatCompletionSampler( + grading_sampler = ChatCompletionsSampler( model="gpt-4.1-2025-04-14", system_message=OPENAI_SYSTEM_MESSAGE_API, max_tokens=2048, + base_url="https://api.openai.com/v1", ) dummy_sampler = SamplerBase() From d0a300a40d6502a1bdd73d18464f3d69440656e0 Mon Sep 17 00:00:00 2001 From: "vol@" Date: Wed, 6 Aug 2025 16:00:30 -0700 Subject: [PATCH 20/91] evals: log reasoning and extend max_tokens for chat completions (#62) --- gpt_oss/evals/__main__.py | 8 +++-- gpt_oss/evals/basic_eval.py | 38 +++++++++++++++++++++++ gpt_oss/evals/chat_completions_sampler.py | 10 ++++-- gpt_oss/evals/responses_sampler.py | 3 +- 4 files changed, 52 insertions(+), 7 deletions(-) create mode 100644 gpt_oss/evals/basic_eval.py diff --git a/gpt_oss/evals/__main__.py b/gpt_oss/evals/__main__.py index 6e93e1f2..bb34e2c3 100644 --- a/gpt_oss/evals/__main__.py +++ b/gpt_oss/evals/__main__.py @@ -3,6 +3,7 @@ from datetime import datetime from . import report +from .basic_eval import BasicEval from .gpqa_eval import GPQAEval from .aime_eval import AIME25Eval from .healthbench_eval import HealthBenchEval @@ -81,6 +82,7 @@ def main(): reasoning_effort=reasoning_effort, temperature=args.temperature, base_url=args.base_url, + max_tokens=131_072, ) print(f"Running with args {args}") @@ -98,9 +100,11 @@ def get_evals(eval_name, debug_mode): ) # Set num_examples = None to reproduce full evals match eval_name: + case "basic": + return BasicEval() case "gpqa": return GPQAEval( - n_repeats=8, + n_repeats=1 if args.debug else 8, num_examples=num_examples, debug=debug_mode, n_threads=args.n_threads or 1, @@ -131,7 +135,7 @@ def get_evals(eval_name, debug_mode): ) case "aime25": return AIME25Eval( - n_repeats=8, + n_repeats=1 if args.debug else 8, num_examples=num_examples, n_threads=args.n_threads or 1, ) diff --git a/gpt_oss/evals/basic_eval.py b/gpt_oss/evals/basic_eval.py new file mode 100644 index 00000000..77995307 --- /dev/null +++ b/gpt_oss/evals/basic_eval.py @@ -0,0 +1,38 @@ +""" +Basic eval +""" +from . import report + +from .types import Eval, EvalResult, SamplerBase, SingleEvalResult + +class BasicEval(Eval): + def __init__(self,): + self.examples = [{ + "question": "hi", + "answer": "hi, how can i help?", + }] + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + sampler_response = sampler([ + sampler._pack_message(content=row["question"], role="user") + ]) + response_text = sampler_response.response_text + extracted_answer = response_text + actual_queried_prompt_messages = sampler_response.actual_queried_message_list + score = 1.0 if len(extracted_answer) > 0 else 0.0 + html = report.jinja_env.from_string(report.HTML_JINJA).render( + prompt_messages=actual_queried_prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=row["answer"], + extracted_answer=extracted_answer, + ) + convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] + return SingleEvalResult( + html=html, score=score, convo=convo, metrics={"chars": len(response_text)} + ) + + results = report.map_with_progress(fn, self.examples, num_threads=1) + return report.aggregate_results(results) + diff --git a/gpt_oss/evals/chat_completions_sampler.py b/gpt_oss/evals/chat_completions_sampler.py index 557bc906..29c1a0a8 100644 --- a/gpt_oss/evals/chat_completions_sampler.py +++ b/gpt_oss/evals/chat_completions_sampler.py @@ -27,7 +27,6 @@ def __init__( reasoning_effort: str | None = None, base_url: str = "http://localhost:8000/v1", ): - self.api_key_name = "OPENAI_API_KEY" self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60) self.model = model self.system_message = system_message @@ -63,8 +62,13 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: temperature=self.temperature, max_tokens=self.max_tokens, ) - content = response.choices[0].message.content - if content is None: + + choice = response.choices[0] + content = choice.message.content + if getattr(choice.message, "reasoning", None): + message_list.append(self._pack_message("assistant", choice.message.reasoning)) + + if not content: raise ValueError("OpenAI API returned empty response; retrying") return SamplerResponse( response_text=content, diff --git a/gpt_oss/evals/responses_sampler.py b/gpt_oss/evals/responses_sampler.py index ec4e0485..fd9daef3 100644 --- a/gpt_oss/evals/responses_sampler.py +++ b/gpt_oss/evals/responses_sampler.py @@ -17,12 +17,11 @@ def __init__( model: str, developer_message: str | None = None, temperature: float = 1.0, - max_tokens: int = 1024, + max_tokens: int = 131_072, reasoning_model: bool = False, reasoning_effort: str | None = None, base_url: str = "http://localhost:8000/v1", ): - self.api_key_name = "OPENAI_API_KEY" self.client = OpenAI(base_url=base_url, timeout=24*60*60) self.model = model self.developer_message = developer_message From 4f5ca7fa6075782c2356a9664bdabcccc3d880cb Mon Sep 17 00:00:00 2001 From: "vol@" Date: Thu, 7 Aug 2025 11:11:08 -0700 Subject: [PATCH 21/91] chat / api_server: do not include developer messages to reduce mismatch (#79) --- gpt_oss/chat.py | 10 ++++++---- gpt_oss/responses_api/api_server.py | 30 ++++++++++++++--------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 47d27067..5e40079d 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -120,9 +120,11 @@ def main(args): ]) ) messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content)) - else: + elif args.developer_message: developer_message_content = DeveloperContent.new().with_instructions(args.developer_message) messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content)) + else: + developer_message_content = None if args.raw: conversation = Conversation.from_messages(messages) @@ -142,9 +144,9 @@ def main(args): print(termcolor.colored("Browser Tool:", "cyan"), "Enabled" if args.browser else "Disabled", flush=True) print(termcolor.colored("Python Tool:", "cyan"), "Enabled" if args.python else "Disabled", flush=True) print(termcolor.colored("Apply Patch Function:", "cyan"), "Enabled" if args.apply_patch else "Disabled", flush=True) - # Developer message - print(termcolor.colored("Developer Message:", "yellow"), flush=True) - print(developer_message_content.instructions, flush=True) + if developer_message_content: + print(termcolor.colored("Developer Message:", "yellow"), flush=True) + print(developer_message_content.instructions, flush=True) # Print the system message and the user message start MESSAGE_PADDING = 12 diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py index 908d86c0..5ea7fc15 100644 --- a/gpt_oss/responses_api/api_server.py +++ b/gpt_oss/responses_api/api_server.py @@ -793,16 +793,16 @@ def _ensure_list(inp): system_message = Message.from_role_and_content( Role.SYSTEM, system_message_content ) + messages = [system_message] - developer_message_content = DeveloperContent.new().with_instructions( - body.instructions - ) + if body.instructions or body.tools: + developer_message_content = DeveloperContent.new().with_instructions( + body.instructions + ) - tools = [] - if body.tools: + tools = [] for tool in body.tools: if tool.type == "function": - has_functions = True tools.append( ToolDescription.new( tool.name, @@ -810,17 +810,17 @@ def _ensure_list(inp): tool.parameters, ) ) - - if len(tools) > 0: - developer_message_content = developer_message_content.with_function_tools( - tools - ) - developer_message = Message.from_role_and_content( - Role.DEVELOPER, developer_message_content - ) + if tools: + developer_message_content = developer_message_content.with_function_tools( + tools + ) + + developer_message = Message.from_role_and_content( + Role.DEVELOPER, developer_message_content + ) - messages = [system_message, developer_message] + messages.append(developer_message) if isinstance(body.input, str): user_message = Message.from_role_and_content(Role.USER, body.input) From 4d514ddc2b9f566ffc5da06721903a68e78f024d Mon Sep 17 00:00:00 2001 From: Sebastian Sarco Date: Thu, 7 Aug 2025 15:11:37 -0300 Subject: [PATCH 22/91] Fix typos 'lenght' -> 'length' (#78) --- gpt_oss/tools/simple_browser/page_contents.py | 4 ++-- gpt_oss/tools/simple_browser/simple_browser_tool.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gpt_oss/tools/simple_browser/page_contents.py b/gpt_oss/tools/simple_browser/page_contents.py index 4a18fc97..e1e1b951 100644 --- a/gpt_oss/tools/simple_browser/page_contents.py +++ b/gpt_oss/tools/simple_browser/page_contents.py @@ -87,13 +87,13 @@ def mark_lines(text: str) -> str: @functools.cache -def _tiktoken_vocabulary_lenghts(enc_name: str) -> list[int]: +def _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]: encoding = tiktoken.get_encoding(enc_name) return [len(encoding.decode([i])) for i in range(encoding.n_vocab)] def warmup_caches(enc_names: list[str]) -> None: - for _ in map(_tiktoken_vocabulary_lenghts, enc_names): + for _ in map(_tiktoken_vocabulary_lengths, enc_names): pass diff --git a/gpt_oss/tools/simple_browser/simple_browser_tool.py b/gpt_oss/tools/simple_browser/simple_browser_tool.py index 9d84a310..913ee0bd 100644 --- a/gpt_oss/tools/simple_browser/simple_browser_tool.py +++ b/gpt_oss/tools/simple_browser/simple_browser_tool.py @@ -102,8 +102,8 @@ def max_chars_per_token(enc_name: str) -> int: def get_tokens(text: str, enc_name: str) -> Tokens: encoding = tiktoken.get_encoding(enc_name) tokens = encoding.encode(text, disallowed_special=()) - _vocabulary_lenghts = _tiktoken_vocabulary_lengths(enc_name) - tok2idx = [0] + list(itertools.accumulate(_vocabulary_lenghts[i] for i in tokens))[ + _vocabulary_lengths = _tiktoken_vocabulary_lengths(enc_name) + tok2idx = [0] + list(itertools.accumulate(_vocabulary_lengths[i] for i in tokens))[ :-1 ] result = Tokens(tokens=tokens, tok2idx=tok2idx) From 9568e6ee38bb7a8ec6d51f4966cd51cbd3f3c80c Mon Sep 17 00:00:00 2001 From: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> Date: Thu, 7 Aug 2025 20:13:55 +0200 Subject: [PATCH 23/91] fix f string errors in streamlit chat (#73) --- examples/streamlit/streamlit_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/streamlit/streamlit_chat.py b/examples/streamlit/streamlit_chat.py index d03fe9c0..f36c37b8 100644 --- a/examples/streamlit/streamlit_chat.py +++ b/examples/streamlit/streamlit_chat.py @@ -173,7 +173,7 @@ def run(container): item = data.get("item", {}) if item.get("type") == "function_call": with container.chat_message("function_call", avatar="🔨"): - st.markdown(f"Called `{item.get("name")}`") + st.markdown(f"Called `{item.get('name')}`") st.caption("Arguments") st.code(item.get("arguments", ""), language="json") if item.get("type") == "web_search_call": @@ -223,7 +223,7 @@ def run(container): st.markdown(item["text"]) elif msg.get("type") == "function_call": with st.chat_message("function_call", avatar="🔨"): - st.markdown(f"Called `{msg.get("name")}`") + st.markdown(f"Called `{msg.get('name')}`") st.caption("Arguments") st.code(msg.get("arguments", ""), language="json") elif msg.get("type") == "function_call_output": From e490130b39cab9df886c303eb7ffedb70eaa950d Mon Sep 17 00:00:00 2001 From: Ignacio Correcher Date: Thu, 7 Aug 2025 20:14:27 +0200 Subject: [PATCH 24/91] Fixing typos and grammatical improvements. (#72) --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7d4f2791..598fe992 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ OpenAI blog

- Download gpt-oss-120b and gpt-oss-20b on Hugging Face + Download gpt-oss-120b and gpt-oss-20b on Hugging Face


@@ -27,13 +27,13 @@ Both models were trained using our [harmony response format][harmony] and should - **Full chain-of-thought:** Provides complete access to the model's reasoning process, facilitating easier debugging and greater trust in outputs. This information is not intended to be shown to end users. - **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning. - **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs. -- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, allowing `gpt-oss-120b` to run on a single H100 GPU and `gpt-oss-20b` to run within 16GB of memory.. +- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, allowing `gpt-oss-120b` to run on a single H100 GPU and `gpt-oss-20b` to run within 16GB of memory. ### Inference examples #### Transformers -You can use `gpt-oss-120b` and `gpt-oss-20b` with Transformers. If you use Transformers's chat template it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package. +You can use `gpt-oss-120b` and `gpt-oss-20b` with Transformers. If you use the Transformers chat template it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package. ```python from transformers import pipeline @@ -398,7 +398,7 @@ if last_message.recipient.startswith("browser"): #### Details -To control the context window size this tool use a scrollable window of text that the model can interact with. So it might fetch the first 50 lines of a page and then scroll to the next 20 lines after that. The model has also been trained to then use citations from this tool in its answers. +To control the context window size this tool uses a scrollable window of text that the model can interact with. So it might fetch the first 50 lines of a page and then scroll to the next 20 lines after that. The model has also been trained to then use citations from this tool in its answers. To improve performance the tool caches requests so that the model can revisit a different part of a page without having to reload the page. For that reason you should create a new browser instance for every request. From 98f62cced4828c62ee559f909f2c535bce891a51 Mon Sep 17 00:00:00 2001 From: MirzaSamadAhmedBaig <89132160+Mirza-Samad-Ahmed-Baig@users.noreply.github.com> Date: Thu, 7 Aug 2025 23:15:18 +0500 Subject: [PATCH 25/91] fix: max_tokens handling in generate.py (#70) --- gpt_oss/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index cd60ca4c..dfaaa6f1 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -28,7 +28,8 @@ def main(args): tokenizer = get_tokenizer() tokens = tokenizer.encode(args.prompt) - for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=args.limit, return_logprobs=True): + max_tokens = None if args.limit == 0 else args.limit + for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True): tokens.append(token) decoded_token = tokenizer.decode([token]) print( From 7e64492fa8ae26b9463fb38d30b8cf11ca020081 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Thu, 7 Aug 2025 12:58:44 -0700 Subject: [PATCH 26/91] Support concurrent sampling from multiple Contexts (#83) Move activation buffers from Model to Context, so they are no longer shared across contexts and multiple contexts can sample in parallel --- gpt_oss/metal/source/context.c | 114 +++++++++++++----- gpt_oss/metal/source/include/internal/model.h | 23 ++-- gpt_oss/metal/source/model.c | 49 -------- 3 files changed, 94 insertions(+), 92 deletions(-) diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index af6a5d65..c356ea44 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -47,6 +47,41 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create( atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed); context->max_tokens = context_length; + // Activation buffers + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + + // Input/output buffers status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer); if (status != gptoss_status_success) { goto cleanup; @@ -73,7 +108,11 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create( } context->kvcache_size = context->kvcache_buffer.size; - context->allocation_size = context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size; + context->allocation_size = + context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size + + context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size + + context->gate_activation_buffer.size + context->expert_activation_buffer.size + context->swiglu_activation_buffer.size + context->moe_activation_buffer.size + + context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size; context->model = model; gptoss_model_retain(model); @@ -139,7 +178,7 @@ static enum gptoss_status process_batch( (context->num_tokens - context->num_batch_tokens) * sizeof(uint32_t), &model->shared_weight_buffer, /*weight_offset=*/0, - &model->residual_activation_buffer, + &context->residual_activation_buffer, /*output_offset=*/0, /*num_tokens=*/context->num_batch_tokens, /*num_channels=*/model->embedding_dim); @@ -154,11 +193,11 @@ static enum gptoss_status process_batch( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( &command_buffer, &model->f32_bf16w_rmsnorm_fn, - &model->residual_activation_buffer, + &context->residual_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, - &model->rmsnorm_activation_buffer, + &context->rmsnorm_activation_buffer, /*output_offset=*/0, /*num_tokens=*/context->num_batch_tokens, /*num_channels=*/model->embedding_dim, @@ -171,13 +210,13 @@ static enum gptoss_status process_batch( &command_buffer, &model->f32_bf16w_matmul_fn, /*threadgroup_size=*/256, - &model->rmsnorm_activation_buffer, + &context->rmsnorm_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n, &model->shared_weight_buffer, /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, - &model->qkv_activation_buffer, + &context->qkv_activation_buffer, /*output_offset=*/0, /*num_tokens=*/context->num_batch_tokens, /*num_cols=*/model->embedding_dim, @@ -191,7 +230,7 @@ static enum gptoss_status process_batch( &command_buffer, &model->f32_rope_fn, /*threadgroup_size=*/32, - &model->qkv_activation_buffer, + &context->qkv_activation_buffer, model->rope_theta, model->interpolation_scale, model->yarn_offset, @@ -209,7 +248,7 @@ static enum gptoss_status process_batch( for (uint32_t t = 0; t < context->num_batch_tokens; t++) { status = gptoss_metal_command_buffer_encode_copy_buffer( &command_buffer, - &model->qkv_activation_buffer, + &context->qkv_activation_buffer, /*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float), &context->kvcache_buffer, /*output_offset=*/(n * context->max_tokens + context->num_kv_tokens + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float), @@ -223,7 +262,7 @@ static enum gptoss_status process_batch( status = gptoss_metal_command_buffer_encode_launch_f32_sdpa( &command_buffer, &model->f32_sdpa_q8_d64_fn, - &model->qkv_activation_buffer, + &context->qkv_activation_buffer, /*q_offset=*/attn_qkv_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), &context->kvcache_buffer, /*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float), @@ -231,7 +270,7 @@ static enum gptoss_status process_batch( /*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float), &model->shared_weight_buffer, /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n, - &model->sdpa_activation_buffer, /*output_offset=*/0, + &context->sdpa_activation_buffer, /*output_offset=*/0, /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX, num_output_tokens, context->num_kv_tokens + (context->num_batch_tokens - num_output_tokens), model->num_heads, model->num_kv_heads, model->head_dim); @@ -243,13 +282,13 @@ static enum gptoss_status process_batch( &command_buffer, &model->f32_bf16w_matmul_fn, /*threadgroup_size=*/256, - &model->sdpa_activation_buffer, + &context->sdpa_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n, &model->shared_weight_buffer, /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, - &model->residual_activation_buffer, + &context->residual_activation_buffer, /*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), /*num_tokens=*/num_output_tokens, /*num_cols=*/model->num_heads * model->head_dim, @@ -262,11 +301,11 @@ static enum gptoss_status process_batch( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( &command_buffer, &model->f32_bf16w_rmsnorm_fn, - &model->residual_activation_buffer, + &context->residual_activation_buffer, /*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), &model->shared_weight_buffer, /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, - &model->rmsnorm_activation_buffer, + &context->rmsnorm_activation_buffer, /*output_offset=*/0, num_output_tokens, model->embedding_dim, @@ -280,13 +319,13 @@ static enum gptoss_status process_batch( &command_buffer, &model->f32_bf16w_matmul_fn, /*threadgroup_size=*/256, - &model->rmsnorm_activation_buffer, + &context->rmsnorm_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n, &model->shared_weight_buffer, /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, - &model->gate_activation_buffer, + &context->gate_activation_buffer, /*output_offset=*/0, /*num_tokens=*/num_output_tokens, /*num_cols=*/model->embedding_dim, @@ -303,8 +342,8 @@ static enum gptoss_status process_batch( status = gptoss_metal_command_buffer_encode_launch_f32_topk( &command_buffer, &model->f32_topk_softmax_e32_k4_fn, - &model->gate_activation_buffer, /*input_offset=*/0, - &model->expert_activation_buffer, /*output_offset=*/0, + &context->gate_activation_buffer, /*input_offset=*/0, + &context->expert_activation_buffer, /*output_offset=*/0, num_output_tokens, model->num_experts, model->num_active_experts); @@ -314,8 +353,8 @@ static enum gptoss_status process_batch( status = gptoss_metal_command_buffer_encode_launch_f32_topk( &command_buffer, &model->f32_topk_softmax_e128_k4_fn, - &model->gate_activation_buffer, /*input_offset=*/0, - &model->expert_activation_buffer, /*output_offset=*/0, + &context->gate_activation_buffer, /*input_offset=*/0, + &context->expert_activation_buffer, /*output_offset=*/0, num_output_tokens, model->num_experts, model->num_active_experts); @@ -334,12 +373,12 @@ static enum gptoss_status process_batch( &command_buffer, &model->f32_mf4w_moe_matmul_swiglu_fn, /*threadgroup_size=*/512, - &model->rmsnorm_activation_buffer, /*input_offset=*/0, - &model->expert_activation_buffer, /*expert_offset=*/0, + &context->rmsnorm_activation_buffer, /*input_offset=*/0, + &context->expert_activation_buffer, /*expert_offset=*/0, &model->block_weight_buffers[n], /*weight_block_offset=*/0, &model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_swiglu_scale_offset, &model->block_weight_buffers[n], /*bias_offset=*/model->mlp_swiglu_bias_offset, - &model->swiglu_activation_buffer, /*output_offset=*/0, + &context->swiglu_activation_buffer, /*output_offset=*/0, model->swiglu_limit, model->per_expert_block_weight_size, num_output_tokens, @@ -355,12 +394,12 @@ static enum gptoss_status process_batch( &command_buffer, &model->f32_mf4w_moe_matmul_fn, /*threadgroup_size=*/512, - &model->swiglu_activation_buffer, /*input_offset=*/0, - &model->expert_activation_buffer, /*expert_offset=*/0, + &context->swiglu_activation_buffer, /*input_offset=*/0, + &context->expert_activation_buffer, /*expert_offset=*/0, &model->block_weight_buffers[n], /*weight_block_offset=*/model->mlp_out_block_offset, &model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_out_scale_offset, &model->block_weight_buffers[n], /*bias_offset=*/model->mlp_out_bias_offset, - &model->moe_activation_buffer, /*output_offset=*/0, + &context->moe_activation_buffer, /*output_offset=*/0, model->per_expert_block_weight_size, num_output_tokens, model->num_active_experts, @@ -376,11 +415,11 @@ static enum gptoss_status process_batch( &model->f32_accumulate_e4_fn, /*threadgroup_size=*/256, model->max_threadgroups, - &model->moe_activation_buffer, + &context->moe_activation_buffer, /*input_offset=*/0, - &model->expert_activation_buffer, + &context->expert_activation_buffer, /*expert_offset=*/0, - &model->residual_activation_buffer, + &context->residual_activation_buffer, /*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), model->embedding_dim, num_output_tokens, @@ -395,11 +434,11 @@ static enum gptoss_status process_batch( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( &command_buffer, &model->f32_bf16w_rmsnorm_fn, - &model->residual_activation_buffer, + &context->residual_activation_buffer, /*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), &model->shared_weight_buffer, /*weight_offset=*/model->rmsnorm_weight_offset, - &model->rmsnorm_activation_buffer, + &context->rmsnorm_activation_buffer, /*output_offset=*/0, /*num_tokens=*/num_output_tokens, /*num_channels=*/model->embedding_dim, @@ -424,7 +463,7 @@ static enum gptoss_status process_batch( &model->f32_bf16w_unembedding_fn, /*threadgroup_size=*/256, model->max_threadgroups, - &model->rmsnorm_activation_buffer, + &context->rmsnorm_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, /*weight_offset=*/model->unembedding_weight_offset, @@ -700,6 +739,17 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release( { if (context != NULL) { if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) { + // Activation buffers + gptoss_metal_buffer_release(&context->residual_activation_buffer); + gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer); + gptoss_metal_buffer_release(&context->qkv_activation_buffer); + gptoss_metal_buffer_release(&context->sdpa_activation_buffer); + gptoss_metal_buffer_release(&context->gate_activation_buffer); + gptoss_metal_buffer_release(&context->expert_activation_buffer); + gptoss_metal_buffer_release(&context->swiglu_activation_buffer); + gptoss_metal_buffer_release(&context->moe_activation_buffer); + + // Input/output buffers gptoss_metal_buffer_release(&context->token_buffer); gptoss_metal_buffer_release(&context->score_buffer); gptoss_metal_buffer_release(&context->prob_buffer); diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index e2a45647..af24419e 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -75,17 +75,6 @@ struct gptoss_model { struct gptoss_metal_function f32_sdpa_q8_d64_fn; struct gptoss_metal_function f32_softmax_fn; - // Activation buffers. - // TODO: merge into a single buffer. - struct gptoss_metal_buffer residual_activation_buffer; // Residual stream - struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output - struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output - struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output - struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output - struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions - struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output - struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert) - size_t per_block_shared_weights_size; size_t per_expert_block_weight_size; @@ -135,6 +124,18 @@ struct gptoss_context { size_t kvcache_size; size_t allocation_size; + // Activation buffers. + // TODO: merge into a single buffer. + struct gptoss_metal_buffer residual_activation_buffer; // Residual stream + struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output + struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output + struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output + struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output + struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions + struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output + struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert) + + // Input/output buffers. struct gptoss_metal_buffer token_buffer; // uint32 token IDs struct gptoss_metal_buffer score_buffer; // unembedding outputs struct gptoss_metal_buffer prob_buffer; diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index aba8a27e..e3aeb98f 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -421,45 +421,6 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( model->weights_size += moe_block_weight_size; } - // Activation buffers - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &model->residual_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &model->rmsnorm_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &model->qkv_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &model->sdpa_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(float), NULL, &model->gate_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &model->expert_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &model->swiglu_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &model->moe_activation_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - - model->allocation_size = - model->residual_activation_buffer.size + model->rmsnorm_activation_buffer.size + - model->qkv_activation_buffer.size + model->sdpa_activation_buffer.size + - model->gate_activation_buffer.size + model->expert_activation_buffer.size + model->swiglu_activation_buffer.size + model->moe_activation_buffer.size; - // Commit tokenizer model->tokenizer = tokenizer; tokenizer = NULL; @@ -510,16 +471,6 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release( if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) { gptoss_tokenizer_release(model->tokenizer); - // Activation buffers - gptoss_metal_buffer_release(&model->residual_activation_buffer); - gptoss_metal_buffer_release(&model->rmsnorm_activation_buffer); - gptoss_metal_buffer_release(&model->qkv_activation_buffer); - gptoss_metal_buffer_release(&model->sdpa_activation_buffer); - gptoss_metal_buffer_release(&model->gate_activation_buffer); - gptoss_metal_buffer_release(&model->expert_activation_buffer); - gptoss_metal_buffer_release(&model->swiglu_activation_buffer); - gptoss_metal_buffer_release(&model->moe_activation_buffer); - // Weight buffers gptoss_metal_buffer_release(&model->shared_weight_buffer); for (uint32_t n = 0; n < model->num_blocks; n++) { From ec7914d092ec6c0347e51df6802d946bd3913f61 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Fri, 8 Aug 2025 10:28:06 -0700 Subject: [PATCH 27/91] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 598fe992..b49ec6ec 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/ We're releasing two flavors of these open models: -- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single H100 GPU (117B parameters with 5.1B active parameters) +- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single 80GB GPU (like NVIDIA H100 or AMD MI300X) (117B parameters with 5.1B active parameters) - `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters) Both models were trained using our [harmony response format][harmony] and should only be used with this format; otherwise, they will not work correctly. @@ -27,7 +27,7 @@ Both models were trained using our [harmony response format][harmony] and should - **Full chain-of-thought:** Provides complete access to the model's reasoning process, facilitating easier debugging and greater trust in outputs. This information is not intended to be shown to end users. - **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning. - **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs. -- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, allowing `gpt-oss-120b` to run on a single H100 GPU and `gpt-oss-20b` to run within 16GB of memory. +- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, allowing `gpt-oss-120b` to run on a single 80GB GPU (like NVIDIA H100 or AMD MI300X) and `gpt-oss-20b` to run within 16GB of memory. ### Inference examples From 4589fbb727fe35f99fbf13ced1bd8882e1ee599b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 8 Aug 2025 16:41:48 -0400 Subject: [PATCH 28/91] fix packaging (#90) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b40487a2..77d8d26c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ requires = ["setuptools>=68"] build-backend = "gpt_oss_build_backend.backend" backend-path = ["_build"] -[tool.setuptools] -packages = ["gpt_oss"] +[tool.setuptools.packages.find] +include = ["gpt_oss*"] [tool.scikit-build] cmake.source-dir = "." # pick up the root CMakeLists.txt From 1a9e106a1eeb0151e97f970a597ed5389f82bde4 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Sat, 9 Aug 2025 18:18:14 -0700 Subject: [PATCH 29/91] Update README.md (#29) - Put quotes around `pip -e` args - Use `hf` over `huggingface-cli` (deprecated) --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b49ec6ec..ca4cfb42 100644 --- a/README.md +++ b/README.md @@ -161,10 +161,10 @@ You can download the model weights from the [Hugging Face Hub](https://huggingfa ```shell # gpt-oss-120b -huggingface-cli download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/ +hf download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/ # gpt-oss-20b -huggingface-cli download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/ +hf download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/ ``` ## Reference PyTorch implementation @@ -174,7 +174,7 @@ We include an inefficient reference PyTorch implementation in [gpt_oss/torch/mod To run the reference implementation. Install dependencies: ```shell -pip install -e .[torch] +pip install -e ".[torch]" ``` And then run: @@ -198,7 +198,7 @@ pip install -r python/requirements.txt pip install -e . --verbose --no-build-isolation # Install the gpt-oss triton implementation -pip install -e .[triton] +pip install -e ".[triton]" ``` And then run: @@ -218,7 +218,7 @@ Additionally we are providing a reference implementation for Metal to run on App The implementation will get automatically compiled when running the `.[metal]` installation on an Apple Silicon device: ```shell -pip install -e .[metal] +pip install -e ".[metal]" ``` To perform inference you'll need to first convert the SafeTensor weights from Hugging Face into the right format using: @@ -230,8 +230,8 @@ python gpt_oss/metal/scripts/create-local-model.py -s -d Date: Sat, 9 Aug 2025 18:23:40 -0700 Subject: [PATCH 30/91] Update README.md (#58) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ca4cfb42..dc6a5249 100644 --- a/README.md +++ b/README.md @@ -171,7 +171,7 @@ hf download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/ We include an inefficient reference PyTorch implementation in [gpt_oss/torch/model.py](gpt_oss/torch/model.py). This code uses basic PyTorch operators to show the exact model architecture, with a small addition of supporting tensor parallelism in MoE so that the larger model can run with this code (e.g., on 4xH100 or 2xH200). In this implementation, we upcast all weights to BF16 and run the model in BF16. -To run the reference implementation. Install dependencies: +To run the reference implementation, install these dependencies: ```shell pip install -e ".[torch]" From 220a05855f332ed70e5f29add6bb228d83ad2b63 Mon Sep 17 00:00:00 2001 From: Hasan Erdem AK <70165677+hasanerdemak@users.noreply.github.com> Date: Sun, 10 Aug 2025 04:25:56 +0300 Subject: [PATCH 31/91] Fix typos and improve grammar in README (#61) Corrected several typographical errors and improved grammar throughout the README for better clarity and professionalism. Changes include fixing word forms, possessives, and minor phrasing issues. Co-authored-by: Romain Huet --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dc6a5249..4322345d 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ print(outputs[0]["generated_text"][-1]) #### vLLM -vLLM recommends using [`uv`](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible webserver. The following command will automatically download the model and start the server. +vLLM recommends using [`uv`](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible web server. The following command will automatically download the model and start the server. ```bash uv pip install --pre vllm==0.10.1+gptoss \ @@ -130,7 +130,7 @@ This repository provides a collection of reference implementations: ### Requirements -- python 3.12 +- Python 3.12 - On macOS: Install the Xcode CLI tools --> `xcode-select --install` - On Linux: These reference implementations require CUDA - On Windows: These reference implementations have not been tested on Windows. Try using solutions like Ollama if you are trying to run the model locally. @@ -171,7 +171,7 @@ hf download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/ We include an inefficient reference PyTorch implementation in [gpt_oss/torch/model.py](gpt_oss/torch/model.py). This code uses basic PyTorch operators to show the exact model architecture, with a small addition of supporting tensor parallelism in MoE so that the larger model can run with this code (e.g., on 4xH100 or 2xH200). In this implementation, we upcast all weights to BF16 and run the model in BF16. -To run the reference implementation, install these dependencies: +To run the reference implementation, install the dependencies: ```shell pip install -e ".[torch]" @@ -227,7 +227,7 @@ To perform inference you'll need to first convert the SafeTensor weights from Hu python gpt_oss/metal/scripts/create-local-model.py -s -d ``` -Or downloaded the pre-converted weight: +Or download the pre-converted weight: ```shell hf download openai/gpt-oss-120b --include "metal/*" --local-dir gpt-oss-120b/metal/ @@ -250,7 +250,7 @@ We also include two system tools for the model: browsing and python container. C ### Terminal Chat -The terminal chat application is a basic example on how to use the harmony format together with the PyTorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used. +The terminal chat application is a basic example of how to use the harmony format together with the PyTorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used. ```bash usage: python -m gpt_oss.chat [-h] [-r REASONING_EFFORT] [-a] [-b] [--show-browser-results] [-p] [--developer-message DEVELOPER_MESSAGE] [-c CONTEXT] [--raw] [--backend {triton,torch,vllm}] FILE @@ -289,7 +289,7 @@ You can start this server with the following inference backends: - `triton` — uses the triton implementation - `metal` — uses the metal implementation on Apple Silicon only -- `ollama` — uses the Ollama /api/generate API as a inference solution +- `ollama` — uses the Ollama /api/generate API as an inference solution - `vllm` — uses your installed vllm version to perform inference - `transformers` — uses your installed transformers version to perform local inference @@ -468,10 +468,10 @@ if last_message.recipient == "python": We released the models with native quantization support. Specifically, we use [MXFP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) for the linear projection weights in the MoE layer. We store the MoE tensor in two parts: -- `tensor.blocks` stores the actual fp4 values. We pack every two value in one `uint8` value. +- `tensor.blocks` stores the actual fp4 values. We pack every two values in one `uint8` value. - `tensor.scales` stores the block scale. The block scaling is done among the last dimension for all MXFP4 tensors. -All other tensors will be in BF16. We also recommend use BF16 as the activation precision for the model. +All other tensors will be in BF16. We also recommend using BF16 as the activation precision for the model. ### Recommended Sampling Parameters From 954f47f88b610ca3d92b1ce895d92563b7a8381a Mon Sep 17 00:00:00 2001 From: palenciavik Date: Sat, 9 Aug 2025 18:34:19 -0700 Subject: [PATCH 32/91] Update README.md (#71) Addresses multiple small typos and grammatical errors in the main README.md as well as some improvements in phrasing for clarity. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4322345d..5346cf0c 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Both models were trained using our [harmony response format][harmony] and should #### Transformers -You can use `gpt-oss-120b` and `gpt-oss-20b` with Transformers. If you use the Transformers chat template it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package. +You can use `gpt-oss-120b` and `gpt-oss-20b` with the Transformers library. If you use Transformers' chat template, it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package. ```python from transformers import pipeline @@ -279,7 +279,7 @@ options: ``` > [!NOTE] -> The torch and triton implementation requires original checkpoint under `gpt-oss-120b/original/` and `gpt-oss-20b/original/` respectively. While vLLM uses the Hugging Face converted checkpoint under `gpt-oss-120b/` and `gpt-oss-20b/` root directory respectively. +> The torch and triton implementations require original checkpoint under `gpt-oss-120b/original/` and `gpt-oss-20b/original/` respectively. While vLLM uses the Hugging Face converted checkpoint under `gpt-oss-120b/` and `gpt-oss-20b/` root directory respectively. ### Responses API From f4096361dd3e4fefa80755ec16e6a2ca1491cc76 Mon Sep 17 00:00:00 2001 From: melad Date: Sun, 10 Aug 2025 04:35:37 +0300 Subject: [PATCH 33/91] Update README.md (#87) Critical edits Co-authored-by: Romain Huet --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5346cf0c..6d3beca6 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python -m gpt_oss.generate --backend triton gpt-oss-120b/original/ ``` -If you encounter `torch.OutOfMemoryError` make sure to turn on the expandable allocator to avoid crashes when loading weights from the checkpoint. +If you encounter `torch.OutOfMemoryError`, make sure to turn on the expandable allocator to avoid crashes when loading weights from the checkpoint. ## Reference Metal implementation From 82a3bad69b8071196e3dfee2b52316595b3dcd02 Mon Sep 17 00:00:00 2001 From: Buddhsen Tripathi Date: Sun, 10 Aug 2025 07:13:21 +0530 Subject: [PATCH 34/91] Update README.md (#41) Co-authored-by: Romain Huet --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6d3beca6..039922e8 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Check out our [awesome list](./awesome-gpt-oss.md) for a broader collection of g This repository provides a collection of reference implementations: - **Inference:** - - [`torch`](#reference-pytorch-implementation) — a non-optimized [PyTorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4x H100s because it's not optimized + - [`torch`](#reference-pytorch-implementation) — a non-optimized [PyTorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4× H100 GPUs due to lack of optimization. - [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [PyTorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching - [`metal`](#reference-metal-implementation) — a Metal-specific implementation for running the models on Apple Silicon hardware - **Tools:** @@ -227,7 +227,7 @@ To perform inference you'll need to first convert the SafeTensor weights from Hu python gpt_oss/metal/scripts/create-local-model.py -s -d ``` -Or download the pre-converted weight: +Or download the pre-converted weights: ```shell hf download openai/gpt-oss-120b --include "metal/*" --local-dir gpt-oss-120b/metal/ From d7f9708d44cb019f781710d1e9fc48f0f6e05371 Mon Sep 17 00:00:00 2001 From: Yoginth Date: Sat, 9 Aug 2025 18:45:39 -0700 Subject: [PATCH 35/91] fix: typos across the codebase (#69) --- gpt_oss/evals/healthbench_eval.py | 2 +- gpt_oss/metal/include/gpt-oss/functions.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gpt_oss/evals/healthbench_eval.py b/gpt_oss/evals/healthbench_eval.py index bf136ac3..1862898b 100644 --- a/gpt_oss/evals/healthbench_eval.py +++ b/gpt_oss/evals/healthbench_eval.py @@ -72,7 +72,7 @@ } ``` -As another example, if the critera says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this: +As another example, if the criteria says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this: ```json { diff --git a/gpt_oss/metal/include/gpt-oss/functions.h b/gpt_oss/metal/include/gpt-oss/functions.h index a81bf50a..9966493c 100644 --- a/gpt_oss/metal/include/gpt-oss/functions.h +++ b/gpt_oss/metal/include/gpt-oss/functions.h @@ -267,7 +267,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_reset( gptoss_context_t context); /* - * Pre-process the tokens in the Context and generate probability distrubution over the next token. + * Pre-process the tokens in the Context and generate probability distribution over the next token. * * @param context Context object created by gptoss_context_create. * From 0d45dfd0060823758240156e240bb995428997db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Jos=C3=A9=20Estrada?= <69777842+jjestrada2@users.noreply.github.com> Date: Sat, 9 Aug 2025 21:47:59 -0400 Subject: [PATCH 36/91] [MINOR] fix: correct spelling error from "wnat" to "want" (#99) --- gpt-oss-mcp-server/python_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt-oss-mcp-server/python_server.py b/gpt-oss-mcp-server/python_server.py index bea86587..7ec35308 100644 --- a/gpt-oss-mcp-server/python_server.py +++ b/gpt-oss-mcp-server/python_server.py @@ -20,7 +20,7 @@ When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you. """, annotations={ - # Harmony format don't wnat this schema to be part of it because it's simple text in text out + # Harmony format don't want this schema to be part of it because it's simple text in text out "include_in_prompt": False, }) async def python(code: str) -> str: From c77966fc0fda390b0abeeecdec7134433fe9f224 Mon Sep 17 00:00:00 2001 From: Tomoya Fujita Date: Sun, 10 Aug 2025 10:52:21 +0900 Subject: [PATCH 37/91] a few typo fixes. (#102) Signed-off-by: Tomoya Fujita From 79eaf7fcfb77db428dc40fde4dd662bc79b8daa4 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Mon, 11 Aug 2025 14:04:03 -0700 Subject: [PATCH 38/91] Add API compatibility test (#114) * add compatibility test * update readme * update test suite * fix example config * fix typo * bump version --- compatibility-test/.gitignore | 142 +++ compatibility-test/README.md | 29 + compatibility-test/analysis.ts | 142 +++ compatibility-test/cases.jsonl | 30 + compatibility-test/index.ts | 196 ++++ compatibility-test/package-lock.json | 1633 ++++++++++++++++++++++++++ compatibility-test/package.json | 11 + compatibility-test/providers.ts | 15 + compatibility-test/runCase.ts | 331 ++++++ compatibility-test/tools.ts | 156 +++ pyproject.toml | 2 +- 11 files changed, 2686 insertions(+), 1 deletion(-) create mode 100644 compatibility-test/.gitignore create mode 100644 compatibility-test/README.md create mode 100644 compatibility-test/analysis.ts create mode 100644 compatibility-test/cases.jsonl create mode 100644 compatibility-test/index.ts create mode 100644 compatibility-test/package-lock.json create mode 100644 compatibility-test/package.json create mode 100644 compatibility-test/providers.ts create mode 100644 compatibility-test/runCase.ts create mode 100644 compatibility-test/tools.ts diff --git a/compatibility-test/.gitignore b/compatibility-test/.gitignore new file mode 100644 index 00000000..2ba323b0 --- /dev/null +++ b/compatibility-test/.gitignore @@ -0,0 +1,142 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) +web_modules/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional stylelint cache +.stylelintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variable files +.env +.env.* +!.env.example + +# parcel-bundler cache (https://parceljs.org/) +.cache +.parcel-cache + +# Next.js build output +.next +out + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and not Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# vuepress v2.x temp and cache directory +.temp +.cache + +# Sveltekit cache directory +.svelte-kit/ + +# vitepress build output +**/.vitepress/dist + +# vitepress cache directory +**/.vitepress/cache + +# Docusaurus cache and generated files +.docusaurus + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# Firebase cache directory +.firebase/ + +# TernJS port file +.tern-port + +# Stores VSCode versions used for testing VSCode extensions +.vscode-test + +# yarn v3 +.pnp.* +.yarn/* +!.yarn/patches +!.yarn/plugins +!.yarn/releases +!.yarn/sdks +!.yarn/versions + +# Vite logs files +vite.config.js.timestamp-* +vite.config.ts.timestamp-* + +rollout_*.jsonl +analysis_*.json \ No newline at end of file diff --git a/compatibility-test/README.md b/compatibility-test/README.md new file mode 100644 index 00000000..22e0007f --- /dev/null +++ b/compatibility-test/README.md @@ -0,0 +1,29 @@ +# API Compatibility Test + +This script uses the Agents SDK in TypeScript and the underlying OpenAI client to verify the shape of the API calls but also whether the API performs tool calling. + +## What it tests + +1. + +## How to run + +0. Run `npm install` in this directory. +1. Update `providers.ts` to create an entry for the API to test. Change `vllm` to the provider name of your choice. Use `chat` for Chat Completions tests and `responses` for Responses API tests. +2. Run an initial quick test to make sure things work. This will only run one test + +``` +npm start -- --provider -n 1 -k 1 +``` + +3. Run the full test (runs each test 5 times to test consistency) + +``` +npm start -- --provider -k 5 +``` + +## Considerations + +1. The tests will fail if the API shape does not match the expected behavior +2. Events in the chat API are currently not tested +3. If the schema validation succeeds but the input is wrong the test will still pass for this test. That's because it's likely more of a prompt engineering issue or a validator issue than an API issue as it still nailed the input diff --git a/compatibility-test/analysis.ts b/compatibility-test/analysis.ts new file mode 100644 index 00000000..9c5cf97d --- /dev/null +++ b/compatibility-test/analysis.ts @@ -0,0 +1,142 @@ +export function analyze(caseResults: any[], tries: number) { + // Group results by unique task: test_case + apiType + type TaskKey = string; + const taskKeyFor = (r: any): TaskKey => + `${r.test_case}::${r.result?.apiType}`; + + const successesByTask: Map> = new Map(); + + // Count wrong-input tool calls (schema correct but incorrect arguments) + let wrongInputToolCalls = 0; + + // Count invalid response shapes per API type + const totalByApiType: Record = {}; + const invalidByApiType: Record = {}; + + for (const r of caseResults) { + if (!r?.result || typeof r.result.apiType !== "string") continue; + + // Parse attempt index from run_id `${i}_${k}` safely + let attemptIndex: number | undefined; + if (typeof r.run_id === "string") { + const parts = r.run_id.split("_"); + const k = Number(parts[1]); + if (Number.isFinite(k)) attemptIndex = k; + } + + const key = taskKeyFor(r); + if (!successesByTask.has(key)) successesByTask.set(key, new Map()); + if (attemptIndex != null) { + successesByTask.get(key)!.set(attemptIndex, Boolean(r.success)); + } + + const d = r.result.toolCallingDetails ?? {}; + const calledToolAtLeastOnce = Boolean(d.calledToolAtLeastOnce); + const calledToolWithRightSchema = Boolean(d.calledToolWithRightSchema); + const calledToolWithRightArguments = Boolean( + d.calledToolWithRightArguments + ); + if ( + calledToolAtLeastOnce && + calledToolWithRightSchema && + !calledToolWithRightArguments + ) { + wrongInputToolCalls++; + } + + // Track invalid/total per apiType for response shape + const apiType = r.result.apiType as string; + totalByApiType[apiType] = (totalByApiType[apiType] ?? 0) + 1; + const isValidResponse = r.result.validResponse === true; + if (!isValidResponse) { + invalidByApiType[apiType] = (invalidByApiType[apiType] ?? 0) + 1; + } + } + + const totalTasks = successesByTask.size; + + // Compute pass@k and pass^k for k = 1..tries + const passAtKByK: number[] = []; + const passHatKByK: number[] = []; + + for (let k = 1; k <= tries; k++) { + let tasksSuccessfulK = 0; // any success in first k attempts + let tasksAllSuccessfulK = 0; // all success in first k attempts + + for (const [, attemptsMap] of successesByTask) { + let anySuccess = false; + let allSuccess = true; + for (let i = 0; i < k; i++) { + const v = attemptsMap.get(i) === true; + anySuccess = anySuccess || v; + if (!v) allSuccess = false; + } + if (anySuccess) tasksSuccessfulK++; + if (allSuccess) tasksAllSuccessfulK++; + } + + const passAtK = totalTasks > 0 ? tasksSuccessfulK / totalTasks : 0; + const passHatK = totalTasks > 0 ? tasksAllSuccessfulK / totalTasks : 0; + passAtKByK.push(passAtK); + passHatKByK.push(passHatK); + } + + // Convenience: final k=tries values + const passAtK = passAtKByK[tries - 1] ?? 0; + const passHatK = passHatKByK[tries - 1] ?? 0; + + return { + totalTasks, + passAtKByK, + passHatKByK, + passAtK, + passHatK, + wrongInputToolCalls, + // New stats for invalid response shapes per API + invalidByApiType, + totalByApiType, + }; +} + +export function printAnalysis( + stats: ReturnType, + caseResults: any[], + provider: string, + selectedLines: string[], + tries: number, + skipped: number, + analysisFile: string +) { + const formatPerK = (arr: number[]) => + Array.from({ length: tries }, (_, i) => { + const v = arr[i] ?? 0; + return `${i + 1}=${v.toFixed(3)}`; + }).join(", "); + + console.log("Summary:"); + console.log(` Provider: ${provider}`); + console.log(` Total input cases: ${selectedLines.length}`); + console.log(` Tries: ${tries}`); + console.log(` Total tasks: ${stats.totalTasks}`); + console.log(` Total runs: ${caseResults.length}`); + // Conditionally print invalid response shape stats per API type + if ((stats.totalByApiType["responses"] ?? 0) > 0) { + const bad = stats.invalidByApiType["responses"] ?? 0; + const tot = stats.totalByApiType["responses"] ?? 0; + console.log(` Invalid Responses API responses: ${bad} (out of ${tot})`); + } + if ((stats.totalByApiType["chat"] ?? 0) > 0) { + const bad = stats.invalidByApiType["chat"] ?? 0; + const tot = stats.totalByApiType["chat"] ?? 0; + console.log( + ` Invalid Chat Completions API responses: ${bad} (out of ${tot})` + ); + } + console.log(` pass@k (k=1..${tries}): ${formatPerK(stats.passAtKByK)}`); + console.log(` pass^k (k=1..${tries}): ${formatPerK(stats.passHatKByK)}`); + console.log(` pass@k (k=${tries}): ${stats.passAtK.toFixed(3)}`); + console.log(` pass^k (k=${tries}): ${stats.passHatK.toFixed(3)}`); + console.log(` Wrong-input tool calls: ${stats.wrongInputToolCalls}`); + console.log(` Invalid cases.jsonl lines: ${skipped}`); + console.log(` Analysis written to ${analysisFile}`); +} diff --git a/compatibility-test/cases.jsonl b/compatibility-test/cases.jsonl new file mode 100644 index 00000000..29e7d4e8 --- /dev/null +++ b/compatibility-test/cases.jsonl @@ -0,0 +1,30 @@ +{"tool_name":"get_system_health","input":"Hey, quick check: is everything up and running?","expected_arguments":"{}"} +{"tool_name":"get_system_health","input":"Status report please.","expected_arguments":"{}"} +{"tool_name":"get_system_health","input":"Can you confirm the LLM health before we start?","expected_arguments":"{}"} +{"tool_name":"get_system_health","input":"Need a health snapshot.","expected_arguments":"{}"} +{"tool_name":"get_system_health","input":"Hi, what's the current system health?","expected_arguments":"{}"} +{"tool_name":"markdown_to_html","input":"Convert this markdown to HTML:\n\n# Title\n\nSome *italic* text.","expected_arguments":"{\"markdown\":\"# Title\\n\\nSome *italic* text.\"}"} +{"tool_name":"markdown_to_html","input":"Hey, could you turn `## Docs` into HTML?","expected_arguments":"{\"markdown\":\"## Docs\"}"} +{"tool_name":"markdown_to_html","input":"Please render the following markdown:\n\n- item 1\n- item 2","expected_arguments":"{\"markdown\":\"- item 1\\n- item 2\"}"} +{"tool_name":"markdown_to_html","input":"I have `**bold**` markdown; give me HTML.","expected_arguments":"{\"markdown\":\"**bold**\"}"} +{"tool_name":"markdown_to_html","input":"Markdown to HTML: > quote","expected_arguments":"{\"markdown\":\"> quote\"}"} +{"tool_name":"detect_language","input":"Hey, what language is this: 'Buenos días, ¿cómo estás?'","expected_arguments":"{\"text\":\"Buenos días, ¿cómo estás?\"}"} +{"tool_name":"detect_language","input":"Identify the language: \"Guten Morgen\"","expected_arguments":"{\"text\":\"Guten Morgen\"}"} +{"tool_name":"detect_language","input":"Language detection needed: こんにちは、お元気ですか?","expected_arguments":"{\"text\":\"こんにちは、お元気ですか?\"}"} +{"tool_name":"detect_language","input":"Detect language for: 'Привет, как дела?'","expected_arguments":"{\"text\":\"Привет, как дела?\"}"} +{"tool_name":"detect_language","input":"What language is 'Bonjour tout le monde'?","expected_arguments":"{\"text\":\"Bonjour tout le monde\"}"} +{"tool_name":"generate_chart","input":"Plot a simple line chart for these points: (1,2),(2,4),(3,9).","expected_arguments":"{\"data\":[[1,2],[2,4],[3,9]],\"chart_type\":\"line\"}"} +{"tool_name":"generate_chart","input":"Hey, can I get a bar chart of my sales: 10, 20, 30 across Q1–Q3?","expected_arguments":"{\"data\":[[1,10],[2,20],[3,30]],\"chart_type\":\"bar\",\"title\":\"Quarterly Sales\"}"} +{"tool_name":"generate_chart","input":"Make a scatter chart titled 'Experiment' with x label Time and y label Value for data [ [0,1], [1,1.5], [2,2.2] ].","expected_arguments":"{\"data\":[[0,1],[1,1.5],[2,2.2]],\"chart_type\":\"scatter\",\"title\":\"Experiment\",\"x_label\":\"Time\",\"y_label\":\"Value\"}"} +{"tool_name":"generate_chart","input":"Create a line chart of temperatures 70,72,68,65 over 4 days, label x as 'Day'.","expected_arguments":"{\"data\":[[1,70],[2,72],[3,68],[4,65]],\"chart_type\":\"line\",\"x_label\":\"Day\"}"} +{"tool_name":"generate_chart","input":"Visualize visits per day with a bar chart; numbers: 100,150,120.","expected_arguments":"{\"data\":[[1,100],[2,150],[3,120]],\"chart_type\":\"bar\",\"title\":\"Daily Visits\",\"y_label\":\"Visitors\"}"} +{"tool_name":"query_database","input":"Give me the ids and emails from users table, limit 5.","expected_arguments":"{\"table\":\"users\",\"columns\":[\"id\",\"email\"],\"limit\":5}"} +{"tool_name":"query_database","input":"Hey, fetch order_id and amount from orders where status is 'shipped'.","expected_arguments":"{\"table\":\"orders\",\"columns\":[\"order_id\",\"amount\"],\"filters\":\"status = 'shipped'\"}"} +{"tool_name":"query_database","input":"Retrieve name and price from products ordered by price descending, top 10 please.","expected_arguments":"{\"table\":\"products\",\"columns\":[\"name\",\"price\"],\"limit\":10,\"order_by\":\"price DESC\"}"} +{"tool_name":"query_database","input":"I need the first 3 log entries from audit_log table.","expected_arguments":"{\"table\":\"audit_log\",\"columns\":[\"id\",\"timestamp\",\"action\"],\"limit\":3}"} +{"tool_name":"query_database","input":"Query the customers table for name, city where city = 'Berlin'.","expected_arguments":"{\"table\":\"customers\",\"columns\":[\"name\",\"city\"],\"filters\":\"city = 'Berlin'\"}"} +{"tool_name":"get_weather","input":"What's the weather in San Francisco right now?","expected_arguments":"{\"location\":\"San Francisco\"}"} +{"tool_name":"get_weather","input":"Weather for Tokyo, please.","expected_arguments":"{\"location\":\"Tokyo\"}"} +{"tool_name":"get_weather","input":"Get me the current weather for 10001.","expected_arguments":"{\"location\":\"10001\"}"} +{"tool_name":"get_weather","input":"How's the weather in Paris today?","expected_arguments":"{\"location\":\"Paris\"}"} +{"tool_name":"get_weather","input":"Check the weather for Sydney.","expected_arguments":"{\"location\":\"Sydney\"}"} diff --git a/compatibility-test/index.ts b/compatibility-test/index.ts new file mode 100644 index 00000000..ca6b03dc --- /dev/null +++ b/compatibility-test/index.ts @@ -0,0 +1,196 @@ +import { parseArgs } from "node:util"; +import { createWriteStream } from "node:fs"; +import { readFile, writeFile } from "node:fs/promises"; +import path from "node:path"; +import process from "node:process"; +import { runCase, RunCaseSummary } from "./runCase"; +import { Listr, ListrTaskWrapper } from "listr2"; +import { analyze, printAnalysis } from "./analysis"; + +function formatTimestamp(d: Date): string { + const pad = (n: number) => String(n).padStart(2, "0"); + const yyyy = d.getFullYear(); + const mm = pad(d.getMonth() + 1); + const dd = pad(d.getDate()); + const hh = pad(d.getHours()); + const mi = pad(d.getMinutes()); + const ss = pad(d.getSeconds()); + return `${yyyy}${mm}${dd}_${hh}${mi}${ss}`; +} + +async function main() { + const args = parseArgs({ + options: { + cases: { type: "string", short: "c", default: "cases.jsonl" }, + provider: { type: "string", short: "p", default: "openai" }, + streaming: { type: "boolean", short: "s", default: false }, + maxTurns: { type: "string", short: "t", default: "10" }, + n: { type: "string", short: "n" }, + strict: { type: "boolean", short: "s", default: false }, + tries: { type: "string", short: "k", default: "1" }, + }, + }); + const casesPathArg = args.values.cases; + const provider = args.values.provider as string; + const streaming = Boolean(args.values.streaming); + const maxTurns = Number(args.values.maxTurns ?? 10); + const nRaw = args.values.n as string | undefined; + const triesRaw = args.values.tries as string | undefined; + const tries = triesRaw != null ? Number(triesRaw) : 1; + const limit = nRaw != null ? Number(nRaw) : undefined; + if (limit != null && (!Number.isFinite(limit) || limit <= 0)) { + console.error("--n must be a positive integer"); + process.exitCode = 1; + return; + } + + if (!casesPathArg) { + console.error("--cases is required (path to JSONL file)"); + process.exitCode = 1; + return; + } + + const casesPath = path.isAbsolute(casesPathArg) + ? casesPathArg + : path.join(process.cwd(), casesPathArg); + + const timestamp = formatTimestamp(new Date()); + const defaultFilename = `rollout_${provider}_${timestamp}.jsonl`; + const outputFile = path.join(process.cwd(), defaultFilename); + const analysisFile = path.join( + process.cwd(), + `analysis_${provider}_${timestamp}.json` + ); + + let fileContent: string; + try { + fileContent = await readFile(casesPath, "utf8"); + } catch (err: any) { + console.error( + `Failed to read cases file at ${casesPath}: ${err?.message ?? err}` + ); + process.exitCode = 1; + return; + } + + const lines = fileContent + .split(/\r?\n/) + .map((l) => l.trim()) + .filter((l) => l.length > 0); + + const selectedLines = + typeof limit === "number" ? lines.slice(0, limit) : lines; + + const out = createWriteStream(outputFile, { flags: "w", encoding: "utf8" }); + + const writeLine = (obj: any) => + new Promise((resolve, reject) => { + const str = JSON.stringify(obj) + "\n"; + out.write(str, (err) => (err ? reject(err) : resolve())); + }); + + // Accumulators for post-run analysis + let skipped = 0; // invalid JSON lines + const caseResults: Array<{ + run_id: string; + success: boolean; + provider: string; + test_case: number; + tool_name: string; + input: string; + result: RunCaseSummary; + }> = []; + + async function processIndex( + i: number, + k: number, + task: ListrTaskWrapper + ) { + const line = selectedLines[i]; + let caseObj: any; + try { + caseObj = JSON.parse(line); + } catch (err: any) { + console.error( + `Skipping invalid JSON on line ${i + 1}: ${err?.message ?? err}` + ); + skipped++; + return; + } + + try { + const summaries = await runCase(provider, caseObj, { + maxTurns, + streaming, + strict: args.values.strict, + }); + + for (const summary of summaries) { + const record = { + run_id: `${i}_${k}`, + success: summary.success, + provider, + test_case: i, + tool_name: caseObj.tool_name, + input: caseObj.input, + result: summary, + }; + task.output = `Case ${i} (attempt ${k + 1}): ${ + summary.success ? "Success" : "Failed" + } ${summary.toolCallingDetails.warning || ""}`; + caseResults.push(record); + await writeLine(record); + } + } catch (err: any) { + const record = { + provider, + test_case: i, + tool_name: caseObj?.tool_name, + input: caseObj?.input, + expected_output: caseObj?.expected_output, + instructions: caseObj?.instructions, + error: String(err?.message ?? err), + }; + await writeLine(record); + task.output = `Case ${i} failed: ${err?.message ?? err}`; + } + } + + const listr = new Listr<{ + output: string; + }>( + selectedLines.flatMap((line, index) => { + return Array.from({ length: tries }, (_, attempt) => ({ + title: `Processing case ${index} (attempt ${attempt + 1})`, + task: async (_, task) => { + await processIndex(index, attempt, task); + }, + rendererOptions: { persistentOutput: true }, + })); + }), + { + concurrent: 5, + } + ); + + await listr.run(); + + await new Promise((resolve) => out.end(resolve)); + console.log(`Results written to ${outputFile}`); + const stats = analyze(caseResults, tries); + await writeFile(analysisFile, JSON.stringify(stats, null, 2), "utf8"); + printAnalysis( + stats, + caseResults, + provider, + selectedLines, + tries, + skipped, + analysisFile + ); +} + +main().catch((err) => { + console.error(err); + process.exitCode = 1; +}); diff --git a/compatibility-test/package-lock.json b/compatibility-test/package-lock.json new file mode 100644 index 00000000..89b6a5e8 --- /dev/null +++ b/compatibility-test/package-lock.json @@ -0,0 +1,1633 @@ +{ + "name": "compatibility-test", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "dependencies": { + "@openai/agents": "^0.0.15", + "ajv": "^8.17.1", + "listr2": "^9.0.1" + } + }, + "node_modules/@modelcontextprotocol/sdk": { + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.17.1.tgz", + "integrity": "sha512-CPle1OQehbWqd25La9Ack5B07StKIxh4+Bf19qnpZKJC1oI22Y0czZHbifjw1UoczIfKBwBDAp/dFxvHG13B5A==", + "license": "MIT", + "optional": true, + "dependencies": { + "ajv": "^6.12.6", + "content-type": "^1.0.5", + "cors": "^2.8.5", + "cross-spawn": "^7.0.5", + "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", + "express": "^5.0.1", + "express-rate-limit": "^7.5.0", + "pkce-challenge": "^5.0.0", + "raw-body": "^3.0.0", + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@modelcontextprotocol/sdk/node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "license": "MIT", + "optional": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/@modelcontextprotocol/sdk/node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "license": "MIT", + "optional": true + }, + "node_modules/@openai/agents": { + "version": "0.0.15", + "resolved": "https://registry.npmjs.org/@openai/agents/-/agents-0.0.15.tgz", + "integrity": "sha512-B8y+WyWOeHowflPx09pyCfcqikC4OYWK27HTyNGt1oraXv93CzuamSr76iAaU1nWQ1MPbUwl6LHPX4BPUikVkQ==", + "license": "MIT", + "dependencies": { + "@openai/agents-core": "0.0.15", + "@openai/agents-openai": "0.0.15", + "@openai/agents-realtime": "0.0.15", + "debug": "^4.4.0", + "openai": "^5.10.1" + } + }, + "node_modules/@openai/agents-core": { + "version": "0.0.15", + "resolved": "https://registry.npmjs.org/@openai/agents-core/-/agents-core-0.0.15.tgz", + "integrity": "sha512-ODTqttjW0s0ejBe5PKnYRlFbJSZH2IO6OtUlRhIKmWiWrX6pGRxvpKjTSOXy8DEtpRHBj6Nhky0UoSlO6eOkDQ==", + "license": "MIT", + "dependencies": { + "@openai/zod": "npm:zod@3.25.40 - 3.25.67", + "debug": "^4.4.0", + "openai": "^5.10.1" + }, + "optionalDependencies": { + "@modelcontextprotocol/sdk": "^1.12.0" + }, + "peerDependencies": { + "zod": "3.25.40 - 3.25.67" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, + "node_modules/@openai/agents-openai": { + "version": "0.0.15", + "resolved": "https://registry.npmjs.org/@openai/agents-openai/-/agents-openai-0.0.15.tgz", + "integrity": "sha512-YIX3n98HdmmWKkb/71OB+DCQUYyGEpqfzPjejzdtNLUvAEs3jvXf7nkC8oTISsuCwrirgBz0rQEefeo0oUlyFQ==", + "license": "MIT", + "dependencies": { + "@openai/agents-core": "0.0.15", + "@openai/zod": "npm:zod@3.25.40 - 3.25.67", + "debug": "^4.4.0", + "openai": "^5.10.1" + } + }, + "node_modules/@openai/agents-realtime": { + "version": "0.0.15", + "resolved": "https://registry.npmjs.org/@openai/agents-realtime/-/agents-realtime-0.0.15.tgz", + "integrity": "sha512-kSZzMyij9Xt3BpMb/9snuVnu7a5qKZLyhtN/kWMA+wmfETvWz23BBz6tbO5xOmurAt9//OktkB+94e0T0RBtlA==", + "license": "MIT", + "dependencies": { + "@openai/agents-core": "0.0.15", + "@openai/zod": "npm:zod@3.25.40 - 3.25.67", + "@types/ws": "^8.18.1", + "debug": "^4.4.0", + "ws": "^8.18.1" + } + }, + "node_modules/@openai/zod": { + "name": "zod", + "version": "3.25.67", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz", + "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/@types/node": { + "version": "24.2.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-24.2.0.tgz", + "integrity": "sha512-3xyG3pMCq3oYCNg7/ZP+E1ooTaGB4cG8JWRsqqOYQdbWNY4zbaV0Ennrd7stjiJEFZCaybcIgpTjJWHRfBSIDw==", + "license": "MIT", + "dependencies": { + "undici-types": "~7.10.0" + } + }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/accepts": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-2.0.0.tgz", + "integrity": "sha512-5cvg6CtKwfgdmVqY1WIiXKc3Q1bkRqGLi+2W/6ao+6Y7gu/RCwRuAhGEzh5B4KlszSuTLgZYuqFqo5bImjNKng==", + "license": "MIT", + "optional": true, + "dependencies": { + "mime-types": "^3.0.0", + "negotiator": "^1.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-escapes": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/ansi-escapes/-/ansi-escapes-7.0.0.tgz", + "integrity": "sha512-GdYO7a61mR0fOlAsvC9/rIHf7L96sBc6dEWzeOu+KAea5bZyQRPIpojrVoI4AXGJS/ycu/fBTdLrUkA4ODrvjw==", + "license": "MIT", + "dependencies": { + "environment": "^1.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ansi-regex": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.1.0.tgz", + "integrity": "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/ansi-styles": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz", + "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/body-parser": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-2.2.0.tgz", + "integrity": "sha512-02qvAaxv8tp7fBa/mw1ga98OGm+eCbqzJOKoRt70sLmfEEi+jyBYVTDGfCL/k06/4EMk/z01gCe7HoCH/f2LTg==", + "license": "MIT", + "optional": true, + "dependencies": { + "bytes": "^3.1.2", + "content-type": "^1.0.5", + "debug": "^4.4.0", + "http-errors": "^2.0.0", + "iconv-lite": "^0.6.3", + "on-finished": "^2.4.1", + "qs": "^6.14.0", + "raw-body": "^3.0.0", + "type-is": "^2.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "license": "MIT", + "optional": true, + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/cli-cursor": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/cli-cursor/-/cli-cursor-5.0.0.tgz", + "integrity": "sha512-aCj4O5wKyszjMmDT4tZj93kxyydN/K5zPWSCe6/0AV/AA1pqe5ZBIw0a2ZfPQV7lL5/yb5HsUreJ6UFAF1tEQw==", + "license": "MIT", + "dependencies": { + "restore-cursor": "^5.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/cli-truncate": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/cli-truncate/-/cli-truncate-4.0.0.tgz", + "integrity": "sha512-nPdaFdQ0h/GEigbPClz11D0v/ZJEwxmeVZGeMo3Z5StPtUTkA9o1lD6QwoirYiSDzbcwn2XcjwmCp68W1IS4TA==", + "license": "MIT", + "dependencies": { + "slice-ansi": "^5.0.0", + "string-width": "^7.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/colorette": { + "version": "2.0.20", + "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz", + "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==", + "license": "MIT" + }, + "node_modules/content-disposition": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-1.0.0.tgz", + "integrity": "sha512-Au9nRL8VNUut/XSzbQA38+M78dzP4D+eqg3gfJHMIHHYa3bg067xj1KxMUWj+VULbiZMowKngFFbKczUrNJ1mg==", + "license": "MIT", + "optional": true, + "dependencies": { + "safe-buffer": "5.2.1" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie-signature": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.2.2.tgz", + "integrity": "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=6.6.0" + } + }, + "node_modules/cors": { + "version": "2.8.5", + "resolved": "https://registry.npmjs.org/cors/-/cors-2.8.5.tgz", + "integrity": "sha512-KIHbLJqu73RGr/hnbrO9uBeixNGuvSQjul/jdFvS/KFSIH1hWVd1ng7zOHx+YrEfInLG7q4n6GHQ9cDtxv/P6g==", + "license": "MIT", + "optional": true, + "dependencies": { + "object-assign": "^4", + "vary": "^1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", + "optional": true, + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/debug": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", + "integrity": "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "optional": true, + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", + "license": "MIT", + "optional": true + }, + "node_modules/emoji-regex": { + "version": "10.4.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-10.4.0.tgz", + "integrity": "sha512-EC+0oUMY1Rqm4O6LLrgjtYDvcVYTy7chDnM4Q7030tP4Kwj3u/pR6gP9ygnp2CJMK5Gq+9Q2oqmrFJAz01DXjw==", + "license": "MIT" + }, + "node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/environment": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/environment/-/environment-1.1.0.tgz", + "integrity": "sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "license": "MIT", + "optional": true, + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha512-NiSupZ4OeuGwr68lGIeym/ksIZMJodUGOSCZ/FSnTxcrekbvqrgdUxlJOMpijaKZVjAJrWrGs/6Jy8OMuyj9ow==", + "license": "MIT", + "optional": true + }, + "node_modules/etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/eventemitter3": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-5.0.1.tgz", + "integrity": "sha512-GWkBvjiSZK87ELrYOSESUYeVIc9mvLLf/nXalMOS5dYrgZq9o5OVkbZAVM06CVxYsCwH9BDZFPlQTlPA1j4ahA==", + "license": "MIT" + }, + "node_modules/eventsource": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/eventsource/-/eventsource-3.0.7.tgz", + "integrity": "sha512-CRT1WTyuQoD771GW56XEZFQ/ZoSfWid1alKGDYMmkt2yl8UXrVR4pspqWNEcqKvVIzg6PAltWjxcSSPrboA4iA==", + "license": "MIT", + "optional": true, + "dependencies": { + "eventsource-parser": "^3.0.1" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/eventsource-parser": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.3.tgz", + "integrity": "sha512-nVpZkTMM9rF6AQ9gPJpFsNAMt48wIzB5TQgiTLdHiuO8XEDhUgZEhqKlZWXbIzo9VmJ/HvysHqEaVeD5v9TPvA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/express": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/express/-/express-5.1.0.tgz", + "integrity": "sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==", + "license": "MIT", + "optional": true, + "dependencies": { + "accepts": "^2.0.0", + "body-parser": "^2.2.0", + "content-disposition": "^1.0.0", + "content-type": "^1.0.5", + "cookie": "^0.7.1", + "cookie-signature": "^1.2.1", + "debug": "^4.4.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "finalhandler": "^2.1.0", + "fresh": "^2.0.0", + "http-errors": "^2.0.0", + "merge-descriptors": "^2.0.0", + "mime-types": "^3.0.0", + "on-finished": "^2.4.1", + "once": "^1.4.0", + "parseurl": "^1.3.3", + "proxy-addr": "^2.0.7", + "qs": "^6.14.0", + "range-parser": "^1.2.1", + "router": "^2.2.0", + "send": "^1.1.0", + "serve-static": "^2.2.0", + "statuses": "^2.0.1", + "type-is": "^2.0.1", + "vary": "^1.1.2" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/express-rate-limit": { + "version": "7.5.1", + "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-7.5.1.tgz", + "integrity": "sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/express-rate-limit" + }, + "peerDependencies": { + "express": ">= 4.11" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "license": "MIT" + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "license": "MIT", + "optional": true + }, + "node_modules/fast-uri": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", + "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/finalhandler": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-2.1.0.tgz", + "integrity": "sha512-/t88Ty3d5JWQbWYgaOGCCYfXRwV1+be02WqYYlL6h0lEiUAMPM8o8qKGO01YIkOHzka2up08wvgYD0mDiI+q3Q==", + "license": "MIT", + "optional": true, + "dependencies": { + "debug": "^4.4.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "on-finished": "^2.4.1", + "parseurl": "^1.3.3", + "statuses": "^2.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/forwarded": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", + "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fresh": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-2.0.0.tgz", + "integrity": "sha512-Rx/WycZ60HOaqLKAi6cHRKKI7zxWbJ31MhntmtwMoaTeF7XFH9hhBp8vITaMidfljRQ6eYWCKkaTK+ykVJHP2A==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "optional": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-east-asian-width": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-east-asian-width/-/get-east-asian-width-1.3.0.tgz", + "integrity": "sha512-vpeMIQKxczTD/0s2CdEWHcb0eeJe6TFjxb+J5xgX7hScxqrGuyjmv4c1D4A/gelKfyox0gJJwIHF+fLjeaM8kQ==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "optional": true, + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/http-errors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", + "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "depd": "2.0.0", + "inherits": "2.0.4", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "toidentifier": "1.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-errors/node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/iconv-lite": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz", + "integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==", + "license": "MIT", + "optional": true, + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "license": "ISC", + "optional": true + }, + "node_modules/ipaddr.js": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", + "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-4.0.0.tgz", + "integrity": "sha512-O4L094N2/dZ7xqVdrXhh9r1KODPJpFms8B5sGdJLPy664AgvXsreZUyCQQNItZRDlYug4xStLjNp/sz3HvBowQ==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-promise": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-4.0.0.tgz", + "integrity": "sha512-hvpoI6korhJMnej285dSg6nu1+e6uxs7zG3BYAm5byqDsgJNWwxzM6z6iZiAgQR4TJ30JmBTOwqZUw3WlyH3AQ==", + "license": "MIT", + "optional": true + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "license": "ISC", + "optional": true + }, + "node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "license": "MIT" + }, + "node_modules/listr2": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/listr2/-/listr2-9.0.1.tgz", + "integrity": "sha512-SL0JY3DaxylDuo/MecFeiC+7pedM0zia33zl0vcjgwcq1q1FWWF1To9EIauPbl8GbMCU0R2e0uJ8bZunhYKD2g==", + "license": "MIT", + "dependencies": { + "cli-truncate": "^4.0.0", + "colorette": "^2.0.20", + "eventemitter3": "^5.0.1", + "log-update": "^6.1.0", + "rfdc": "^1.4.1", + "wrap-ansi": "^9.0.0" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/log-update": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/log-update/-/log-update-6.1.0.tgz", + "integrity": "sha512-9ie8ItPR6tjY5uYJh8K/Zrv/RMZ5VOlOWvtZdEHYSTFKZfIBPQa9tOAEeAWhd+AnIneLJ22w5fjOYtoutpWq5w==", + "license": "MIT", + "dependencies": { + "ansi-escapes": "^7.0.0", + "cli-cursor": "^5.0.0", + "slice-ansi": "^7.1.0", + "strip-ansi": "^7.1.0", + "wrap-ansi": "^9.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/log-update/node_modules/is-fullwidth-code-point": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-5.0.0.tgz", + "integrity": "sha512-OVa3u9kkBbw7b8Xw5F9P+D/T9X+Z4+JruYVNapTjPYZYUznQ5YfWeFkOj606XYYW8yugTfC8Pj0hYqvi4ryAhA==", + "license": "MIT", + "dependencies": { + "get-east-asian-width": "^1.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/log-update/node_modules/slice-ansi": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-7.1.0.tgz", + "integrity": "sha512-bSiSngZ/jWeX93BqeIAbImyTbEihizcwNjFoRUIY/T1wWQsfsm2Vw1agPKylXvQTU7iASGdHhyqRlqQzfz+Htg==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.2.1", + "is-fullwidth-code-point": "^5.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/chalk/slice-ansi?sponsor=1" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/media-typer": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-1.1.0.tgz", + "integrity": "sha512-aisnrDP4GNe06UcKFnV5bfMNPBUw4jsLGaWwWfnH3v02GnBuXX2MCVn5RbrWo0j3pczUilYblq7fQ7Nw2t5XKw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/merge-descriptors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-2.0.0.tgz", + "integrity": "sha512-Snk314V5ayFLhp3fkUREub6WtjBfPdCPY1Ln8/8munuLuiYhsABgBVWsozAG+MWMbVEvcdcpbi9R7ww22l9Q3g==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/mime-db": { + "version": "1.54.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.54.0.tgz", + "integrity": "sha512-aU5EJuIN2WDemCcAp2vFBfp/m4EAhWJnUNSSw0ixs7/kXbd6Pg64EmwJkNdFhB8aWt1sH2CTXrLxo/iAGV3oPQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-3.0.1.tgz", + "integrity": "sha512-xRc4oEhT6eaBpU1XF7AjpOFD+xQmXNB5OVKwp4tqCuBpHLS/ZbBDrc07mYTDqVMg6PfxUjjNp85O6Cd2Z/5HWA==", + "license": "MIT", + "optional": true, + "dependencies": { + "mime-db": "^1.54.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mimic-function": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/mimic-function/-/mimic-function-5.0.1.tgz", + "integrity": "sha512-VP79XUPxV2CigYP3jWwAUFSku2aKqBH7uTAapFWCBqutsbmDo96KY5o8uh6U+/YSIn5OxJnXp73beVkpqMIGhA==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/negotiator": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-1.0.0.tgz", + "integrity": "sha512-8Ofs/AUQh8MaEcrlq5xOX0CQ9ypTF5dl78mjlMNfOK08fzpgTHQRQPBxcPlEtIw0yRpws+Zo/3r+5WRby7u3Gg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-inspect": { + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/on-finished": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", + "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", + "license": "MIT", + "optional": true, + "dependencies": { + "ee-first": "1.1.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "license": "ISC", + "optional": true, + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/onetime": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/onetime/-/onetime-7.0.0.tgz", + "integrity": "sha512-VXJjc87FScF88uafS3JllDgvAm+c/Slfz06lorj2uAY34rlUu0Nt+v8wreiImcrgAjjIHp1rXpTDlLOGw29WwQ==", + "license": "MIT", + "dependencies": { + "mimic-function": "^5.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/openai": { + "version": "5.12.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-5.12.0.tgz", + "integrity": "sha512-vUdt02xiWgOHiYUmW0Hj1Qu9OKAiVQu5Bd547ktVCiMKC1BkB5L3ImeEnCyq3WpRKR6ZTaPgekzqdozwdPs7Lg==", + "license": "Apache-2.0", + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, + "node_modules/parseurl": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", + "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/path-to-regexp": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-8.2.0.tgz", + "integrity": "sha512-TdrF7fW9Rphjq4RjrW0Kp2AW0Ahwu9sRGTkS6bvDi0SCwZlEZYmcfDbEsTz8RVk0EHIS/Vd1bv3JhG+1xZuAyQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=16" + } + }, + "node_modules/pkce-challenge": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-5.0.0.tgz", + "integrity": "sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=16.20.0" + } + }, + "node_modules/proxy-addr": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", + "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", + "license": "MIT", + "optional": true, + "dependencies": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/qs": { + "version": "6.14.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz", + "integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==", + "license": "BSD-3-Clause", + "optional": true, + "dependencies": { + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">=0.6" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/range-parser": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", + "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/raw-body": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-3.0.0.tgz", + "integrity": "sha512-RmkhL8CAyCRPXCE28MMH0z2PNWQBNk2Q09ZdxM9IOOXwxwZbN+qbWaatPkdkWIKL2ZVDImrN/pK5HTRz2PcS4g==", + "license": "MIT", + "optional": true, + "dependencies": { + "bytes": "3.1.2", + "http-errors": "2.0.0", + "iconv-lite": "0.6.3", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/restore-cursor": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/restore-cursor/-/restore-cursor-5.1.0.tgz", + "integrity": "sha512-oMA2dcrw6u0YfxJQXm342bFKX/E4sG9rbTzO9ptUcR/e8A33cHuvStiYOwH7fszkZlZ1z/ta9AAoPk2F4qIOHA==", + "license": "MIT", + "dependencies": { + "onetime": "^7.0.0", + "signal-exit": "^4.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/rfdc": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/rfdc/-/rfdc-1.4.1.tgz", + "integrity": "sha512-q1b3N5QkRUWUl7iyylaaj3kOpIT0N2i9MqIEQXP73GVsN9cw3fdx8X63cEmWhJGi2PPCF23Ijp7ktmd39rawIA==", + "license": "MIT" + }, + "node_modules/router": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/router/-/router-2.2.0.tgz", + "integrity": "sha512-nLTrUKm2UyiL7rlhapu/Zl45FwNgkZGaCpZbIHajDYgwlJCOzLSk+cIPAnsEqV955GjILJnKbdQC1nVPz+gAYQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "debug": "^4.4.0", + "depd": "^2.0.0", + "is-promise": "^4.0.0", + "parseurl": "^1.3.3", + "path-to-regexp": "^8.0.0" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "optional": true + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "license": "MIT", + "optional": true + }, + "node_modules/send": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/send/-/send-1.2.0.tgz", + "integrity": "sha512-uaW0WwXKpL9blXE2o0bRhoL2EGXIrZxQ2ZQ4mgcfoBxdFmQold+qWsD2jLrfZ0trjKL6vOw0j//eAwcALFjKSw==", + "license": "MIT", + "optional": true, + "dependencies": { + "debug": "^4.3.5", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "fresh": "^2.0.0", + "http-errors": "^2.0.0", + "mime-types": "^3.0.1", + "ms": "^2.1.3", + "on-finished": "^2.4.1", + "range-parser": "^1.2.1", + "statuses": "^2.0.1" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/serve-static": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-2.2.0.tgz", + "integrity": "sha512-61g9pCh0Vnh7IutZjtLGGpTA355+OPn2TyDv/6ivP2h/AdAVX9azsoxmg2/M6nZeQZNYBEwIcsne1mJd9oQItQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "parseurl": "^1.3.3", + "send": "^1.2.0" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", + "license": "ISC", + "optional": true + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "license": "MIT", + "optional": true, + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/side-channel": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", + "license": "MIT", + "optional": true, + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "license": "MIT", + "optional": true, + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "license": "MIT", + "optional": true, + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "license": "MIT", + "optional": true, + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/slice-ansi": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-5.0.0.tgz", + "integrity": "sha512-FC+lgizVPfie0kkhqUScwRu1O/lF6NOgJmlCgK+/LYxDCTk8sGelYaHDhFcDN+Sn3Cv+3VSa4Byeo+IMCzpMgQ==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.0.0", + "is-fullwidth-code-point": "^4.0.0" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/slice-ansi?sponsor=1" + } + }, + "node_modules/statuses": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", + "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/string-width": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-7.2.0.tgz", + "integrity": "sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^10.3.0", + "get-east-asian-width": "^1.0.0", + "strip-ansi": "^7.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/strip-ansi": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz", + "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=0.6" + } + }, + "node_modules/type-is": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-2.0.1.tgz", + "integrity": "sha512-OZs6gsjF4vMp32qrCbiVSkrFmXtG/AZhY3t0iAMrMBiAZyV9oALtXO8hsrHbMXF9x6L3grlFuwW2oAz7cav+Gw==", + "license": "MIT", + "optional": true, + "dependencies": { + "content-type": "^1.0.5", + "media-typer": "^1.1.0", + "mime-types": "^3.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/undici-types": { + "version": "7.10.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.10.0.tgz", + "integrity": "sha512-t5Fy/nfn+14LuOc2KNYg75vZqClpAiqscVvMygNnlsHBFpSXdJaYtXMcdNLpl/Qvc3P2cB3s6lOV51nqsFq4ag==", + "license": "MIT" + }, + "node_modules/unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "license": "BSD-2-Clause", + "optional": true, + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "license": "ISC", + "optional": true, + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/wrap-ansi": { + "version": "9.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-9.0.0.tgz", + "integrity": "sha512-G8ura3S+3Z2G+mkgNRq8dqaFZAuxfsxpBB8OCTGRTCtp+l/v9nbFNmCUP1BZMts3G1142MsZfn6eeUKrr4PD1Q==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.2.1", + "string-width": "^7.0.0", + "strip-ansi": "^7.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "license": "ISC", + "optional": true + }, + "node_modules/ws": { + "version": "8.18.3", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", + "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/zod": { + "version": "3.25.67", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz", + "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==", + "license": "MIT", + "optional": true, + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.6", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.6.tgz", + "integrity": "sha512-h/z3PKvcTcTetyjl1fkj79MHNEjm+HpD6NXheWjzOekY7kV+lwDYnHw+ivHkijnCSMz1yJaWBD9vu/Fcmk+vEg==", + "license": "ISC", + "optional": true, + "peerDependencies": { + "zod": "^3.24.1" + } + } + } +} diff --git a/compatibility-test/package.json b/compatibility-test/package.json new file mode 100644 index 00000000..66d51439 --- /dev/null +++ b/compatibility-test/package.json @@ -0,0 +1,11 @@ +{ + "type": "module", + "dependencies": { + "@openai/agents": "^0.0.15", + "ajv": "^8.17.1", + "listr2": "^9.0.1" + }, + "scripts": { + "start": "tsx index.ts" + } +} diff --git a/compatibility-test/providers.ts b/compatibility-test/providers.ts new file mode 100644 index 00000000..91f58e0f --- /dev/null +++ b/compatibility-test/providers.ts @@ -0,0 +1,15 @@ +export const PROVIDERS = { + vllm: { + apiBaseUrl: "http://localhost:8000/v1", + apiKey: "vllm", + apiType: ["responses", "chat"], // choose from responses, chat, or both + modelName: "openai/gpt-oss-120b", + providerDetails: { + // add any provider-specific details here. These will be passed as part of every request + // for example to fix the provider for openrouter, you can do: + // provider: { + // only: ["example"], + // }, + }, + }, +}; diff --git a/compatibility-test/runCase.ts b/compatibility-test/runCase.ts new file mode 100644 index 00000000..fd066c0c --- /dev/null +++ b/compatibility-test/runCase.ts @@ -0,0 +1,331 @@ +import { + Agent, + Runner, + OpenAIResponsesModel, + OpenAIChatCompletionsModel, + RunResult, + StreamedRunResult, + FunctionTool, + setTracingDisabled, +} from "@openai/agents"; +import { Ajv } from "ajv"; +import { OpenAI } from "openai"; +import { PROVIDERS } from "./providers"; +import { TOOLS_MAP } from "./tools"; + +setTracingDisabled(true); + +const ajv = new Ajv(); + +export type Case = { + tool_name: string; + input: string; + expected_arguments: string; + instructions?: string; +}; + +// Summary shape for each apiType +export type RunCaseSummary = { + apiType: string; + success: boolean; + validResponse: boolean; + validEvents?: boolean; + details: Record; + history: any[]; + successToolCall: boolean; + toolCallingDetails: Record; +}; + +export async function runCase( + provider: string, + caseData: Case, + { + maxTurns, + streaming, + strict, + }: { maxTurns: number; streaming: boolean; strict: boolean } +): Promise { + const config = PROVIDERS[provider]; + if (!config) { + throw new Error( + `Provider ${provider} not found. Valid providers are: ${Object.keys( + PROVIDERS + ).join(", ")}` + ); + } + + const agent = new Agent({ + name: caseData.tool_name, + instructions: caseData.instructions, + tools: [TOOLS_MAP[caseData.tool_name]], + }); + + const client = new OpenAI({ + apiKey: config.apiKey, + baseURL: config.apiBaseUrl, + }); + + const summaries: RunCaseSummary[] = []; + + for (const apiType of config.apiType) { + const runner = new Runner({ + model: + apiType === "responses" + ? new OpenAIResponsesModel(client, config.modelName) + : new OpenAIChatCompletionsModel(client, config.modelName), + modelSettings: { + providerData: config.providerDetails ?? {}, + }, + }); + + let result: RunResult | StreamedRunResult; + let streamedEvents: any[] | undefined = undefined; + if (streaming) { + result = await runner.run(agent, caseData.input, { + stream: streaming, + maxTurns: maxTurns, + }); + if (result instanceof StreamedRunResult) { + // Collect streaming events if applicable + streamedEvents = []; + for await (const event of result) { + if (event.type === "raw_model_stream_event") { + if (event.data.type === "model") { + streamedEvents.push(event.data.event); + } + } + } + await result.completed; + } + } else { + result = await runner.run(agent, caseData.input, { + maxTurns: maxTurns, + }); + } + + const { success: successToolCall, details: toolCallingDetails } = + testToolCall(apiType, caseData, result, strict); + + const { validResponse, details } = testOutputData( + apiType, + result.rawResponses, + streaming + ); + + const { validEvents, details: eventsDetails } = streaming + ? testEvents(apiType, streamedEvents) + : { validEvents: true, details: {} }; + + let success = successToolCall && validResponse; + if (streaming) { + success = success && validEvents; + } + const summary: RunCaseSummary = { + apiType, + success, + validResponse, + validEvents, + details: { + ...details, + ...eventsDetails, + }, + history: result?.rawResponses.map((entry) => entry.providerData) ?? [], + successToolCall, + toolCallingDetails, + }; + + summaries.push(summary); + } + + return summaries; +} + +function testToolCall(apiType, caseData, result, strict) { + let details: Record = {}; + result.newItems.forEach((item) => { + // for this test for now we only care if the tool is called at least once + if (details.calledToolAtLeastOnce) { + return; + } + + const isToolCall = item.type === "tool_call_item"; + if (isToolCall) { + if (item.rawItem.type === "function_call") { + if (item.rawItem.name === caseData.tool_name) { + const validate = ajv.compile( + (TOOLS_MAP[caseData.tool_name] as FunctionTool).parameters + ); + const valid = validate(JSON.parse(item.rawItem.arguments)); + details.calledToolWithRightSchema = valid; + details.calledToolAtLeastOnce = true; + + if (details.calledToolWithRightSchema) { + const parsedArguments = JSON.parse(item.rawItem.arguments); + const expectedArguments = JSON.parse(caseData.expected_arguments); + details.calledToolWithRightArguments = deepEqual( + parsedArguments, + expectedArguments + ); + if (!details.calledToolWithRightArguments) { + if (details.calledToolWithRightSchema) { + details.warning = `Tool call with wrong arguments but correct schema. Check logs for full details. Not failing this test. Parsed: ${JSON.stringify( + parsedArguments + )} Expected: ${JSON.stringify(expectedArguments)}`; + } + details.actualArguments = parsedArguments; + details.expectedArguments = expectedArguments; + } + } + } + } + } + }); + + return { + success: + !!details.calledToolAtLeastOnce && + !!details.calledToolWithRightSchema && + (!strict || !!details.calledToolWithRightArguments), + details, + }; +} + +function testEvents(apiType, events) { + // In an ideal world we would check all the events to follow and reconstruct the final response + // and then compare it against the final response in the response.completed event + // for now we just check that certain events are present + + let details: Record = {}; + let validEvents: boolean = false; + + if (apiType === "chat") { + let hasReasoningDeltas = false; + for (const event of events) { + hasReasoningDeltas = + hasReasoningDeltas || + (typeof event.choices[0].delta.reasoning === "string" && + event.choices[0].delta.reasoning.length > 0); + } + details.hasReasoningDeltas = hasReasoningDeltas; + validEvents = hasReasoningDeltas; + } + + if (apiType === "responses") { + let hasReasoningDeltaEvents = false; + let hasReasoningDoneEvents = false; + for (const event of events) { + if (event.type === "raw_model_stream_event") { + if (event.data.type === "model") { + if (event.data.event.type === "response.reasoning_text.delta") { + hasReasoningDeltaEvents = true; + } + if (event.data.event.type === "response.reasoning_text.done") { + hasReasoningDoneEvents = true; + } + } + } + } + + details.hasReasoningDeltaEvents = hasReasoningDeltaEvents; + details.hasReasoningDoneEvents = hasReasoningDoneEvents; + validEvents = + details.hasReasoningDeltaEvents && details.hasReasoningDoneEvents; + } + + return { + validEvents, + details, + }; +} + +function testOutputData(apiType, rawResponses, streaming) { + let details: Record = {}; + let validResponse: boolean = false; + + if (apiType === "chat") { + for (const response of rawResponses) { + if (streaming && !response.providerData) { + // with Chat Completions we don't have a final response object that's native so we skip this test + return { + validResponse: true, + details: { + skippedBecauseStreaming: true, + }, + }; + } + + // this is the actual HTTP response from the provider + // Since it's not guaranteed that every response has a reasoning field, we check if it's present + // at least once across all responses + const data = response.providerData; + const message = data.choices[0].message; + if (message.role === "assistant" && !message.refusal) { + details.hasReasoningField = + details.hasReasoningField || + ("reasoning" in message && typeof message.reasoning === "string"); + details.hasReasoningContentField = + details.hasReasoningContentField || + ("reasoning_content" in message && + typeof message.reasoning_content === "string"); + + validResponse = + validResponse || + (details.hasReasoningField && message.reasoning.length > 0); + } + } + } else if (apiType === "responses") { + // this is the actual HTTP response from the provider + const data = rawResponses[0].providerData; + for (const item of data.output) { + // Since it's not guaranteed that every response has a reasoning field, we check if it's present + // at least once across all responses + + if (item.type === "reasoning") { + details.hasReasoningContentArray = Array.isArray(item.content); + details.hasReasoningContentArrayLength = item.content.length > 0; + details.hasReasoningContentArrayItemType = item.content.every( + (item) => item.type === "reasoning_text" + ); + details.hasReasoningContentArrayItemText = item.content.every( + (item) => item.text.length > 0 + ); + + validResponse = + details.hasReasoningContentArray && + details.hasReasoningContentArrayLength && + details.hasReasoningContentArrayItemType && + details.hasReasoningContentArrayItemText; + } + } + } + + return { + validResponse, + details, + }; +} + +function deepEqual(a: any, b: any): boolean { + if (a === b) return true; + if (typeof a !== typeof b) return false; + if (a && b && typeof a === "object") { + if (Array.isArray(a) !== Array.isArray(b)) return false; + if (Array.isArray(a)) { + if (a.length !== b.length) return false; + for (let i = 0; i < a.length; i++) { + if (!deepEqual(a[i], b[i])) return false; + } + return true; + } else { + const aKeys = Object.keys(a); + const bKeys = Object.keys(b); + if (aKeys.length !== bKeys.length) return false; + for (const key of aKeys) { + if (!b.hasOwnProperty(key)) return false; + if (!deepEqual(a[key], b[key])) return false; + } + return true; + } + } + return false; +} diff --git a/compatibility-test/tools.ts b/compatibility-test/tools.ts new file mode 100644 index 00000000..d2d4db6e --- /dev/null +++ b/compatibility-test/tools.ts @@ -0,0 +1,156 @@ +import { Tool, tool } from "@openai/agents"; + +function convertToTool(toolData: any) { + return tool({ + name: toolData.name, + description: toolData.description, + parameters: toolData.parameters, + execute: async (parameters) => { + return toolData.output; + }, + strict: false, + }); +} + +export const TOOLS = [ + { + type: "function", + name: "get_weather", + description: "Get the weather for a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The location to get the weather for", + }, + }, + required: ["location"], + additionalProperties: false, + }, + output: '{"weather":"sunny"}', + }, + { + type: "function", + name: "get_system_health", + description: + "Returns the current health status of the LLM runtime—use before critical operations to verify the service is live.", + parameters: { type: "object", properties: {} }, + output: '{"status":"ok","uptime_seconds":372045}', + }, + { + type: "function", + name: "markdown_to_html", + description: + "Converts a Markdown string to sanitized HTML—use when you need browser-renderable output.", + parameters: { + type: "object", + properties: { + markdown: { type: "string", description: "Raw Markdown content" }, + }, + required: ["markdown"], + additionalProperties: false, + }, + output: '{"html":"

Hello World

This is great.

"}', + }, + { + type: "function", + name: "detect_language", + description: + "Identifies the ISO language code of the supplied text—use for routing text to language-specific models.", + parameters: { + type: "object", + properties: { + text: { + type: "string", + description: "Text whose language should be detected", + }, + }, + required: ["text"], + additionalProperties: false, + }, + output: '{"language":"de","confidence":0.98}', + }, + { + type: "function", + name: "generate_chart", + description: + "Creates a base64-encoded PNG chart from tabular data—use for quick visualizations inside chat.", + parameters: { + type: "object", + properties: { + data: { + type: "array", + items: { type: "array", items: { type: "number" } }, + description: "2-D numeric data matrix", + }, + chart_type: { + type: "string", + enum: ["line", "bar", "scatter"], + description: "Type of chart to generate", + }, + title: { + type: "string", + description: "Chart title", + default: "", + }, + x_label: { + type: "string", + description: "Label for the x-axis", + default: "", + }, + y_label: { + type: "string", + description: "Label for the y-axis", + default: "", + }, + }, + required: ["data", "chart_type"], + additionalProperties: false, + }, + output: '{"image_png_base64":"iVBORw0KGgoAAAANSUhEUgAA..."}', + }, + { + type: "function", + name: "query_database", + description: + "Runs a parameterized SQL SELECT on the internal analytics DB—use for lightweight data look-ups.", + parameters: { + type: "object", + properties: { + table: { type: "string", description: "Table name to query" }, + columns: { + type: "array", + items: { type: "string" }, + description: "Columns to return", + }, + filters: { + type: "string", + description: "SQL WHERE clause without the word WHERE", + default: "", + }, + limit: { + type: "integer", + minimum: 1, + maximum: 10000, + description: "Max rows to return", + default: 100, + }, + order_by: { + type: "string", + description: "Column to order by (optional)", + default: "", + }, + }, + required: ["table", "columns"], + additionalProperties: false, + }, + output: + '{"rows":[{"id":1,"email":"user@example.com"},{"id":2,"email":"foo@bar.com"}],"row_count":2}', + }, +]; + +export const TOOLS_MAP = TOOLS.reduce((acc, tool) => { + acc[tool.name] = convertToTool(tool); + return acc; +}, {} as Record); diff --git a/pyproject.toml b/pyproject.toml index 77d8d26c..25942405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ ] readme = "README.md" requires-python = ">=3.12,<3.13" -version = "0.0.1" +version = "0.0.2" [project.optional-dependencies] triton = ["triton", "safetensors>=0.5.3", "torch>=2.7.0"] From a4f98eaabfb804f6b251d44b57feb6805af93bcd Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Mon, 11 Aug 2025 14:29:08 -0700 Subject: [PATCH 39/91] Update awesome-gpt-oss.md --- awesome-gpt-oss.md | 1 + 1 file changed, 1 insertion(+) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index c8a57a22..75cdba42 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -36,6 +36,7 @@ This is a list of guides and resources to help you get started with the gpt-oss - vLLM - [How to run gpt-oss with vLLM](https://cookbook.openai.com/articles/gpt-oss/run-vllm) + - [vLLM & gpt-oss recipies](https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html) - NVIDIA - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/run-nvidia) - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md) From e73da24d5be14b69f633c7b37679a834710a8165 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:18:08 -0700 Subject: [PATCH 40/91] fix: Add channel parameter to PythonTool response handling (#33) Signed-off-by: Xinyuan Tong --- gpt_oss/tools/python_docker/docker_tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py index f2d9183b..7067c1e1 100644 --- a/gpt_oss/tools/python_docker/docker_tool.py +++ b/gpt_oss/tools/python_docker/docker_tool.py @@ -92,9 +92,10 @@ def tool_config(self) -> ToolNamespaceConfig: def _make_response( self, output: str, + channel: str | None = None, ) -> Message: content = TextContent(text=output) - return self.make_response(content=content) + return self.make_response(content=content, channel=channel) def make_response( self, From 4a8a22eef39d85d37f1dc48bae59f8271aeebd45 Mon Sep 17 00:00:00 2001 From: CivaaBTW <98655324+CivaaBTW@users.noreply.github.com> Date: Tue, 12 Aug 2025 18:11:22 +0200 Subject: [PATCH 41/91] Fix: Corrected typos across 3 files in gpt-oss directory (#115) --- gpt_oss/evals/healthbench_eval.py | 4 ++-- gpt_oss/metal/include/gpt-oss/functions.h | 2 +- gpt_oss/responses_api/inference/ollama.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gpt_oss/evals/healthbench_eval.py b/gpt_oss/evals/healthbench_eval.py index 1862898b..09d184c1 100644 --- a/gpt_oss/evals/healthbench_eval.py +++ b/gpt_oss/evals/healthbench_eval.py @@ -8,7 +8,7 @@ You can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command: - To evaluate physician ideal completions: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completions` -- To evaluate reference model completions used by physicians: `python -m gpt_oss.evalshealthbench_eval --run_mode=physician_completion_references` +- To evaluate reference model completions used by physicians: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completion_references` """ import argparse @@ -48,7 +48,7 @@ # Instructions Return a json object with the following fields: "explanation" and "criteria_met". - The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item. -- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met. +- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true if all of the criteria are met. - One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria. # Example 1 diff --git a/gpt_oss/metal/include/gpt-oss/functions.h b/gpt_oss/metal/include/gpt-oss/functions.h index 9966493c..085ebe0d 100644 --- a/gpt_oss/metal/include/gpt-oss/functions.h +++ b/gpt_oss/metal/include/gpt-oss/functions.h @@ -218,7 +218,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens( * * On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of * cached tokens in the num_tokens_out argument. - * On failure, returns an error code and leaves the values specified by tokend_out and num_tokens_out unchanged. + * On failure, returns an error code and leaves the values specified by tokens_out and num_tokens_out unchanged. */ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens( gptoss_context_t context, diff --git a/gpt_oss/responses_api/inference/ollama.py b/gpt_oss/responses_api/inference/ollama.py index cab54adf..35eb1b2f 100644 --- a/gpt_oss/responses_api/inference/ollama.py +++ b/gpt_oss/responses_api/inference/ollama.py @@ -1,5 +1,5 @@ """ -NOTE: this is a stiched together implementation that uses Ollama for inference. It's primarily used +NOTE: this is a stitched together implementation that uses Ollama for inference. It's primarily used for testing and development. It does not leverage any prompt caching or other optimizations and can therefore be slow between turns. """ From 3e8be309966f51c6eb027bf61fef0b058375977c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 12 Aug 2025 09:15:19 -0700 Subject: [PATCH 42/91] fix editable build (#113) --- _build/gpt_oss_build_backend/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_build/gpt_oss_build_backend/backend.py b/_build/gpt_oss_build_backend/backend.py index 3255db6c..5cd76bdf 100644 --- a/_build/gpt_oss_build_backend/backend.py +++ b/_build/gpt_oss_build_backend/backend.py @@ -87,7 +87,7 @@ def prepare_metadata_for_build_wheel( # Optional hooks def build_editable( - editable_directory: str, config_settings: Mapping[str, Any] | None = None + editable_directory: str, config_settings: Mapping[str, Any] | None = None, metadata_directory: str | None = None ) -> str: be = _backend() fn = getattr(be, "build_editable", None) From 0c83ebe33bf1db1802393c7dbc88535a6f25fa2a Mon Sep 17 00:00:00 2001 From: Okey Amy Date: Tue, 12 Aug 2025 17:17:15 +0100 Subject: [PATCH 43/91] docs: add table of contents to README.md (#106) --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 039922e8..7d7be330 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,21 @@ We're releasing two flavors of these open models: Both models were trained using our [harmony response format][harmony] and should only be used with this format; otherwise, they will not work correctly. +## Table of Contents +- [Highlights](#highlights) +- [Inference examples](#inference-examples) +- [About this repository](#about-this-repository) +- [Setup](#setup) +- [Download the model](#download-the-model) +- [Reference PyTorch implementation](#reference-pytorch-implementation) +- [Reference Triton implementation (single GPU)](#reference-triton-implementation-single-gpu) +- [Reference Metal implementation](#reference-metal-implementation) +- [Harmony format & tools](#harmony-format--tools) +- [Clients](#clients) +- [Tools](#tools) +- [Other details](#other-details) +- [Contributing](#contributing) + ### Highlights - **Permissive Apache 2.0 license:** Build freely without copyleft restrictions or patent risk—ideal for experimentation, customization, and commercial deployment. From 9dd466c7019c68144fbfaefd279103f0857acc8c Mon Sep 17 00:00:00 2001 From: Okey Amy Date: Tue, 12 Aug 2025 17:18:09 +0100 Subject: [PATCH 44/91] fix: Markdown linting and cleanup (#107) * docs: add table of contents to README.md * fix: clean up markdown files --- awesome-gpt-oss.md | 8 ++++---- gpt-oss-mcp-server/README.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index 75cdba42..aba8a2b2 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -41,7 +41,7 @@ This is a list of guides and resources to help you get started with the gpt-oss - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/run-nvidia) - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md) - AMD - - [Running the Latest Open Models from OpenAI on AMD AI Hardware](https://rocm.blogs.amd.com/ecosystems-and-partners/openai-day-0/README.html) + - [Running the Latest Open Models from OpenAI on AMD AI Hardware](https://rocm.blogs.amd.com/ecosystems-and-partners/openai-day-0/README.html) ### Cloud @@ -50,18 +50,18 @@ This is a list of guides and resources to help you get started with the gpt-oss - [gpt-oss-120b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-120b) - [gpt-oss-20b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-20b) - [gpt-oss with built-in web search on GroqCloud](https://console.groq.com/docs/browser-search) - - [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution) + - [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution) - [Responses API on Groq](https://console.groq.com/docs/responses-api) - NVIDIA - [NVIDIA launch blog post](https://blogs.nvidia.com/blog/openai-gpt-oss/) - [NVIDIA & gpt-oss developer launch blog post](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/) - Use [gpt-oss-120b](https://build.nvidia.com/openai/gpt-oss-120b) and [gpt-oss-20b](https://build.nvidia.com/openai/gpt-oss-20b) on NVIDIA's Cloud - Cloudflare - - [Cloudflare & gpt-oss launch blog post](http://blog.cloudflare.com/openai-gpt-oss-on-workers-ai) + - [Cloudflare & gpt-oss launch blog post](https://blog.cloudflare.com/openai-gpt-oss-on-workers-ai) - [gpt-oss-120b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-120b) - [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b) - AMD - - [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot) + - [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot) ## Examples & Tutorials diff --git a/gpt-oss-mcp-server/README.md b/gpt-oss-mcp-server/README.md index 6326b2e7..10aedd5f 100644 --- a/gpt-oss-mcp-server/README.md +++ b/gpt-oss-mcp-server/README.md @@ -1,8 +1,8 @@ # MCP Servers for gpt-oss reference tools This directory contains MCP servers for the reference tools in the [gpt-oss](https://github.com/openai/gpt-oss) repository. -You can set up these tools behind MCP servers and use them in your applications. -For inference service that integrates with MCP, you can also use these as reference tools. +You can set up these tools behind MCP servers and use them in your applications. +For inference service that integrates with MCP, you can also use these as reference tools. In particular, this directory contains a `build-system-prompt.py` script that will generate exactly the same system prompt as `reference-system-prompt.py`. The build system prompt script show case all the care needed to automatically discover the tools and construct the system prompt before feeding it into Harmony. @@ -22,8 +22,8 @@ mcp run -t sse browser_server.py:mcp mcp run -t sse python_server.py:mcp ``` -You can now use MCP inspector to play with the tools. +You can now use MCP inspector to play with the tools. Once opened, set SSE to `http://localhost:8001/sse` and `http://localhost:8000/sse` respectively. -To compare the system prompt and see how to construct it via MCP service discovery, see `build-system-prompt.py`. +To compare the system prompt and see how to construct it via MCP service discovery, see `build-system-prompt.py`. This script will generate exactly the same system prompt as `reference-system-prompt.py`. From 750cfe908fdc9dd1f0e9bfcd92a4bb1adb0aa81c Mon Sep 17 00:00:00 2001 From: Adarsh <44583199+adarsh-crafts@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:55:06 +0530 Subject: [PATCH 45/91] docs: add docstrings to utility and helper functions (#97) Co-authored-by: Adarsh N <44583199+adarshn656@users.noreply.github.com> --- gpt_oss/tools/simple_browser/page_contents.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/gpt_oss/tools/simple_browser/page_contents.py b/gpt_oss/tools/simple_browser/page_contents.py index e1e1b951..6fffd3f1 100644 --- a/gpt_oss/tools/simple_browser/page_contents.py +++ b/gpt_oss/tools/simple_browser/page_contents.py @@ -64,6 +64,7 @@ class Tokens: def get_domain(url: str) -> str: + """Extracts the domain from a URL.""" if "http" not in url: # If `get_domain` is called on a domain, add a scheme so that the # original domain is returned instead of the empty string. @@ -72,12 +73,14 @@ def get_domain(url: str) -> str: def multiple_replace(text: str, replacements: dict[str, str]) -> str: + """Performs multiple string replacements using regex pass.""" regex = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys()))) return regex.sub(lambda mo: replacements[mo.group(1)], text) @functools.lru_cache(maxsize=1024) def mark_lines(text: str) -> str: + """Adds line numbers (ex: 'L0:') to the beginning of each line in a string.""" # Split the string by newline characters lines = text.split("\n") @@ -88,16 +91,19 @@ def mark_lines(text: str) -> str: @functools.cache def _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]: + """Gets the character lengths of all tokens in the specified TikToken vocabulary.""" encoding = tiktoken.get_encoding(enc_name) return [len(encoding.decode([i])) for i in range(encoding.n_vocab)] def warmup_caches(enc_names: list[str]) -> None: + """Warm up the cache by computing token length lists for the given TikToken encodings.""" for _ in map(_tiktoken_vocabulary_lengths, enc_names): pass def _replace_special_chars(text: str) -> str: + """Replaces specific special characters with visually similar alternatives.""" replacements = { "【": "〖", "】": "〗", @@ -110,16 +116,19 @@ def _replace_special_chars(text: str) -> str: def merge_whitespace(text: str) -> str: + """Replace newlines with spaces and merge consecutive whitespace into a single space.""" text = text.replace("\n", " ") text = re.sub(r"\s+", " ", text) return text def arxiv_to_ar5iv(url: str) -> str: + """Converts an arxiv.org URL to its ar5iv.org equivalent.""" return re.sub(r"arxiv.org", r"ar5iv.org", url) def _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]: + """Processes all anchor tags in the HTML, replaces them with a custom format and returns an ID-to-URL mapping.""" cur_domain = get_domain(cur_url) urls: dict[str, str] = {} urls_rev: dict[str, str] = {} @@ -156,10 +165,12 @@ def _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]: def _get_text(node: lxml.html.HtmlElement) -> str: + """Extracts all text from an HTML element and merges it into a whitespace-normalized string.""" return merge_whitespace(" ".join(node.itertext())) def _remove_node(node: lxml.html.HtmlElement) -> None: + """Removes a node from its parent in the lxml tree.""" node.getparent().remove(node) @@ -172,6 +183,7 @@ def _escape_md_section(text: str, snob: bool = False) -> str: def html_to_text(html: str) -> str: + """Converts an HTML string to clean plaintext.""" html = re.sub(HTML_SUP_RE, r"^{\2}", html) html = re.sub(HTML_SUB_RE, r"_{\2}", html) # add spaces between tags such as table cells @@ -195,6 +207,7 @@ def html_to_text(html: str) -> str: def _remove_math(root: lxml.html.HtmlElement) -> None: + """Removes all elements from the lxml tree.""" for node in root.findall(".//math"): _remove_node(node) @@ -209,6 +222,7 @@ def remove_unicode_smp(text: str) -> str: def replace_node_with_text(node: lxml.html.HtmlElement, text: str) -> None: + """Replaces an lxml node with a text string while preserving surrounding text.""" previous = node.getprevious() parent = node.getparent() tail = node.tail or "" @@ -224,6 +238,7 @@ def replace_images( base_url: str, session: aiohttp.ClientSession | None, ) -> None: + """Finds all image tags and replaces them with numbered placeholders (includes alt/title if available).""" cnt = 0 for img_tag in root.findall(".//img"): image_name = img_tag.get("alt", img_tag.get("title")) From 1dcd7d06992d101e25962f6f3a61f2c4b4c11671 Mon Sep 17 00:00:00 2001 From: harshalmore31 Date: Tue, 12 Aug 2025 22:01:33 +0530 Subject: [PATCH 46/91] feat: Add Gradio chat interface example (#89) * feat: implement Gradio chatbot interface with function calling and browser search capabilities * updates ! --- examples/gradio/gradio_chat.py | 247 +++++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 examples/gradio/gradio_chat.py diff --git a/examples/gradio/gradio_chat.py b/examples/gradio/gradio_chat.py new file mode 100644 index 00000000..da742bd3 --- /dev/null +++ b/examples/gradio/gradio_chat.py @@ -0,0 +1,247 @@ +import json +import requests +import gradio as gr + +DEFAULT_FUNCTION_PROPERTIES = """ +{ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] +} +""".strip() + +def chat_with_model(message, history, model_choice, instructions, effort, use_functions, + function_name, function_description, function_parameters, + use_browser_search, temperature, max_output_tokens, debug_mode): + + if not message.strip(): + return history, "" + + # Append user message and empty assistant placeholder (idiomatic Gradio pattern) + history = history + [[message, ""]] + + # Build messages list from history (excluding the empty assistant placeholder) + messages = [] + + # Convert history to messages format (excluding the last empty assistant message) + for user_msg, assistant_msg in history[:-1]: + if user_msg: + messages.append({ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": user_msg}] + }) + if assistant_msg: + messages.append({ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": assistant_msg}] + }) + + # Add current user message + messages.append({ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": message}] + }) + + # Prepare tools + tools = [] + if use_functions: + try: + tools.append({ + "type": "function", + "name": function_name, + "description": function_description, + "parameters": json.loads(function_parameters), + }) + except json.JSONDecodeError: + pass + + if use_browser_search: + tools.append({"type": "browser_search"}) + + # Get URL based on model (matching streamlit logic) + options = ["large", "small"] + URL = ("http://localhost:8081/v1/responses" if model_choice == options[1] + else "http://localhost:8000/v1/responses") + + try: + response = requests.post( + URL, + json={ + "input": messages, + "stream": True, + "instructions": instructions, + "reasoning": {"effort": effort}, + "metadata": {"__debug": debug_mode}, + "tools": tools, + "temperature": temperature, + "max_output_tokens": max_output_tokens, + }, + stream=True, + ) + + full_content = "" + text_delta = "" + current_output_index = 0 + in_reasoning = False + + for line in response.iter_lines(decode_unicode=True): + if not line or not line.startswith("data:"): + continue + data_str = line[len("data:"):].strip() + if not data_str: + continue + + try: + data = json.loads(data_str) + except Exception: + continue + + event_type = data.get("type", "") + output_index = data.get("output_index", 0) + + if event_type == "response.output_item.added": + current_output_index = output_index + output_type = data.get("item", {}).get("type", "message") + text_delta = "" + + if output_type == "reasoning": + if not in_reasoning: + full_content += "🤔 **Thinking...**\n" + in_reasoning = True + elif output_type == "message": + if in_reasoning: + full_content += "\n\n" + in_reasoning = False + + elif event_type == "response.reasoning_text.delta": + delta = data.get("delta", "") + full_content += delta + + # Update last assistant message (idiomatic Gradio pattern) + history[-1][1] = full_content + yield history, "" + + elif event_type == "response.output_text.delta": + delta = data.get("delta", "") + full_content += delta + + # Update last assistant message (idiomatic Gradio pattern) + history[-1][1] = full_content + yield history, "" + + elif event_type == "response.output_item.done": + item = data.get("item", {}) + if item.get("type") == "function_call": + function_call_text = f"\n\n🔨 Called `{item.get('name')}`\n**Arguments**\n```json\n{item.get('arguments', '')}\n```" + full_content += function_call_text + + # Update last assistant message (idiomatic Gradio pattern) + history[-1][1] = full_content + yield history, "" + + elif item.get("type") == "web_search_call": + web_search_text = f"\n\n🌐 **Web Search**\n```json\n{json.dumps(item.get('action', {}), indent=2)}\n```\n✅ Done" + full_content += web_search_text + + # Update last assistant message (idiomatic Gradio pattern) + history[-1][1] = full_content + yield history, "" + + elif event_type == "response.completed": + response_data = data.get("response", {}) + if debug_mode: + debug_info = response_data.get("metadata", {}).get("__debug", "") + if debug_info: + full_content += f"\n\n**Debug**\n```\n{debug_info}\n```" + + # Update last assistant message (idiomatic Gradio pattern) + history[-1][1] = full_content + yield history, "" + break + + # Return final history and empty string to clear textbox + return history, "" + + except Exception as e: + error_message = f"❌ Error: {str(e)}" + history[-1][1] = error_message + return history, "" + + +# Create the Gradio interface +with gr.Blocks(title="💬 Chatbot") as demo: + gr.Markdown("# 💬 Chatbot") + + with gr.Row(): + with gr.Column(scale=3): + chatbot = gr.Chatbot(height=500) + + with gr.Row(): + msg = gr.Textbox(placeholder="Type a message...", scale=4, show_label=False) + send_btn = gr.Button("Send", scale=1) + + clear_btn = gr.Button("Clear Chat") + + with gr.Column(scale=1): + model_choice = gr.Radio(["large", "small"], value="small", label="Model") + + instructions = gr.Textbox( + label="Instructions", + value="You are a helpful assistant that can answer questions and help with tasks.", + lines=3 + ) + + effort = gr.Radio(["low", "medium", "high"], value="medium", label="Reasoning effort") + + gr.Markdown("#### Functions") + use_functions = gr.Checkbox(label="Use functions", value=False) + + with gr.Column(visible=False) as function_group: + function_name = gr.Textbox(label="Function name", value="get_weather") + function_description = gr.Textbox( + label="Function description", + value="Get the weather for a given city" + ) + function_parameters = gr.Textbox( + label="Function parameters", + value=DEFAULT_FUNCTION_PROPERTIES, + lines=6 + ) + + # Conditional browser search (matching Streamlit logic) + # In Streamlit: if "show_browser" in st.query_params: + # For Gradio, we'll always show it (simplified) + gr.Markdown("#### Built-in Tools") + use_browser_search = gr.Checkbox(label="Use browser search", value=False) + + temperature = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Temperature") + max_output_tokens = gr.Slider(1000, 20000, value=1024, step=100, label="Max output tokens") + + debug_mode = gr.Checkbox(label="Debug mode", value=False) + + # Event handlers + def toggle_function_group(use_funcs): + return gr.update(visible=use_funcs) + + use_functions.change(toggle_function_group, use_functions, function_group) + + # Chat functionality + inputs = [msg, chatbot, model_choice, instructions, effort, use_functions, + function_name, function_description, function_parameters, + use_browser_search, temperature, max_output_tokens, debug_mode] + + msg.submit(chat_with_model, inputs, [chatbot, msg]) + send_btn.click(chat_with_model, inputs, [chatbot, msg]) + clear_btn.click(lambda: [], outputs=chatbot) + + +if __name__ == "__main__": + demo.launch() \ No newline at end of file From 4195fb33a659dd796289347f56e1e27a5de21a31 Mon Sep 17 00:00:00 2001 From: SyedaAnshrahGillani <90501474+SyedaAnshrahGillani@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:32:27 +0500 Subject: [PATCH 47/91] Feat: add command-line arguments for backend parameters (#86) --- gpt_oss/generate.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index dfaaa6f1..c0755805 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -19,10 +19,10 @@ def main(args): from gpt_oss.torch.utils import init_distributed from gpt_oss.triton.model import TokenGenerator as TritonGenerator device = init_distributed() - generator = TritonGenerator(args.checkpoint, context=4096, device=device) + generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device) case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator - generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2) + generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) case _: raise ValueError(f"Invalid backend: {args.backend}") @@ -31,9 +31,9 @@ def main(args): max_tokens = None if args.limit == 0 else args.limit for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True): tokens.append(token) - decoded_token = tokenizer.decode([token]) + token_text = tokenizer.decode([token]) print( - f"Generated token: {repr(decoded_token)}, logprob: {logprob}" + f"Generated token: {repr(token_text)}, logprob: {logprob}" ) @@ -78,6 +78,18 @@ def main(args): choices=["triton", "torch", "vllm"], help="Inference backend", ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=2, + help="Tensor parallel size for vLLM backend", + ) + parser.add_argument( + "--context-length", + type=int, + default=4096, + help="Context length for Triton backend", + ) args = parser.parse_args() main(args) From 421dbe99eb1b710cff30f6426f73ffd21e78f77f Mon Sep 17 00:00:00 2001 From: xiejw Date: Tue, 12 Aug 2025 09:33:39 -0700 Subject: [PATCH 48/91] added GPTOSS_BUILD_METAL=1 for metal. (#84) Co-authored-by: Dominik Kundel --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d7be330..b800e877 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ Additionally we are providing a reference implementation for Metal to run on App The implementation will get automatically compiled when running the `.[metal]` installation on an Apple Silicon device: ```shell -pip install -e ".[metal]" +GPTOSS_BUILD_METAL=1 pip install -e ".[metal]" ``` To perform inference you'll need to first convert the SafeTensor weights from Hugging Face into the right format using: From 83e1b367679fc09f0d9c6a335e70fe4de2c6b18d Mon Sep 17 00:00:00 2001 From: Adarsh <44583199+adarsh-crafts@users.noreply.github.com> Date: Tue, 12 Aug 2025 22:04:27 +0530 Subject: [PATCH 49/91] chore: remove unused WeatherParams class and import (#82) Removes the unused `WeatherParams` class and its corresponding `BaseModel` import from `pydantic`. This code was not referenced anywhere in the script. Removing it reduces clutter. Co-authored-by: Adarsh N <44583199+adarshn656@users.noreply.github.com> --- examples/agents-sdk-python/example.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/agents-sdk-python/example.py b/examples/agents-sdk-python/example.py index 06aa7d70..b4f08d63 100644 --- a/examples/agents-sdk-python/example.py +++ b/examples/agents-sdk-python/example.py @@ -13,11 +13,6 @@ function_tool, ) from agents.mcp import MCPServerStdio -from pydantic import BaseModel - - -class WeatherParams(BaseModel): - location: str async def prompt_user(question: str) -> str: From 1246ff8e4e0b4a87ff306c917721ad24222dddd9 Mon Sep 17 00:00:00 2001 From: Adarsh <44583199+adarsh-crafts@users.noreply.github.com> Date: Tue, 12 Aug 2025 22:05:04 +0530 Subject: [PATCH 50/91] refactor: rename search_tool for clarity (#81) The function originally named `search_tool` has a specific purpose of fetching weather information, but the name was too generic. Renaming it to `get_weather` increases readability and maintainability for future developers. Co-authored-by: Adarsh N <44583199+adarshn656@users.noreply.github.com> --- examples/agents-sdk-python/example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/agents-sdk-python/example.py b/examples/agents-sdk-python/example.py index b4f08d63..af0be603 100644 --- a/examples/agents-sdk-python/example.py +++ b/examples/agents-sdk-python/example.py @@ -54,14 +54,14 @@ async def main(): # Define weather tool @function_tool - async def search_tool(location: str) -> str: + async def get_weather(location: str) -> str: return f"The weather in {location} is sunny." # Create agent agent = Agent( name="My Agent", instructions="You are a helpful assistant.", - tools=[search_tool], + tools=[get_weather], model="gpt-oss:20b-test", mcp_servers=[mcp_server], ) From 359b3ffd975b5bc8adc04ce54a0a4ea162f413eb Mon Sep 17 00:00:00 2001 From: Om Alve Date: Tue, 12 Aug 2025 22:07:43 +0530 Subject: [PATCH 51/91] fix invalid import in build-system-prompt.py (#32) fixes invalid tokenizer import in build-system-prompt.py --- gpt-oss-mcp-server/build-system-prompt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpt-oss-mcp-server/build-system-prompt.py b/gpt-oss-mcp-server/build-system-prompt.py index 58e953ad..1aca256a 100644 --- a/gpt-oss-mcp-server/build-system-prompt.py +++ b/gpt-oss-mcp-server/build-system-prompt.py @@ -1,7 +1,7 @@ import datetime import asyncio -from gpt_oss.tokenizer import tokenizer +from gpt_oss.tokenizer import get_tokenizer from openai_harmony import ( Conversation, @@ -66,6 +66,7 @@ def post_process_tools_description( return list_tools_result +tokenizer = get_tokenizer() tools_urls = [ "http://localhost:8001/sse", # browser From 100011251bfaa036f31093a410c7330a9d1edd82 Mon Sep 17 00:00:00 2001 From: Shubhankar Dixit Date: Tue, 12 Aug 2025 22:08:31 +0530 Subject: [PATCH 52/91] Update simple_browser_tool.py (#40) * Update simple_browser_tool.py Fixed a typo in variable name: lenghts -> lengths * Update api_server.py Added validation to get_reasoning_effort, documenting the function and raising a clear ValueError when an unsupported effort string is provided instead of returning None Updated the /v1/responses endpoint to catch invalid reasoning effort inputs and return an HTTP 422 error, preventing server-side crashes --- gpt_oss/responses_api/api_server.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py index 5ea7fc15..63504a65 100644 --- a/gpt_oss/responses_api/api_server.py +++ b/gpt_oss/responses_api/api_server.py @@ -64,10 +64,11 @@ def get_reasoning_effort(effort: Literal["low", "medium", "high"]) -> ReasoningEffort: if effort == "low": return ReasoningEffort.LOW - elif effort == "medium": + if effort == "medium": return ReasoningEffort.MEDIUM - elif effort == "high": + if effort == "high": return ReasoningEffort.HIGH + raise ValueError(f"Invalid reasoning effort: {effort}") def is_not_builtin_tool(recipient: str) -> bool: @@ -784,7 +785,12 @@ def _ensure_list(inp): ) if body.reasoning is not None: - reasoning_effort = get_reasoning_effort(body.reasoning.effort) + try: + + reasoning_effort = get_reasoning_effort(body.reasoning.effect) + except ValueError as e: + from fastapi import HTTP Exception + raise HTTPException(status_code=422, detail=str(e)) system_message_content = system_message_content.with_reasoning_effort(reasoning_effort) if use_browser_tool: From fa67988739c83199d6604c0d296ccac8439f9d5c Mon Sep 17 00:00:00 2001 From: BobHuang <30999153+sBobHuang@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:39:19 +0800 Subject: [PATCH 53/91] triton implementation need install triton_kernels (#45) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b800e877..08ed09b4 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,7 @@ git clone https://github.com/triton-lang/triton cd triton/ pip install -r python/requirements.txt pip install -e . --verbose --no-build-isolation +pip install -e python/triton_kernels # Install the gpt-oss triton implementation pip install -e ".[triton]" From 906a0efe82c7d92facc0437583922fe419a33716 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Tue, 12 Aug 2025 09:40:31 -0700 Subject: [PATCH 54/91] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 25942405..a52efdde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ ] readme = "README.md" requires-python = ">=3.12,<3.13" -version = "0.0.2" +version = "0.0.3" [project.optional-dependencies] triton = ["triton", "safetensors>=0.5.3", "torch>=2.7.0"] From a02c2ce83e4c9b46773ff13c1eecf91b5d29a937 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Tue, 12 Aug 2025 21:38:17 -0700 Subject: [PATCH 55/91] Update README.md Clarify MXFP4 quantization callout --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 08ed09b4..1796d38d 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Both models were trained using our [harmony response format][harmony] and should - **Full chain-of-thought:** Provides complete access to the model's reasoning process, facilitating easier debugging and greater trust in outputs. This information is not intended to be shown to end users. - **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning. - **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs. -- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, allowing `gpt-oss-120b` to run on a single 80GB GPU (like NVIDIA H100 or AMD MI300X) and `gpt-oss-20b` to run within 16GB of memory. +- **MXFP4 quantization:** The models were post-trained with MXFP4 quantization of the MoE weights, making `gpt-oss-120b` run on a single 80GB GPU (like NVIDIA H100 or AMD MI300X) and the `gpt-oss-20b` model run within 16GB of memory. All evals were performed with the same MXFP4 quantization. ### Inference examples From f8d21ad40617e8ddd857f2ec08812ca8ed06de9c Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Wed, 13 Aug 2025 10:09:58 -0700 Subject: [PATCH 56/91] fix streamlit & ollama demo. Add python tool (#131) * fix streamlit & ollama demo * add python to streamlit demo --- examples/streamlit/streamlit_chat.py | 46 ++- gpt_oss/responses_api/api_server.py | 309 ++++++++++++++++----- gpt_oss/responses_api/events.py | 59 +++- gpt_oss/responses_api/inference/ollama.py | 39 ++- gpt_oss/responses_api/types.py | 46 ++- gpt_oss/tools/python_docker/docker_tool.py | 13 +- 6 files changed, 392 insertions(+), 120 deletions(-) diff --git a/examples/streamlit/streamlit_chat.py b/examples/streamlit/streamlit_chat.py index f36c37b8..9185ff67 100644 --- a/examples/streamlit/streamlit_chat.py +++ b/examples/streamlit/streamlit_chat.py @@ -48,12 +48,10 @@ st.sidebar.subheader("Functions") use_functions = st.sidebar.toggle("Use functions", value=False) -if "show_browser" in st.query_params: - st.sidebar.subheader("Built-in Tools") +st.sidebar.subheader("Built-in Tools") # Built-in Tools section - use_browser_search = st.sidebar.toggle("Use browser search", value=False) -else: - use_browser_search = False +use_browser_search = st.sidebar.toggle("Use browser search", value=False) +use_code_interpreter = st.sidebar.toggle("Use code interpreter", value=False) if use_functions: function_name = st.sidebar.text_input("Function name", value="get_weather") @@ -72,7 +70,7 @@ "Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01 ) max_output_tokens = st.sidebar.slider( - "Max output tokens", min_value=1000, max_value=20000, value=1024, step=100 + "Max output tokens", min_value=1, max_value=131072, value=30000, step=1000 ) st.sidebar.divider() debug_mode = st.sidebar.toggle("Debug mode", value=False) @@ -89,6 +87,7 @@ else "http://localhost:8000/v1/responses" ) + def trigger_fake_tool(container): function_output = st.session_state.get("function_output", "It's sunny!") last_call = st.session_state.messages[-1] @@ -117,6 +116,8 @@ def run(container): # Add browser_search tool if checkbox is checked if use_browser_search: tools.append({"type": "browser_search"}) + if use_code_interpreter: + tools.append({"type": "code_interpreter"}) response = requests.post( URL, json={ @@ -134,7 +135,7 @@ def run(container): text_delta = "" - current_output_index = 0 + _current_output_index = 0 for line in response.iter_lines(decode_unicode=True): if not line or not line.startswith("data:"): continue @@ -149,7 +150,7 @@ def run(container): event_type = data.get("type", "") output_index = data.get("output_index", 0) if event_type == "response.output_item.added": - current_output_index = output_index + _current_output_index = output_index output_type = data.get("item", {}).get("type", "message") if output_type == "message": output = container.chat_message("assistant") @@ -159,7 +160,13 @@ def run(container): placeholder = output.empty() elif output_type == "web_search_call": output = container.chat_message("web_search_call", avatar="🌐") - output.code(json.dumps(data.get("item", {}).get("action", {}), indent=4), language="json") + output.code( + json.dumps(data.get("item", {}).get("action", {}), indent=4), + language="json", + ) + placeholder = output.empty() + elif output_type == "code_interpreter_call": + output = container.chat_message("code_interpreter_call", avatar="🧪") placeholder = output.empty() text_delta = "" elif event_type == "response.reasoning_text.delta": @@ -178,6 +185,18 @@ def run(container): st.code(item.get("arguments", ""), language="json") if item.get("type") == "web_search_call": placeholder.markdown("✅ Done") + if item.get("type") == "code_interpreter_call": + placeholder.markdown("✅ Done") + elif event_type == "response.code_interpreter_call.in_progress": + try: + placeholder.markdown("⏳ Running") + except Exception: + pass + elif event_type == "response.code_interpreter_call.completed": + try: + placeholder.markdown("✅ Done") + except Exception: + pass elif event_type == "response.completed": response = data.get("response", {}) if debug_mode: @@ -187,7 +206,7 @@ def run(container): st.session_state.messages.extend(response.get("output", [])) if st.session_state.messages[-1].get("type") == "function_call": with container.form("function_output_form"): - function_output = st.text_input( + _function_output = st.text_input( "Enter function output", value=st.session_state.get("function_output", "It's sunny!"), key="function_output", @@ -213,7 +232,9 @@ def run(container): st.markdown(item["text"]) if item.get("annotations"): annotation_lines = "\n".join( - f"- {annotation.get('url')}" for annotation in item["annotations"] if annotation.get("url") + f"- {annotation.get('url')}" + for annotation in item["annotations"] + if annotation.get("url") ) st.caption(f"**Annotations:**\n{annotation_lines}") elif msg.get("type") == "reasoning": @@ -234,6 +255,9 @@ def run(container): with st.chat_message("web_search_call", avatar="🌐"): st.code(json.dumps(msg.get("action", {}), indent=4), language="json") st.markdown("✅ Done") + elif msg.get("type") == "code_interpreter_call": + with st.chat_message("code_interpreter_call", avatar="🧪"): + st.markdown("✅ Done") if render_input: # Input field diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py index 63504a65..2934b011 100644 --- a/gpt_oss/responses_api/api_server.py +++ b/gpt_oss/responses_api/api_server.py @@ -1,8 +1,6 @@ -import asyncio import datetime import uuid from typing import Callable, Literal, Optional -import json from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse @@ -20,29 +18,32 @@ ToolDescription, ) +from gpt_oss.tools.python_docker.docker_tool import PythonTool from gpt_oss.tools.simple_browser import SimpleBrowserTool from gpt_oss.tools.simple_browser.backend import ExaBackend from .events import ( + ResponseCodeInterpreterCallCompleted, + ResponseCodeInterpreterCallInProgress, ResponseCompletedEvent, + ResponseContentPartAdded, + ResponseContentPartDone, ResponseCreatedEvent, - ResponseInProgressEvent, ResponseEvent, + ResponseInProgressEvent, ResponseOutputItemAdded, ResponseOutputItemDone, - ResponseContentPartAdded, - ResponseContentPartDone, - ResponseOutputTextDone, + ResponseOutputTextAnnotationAdded, ResponseOutputTextDelta, - ResponseReasoningTextDone, + ResponseOutputTextDone, ResponseReasoningTextDelta, + ResponseReasoningTextDone, + ResponseWebSearchCallCompleted, ResponseWebSearchCallInProgress, ResponseWebSearchCallSearching, - ResponseWebSearchCallCompleted, - ResponseOutputTextAnnotationAdded ) from .types import ( - UrlCitation, + CodeInterpreterCallItem, Error, FunctionCallItem, Item, @@ -51,11 +52,12 @@ ResponseObject, ResponsesRequest, TextContentItem, + UrlCitation, Usage, - WebSearchCallItem, - WebSearchActionSearch, - WebSearchActionOpenPage, WebSearchActionFind, + WebSearchActionOpenPage, + WebSearchActionSearch, + WebSearchCallItem, ) DEFAULT_TEMPERATURE = 0.0 @@ -72,7 +74,12 @@ def get_reasoning_effort(effort: Literal["low", "medium", "high"]) -> ReasoningE def is_not_builtin_tool(recipient: str) -> bool: - return not recipient.startswith("browser.") and not recipient == "python" and not recipient == "assistant" + return ( + not recipient.startswith("browser.") + and not recipient == "python" + and not recipient == "assistant" + ) + def create_api_server( infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding @@ -90,6 +97,8 @@ def generate_response( previous_response_id: Optional[str] = None, browser_tool: Optional[SimpleBrowserTool] = None, browser_call_ids: Optional[list[str]] = None, + python_tool: Optional[PythonTool] = None, + python_call_ids: Optional[list[str]] = None, ) -> ResponseObject: output = [] error = None @@ -113,9 +122,12 @@ def generate_response( fc_index = 0 browser_tool_index = 0 + python_tool_index = 0 for entry in entries: entry_dict = entry.to_dict() - if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(entry_dict["recipient"]): + if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool( + entry_dict["recipient"] + ): call = entry_dict["content"][0] arguments = call["text"] name = entry_dict["recipient"] @@ -139,12 +151,16 @@ def generate_response( call_id=call_id, ) ) - elif len(entry_dict.get("recipient", "")) > 0 and entry_dict["recipient"].startswith("browser.") and browser_tool is not None: + elif ( + len(entry_dict.get("recipient", "")) > 0 + and entry_dict["recipient"].startswith("browser.") + and browser_tool is not None + ): # Mirror event-based creation of WebSearchCallItems when the browser tool is invoked name = entry_dict["recipient"] call = entry_dict["content"][0] arguments = call["text"] - function_name = name[len("browser."):] + function_name = name[len("browser.") :] # Reconstruct a Message for argument parsing tool_msg = ( @@ -177,7 +193,9 @@ def generate_response( action = None if action is not None: - if browser_call_ids and browser_tool_index < len(browser_call_ids): + if browser_call_ids and browser_tool_index < len( + browser_call_ids + ): web_search_call_id = browser_call_ids[browser_tool_index] else: web_search_call_id = f"ws_{uuid.uuid4().hex}" @@ -189,11 +207,29 @@ def generate_response( action=action, ) ) + elif ( + len(entry_dict.get("recipient", "")) > 0 + and entry_dict["recipient"].startswith("python") + and python_tool is not None + ): + if python_call_ids and python_tool_index < len(python_call_ids): + code_call_id = python_call_ids[python_tool_index] + else: + code_call_id = f"ci_{uuid.uuid4().hex}" + python_tool_index += 1 + output.append( + CodeInterpreterCallItem( + type="code_interpreter_call", + id=code_call_id, + ) + ) elif entry_dict["channel"] == "final": content = [] - for content_entry in entry_dict["content"]: + for content_entry in entry_dict["content"]: if browser_tool: - text_content, annotation_entries, _has_partial_citations = browser_tool.normalize_citations(content_entry["text"]) + text_content, annotation_entries, _has_partial_citations = ( + browser_tool.normalize_citations(content_entry["text"]) + ) annotations = [UrlCitation(**a) for a in annotation_entries] else: text_content = content_entry["text"] @@ -288,7 +324,6 @@ class StreamResponsesEvents: request_body: ResponsesRequest request: Request sequence_number: int - def __init__( self, @@ -301,6 +336,7 @@ def __init__( Callable[[str, ResponsesRequest, ResponseObject], None] ] = None, browser_tool: Optional[SimpleBrowserTool] = None, + python_tool: Optional[PythonTool] = None, ): self.initial_tokens = initial_tokens self.tokens = initial_tokens.copy() @@ -327,6 +363,9 @@ def __init__( self.browser_tool = browser_tool self.use_browser_tool = browser_tool is not None self.browser_call_ids: list[str] = [] + self.python_tool = python_tool + self.use_code_interpreter = python_tool is not None + self.python_call_ids: list[str] = [] def _send_event(self, event: ResponseEvent): event.sequence_number = self.sequence_number @@ -346,6 +385,10 @@ async def run(self): function_call_ids=self.function_call_ids, response_id=self.response_id, previous_response_id=self.request_body.previous_response_id, + browser_tool=self.browser_tool, + browser_call_ids=self.browser_call_ids, + python_tool=self.python_tool, + python_call_ids=self.python_call_ids, ) initial_response.status = "in_progress" yield self._send_event( @@ -368,9 +411,9 @@ async def run(self): sent_output_item_added = False # we use this if the model outputs a citation to buffer until completed - output_delta_buffer = "" + output_delta_buffer = "" # we use this to track the current output text content for things like providing the right indices in citations - current_output_text_content = "" + current_output_text_content = "" current_annotations = [] while True: @@ -387,7 +430,7 @@ async def run(self): self.tokens.append(next_tok) try: self.parser.process(next_tok) - except Exception as e: + except Exception: pass if self.parser.state == StreamState.EXPECT_START: @@ -463,9 +506,17 @@ async def run(self): ) ) if previous_item.channel == "final": - annotations = [UrlCitation(**a) for a in current_annotations] + annotations = [ + UrlCitation(**a) for a in current_annotations + ] if browser_tool: - normalized_text, _annotations, _has_partial_citations = browser_tool.normalize_citations(previous_item.content[0].text) + ( + normalized_text, + _annotations, + _has_partial_citations, + ) = browser_tool.normalize_citations( + previous_item.content[0].text + ) else: normalized_text = previous_item.content[0].text annotations = [] @@ -531,14 +582,26 @@ async def run(self): should_send_output_text_delta = True if browser_tool: # we normalize on the full current text to get the right indices in citations - updated_output_text, annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text_content + output_delta_buffer) + updated_output_text, annotations, has_partial_citations = ( + browser_tool.normalize_citations( + current_output_text_content + output_delta_buffer + ) + ) # remove the current text to get back the delta but now normalized - output_delta_buffer = updated_output_text[len(current_output_text_content):] - + output_delta_buffer = updated_output_text[ + len(current_output_text_content) : + ] + # Filter annotations to only include those whose start_index is not already present in current_annotations # this is to avoid sending duplicate annotations as multiple annotations can't be in the same place - existing_start_indices = {a["start_index"] for a in current_annotations} - new_annotations = [a for a in annotations if a["start_index"] not in existing_start_indices] + existing_start_indices = { + a["start_index"] for a in current_annotations + } + new_annotations = [ + a + for a in annotations + if a["start_index"] not in existing_start_indices + ] for a in new_annotations: current_annotations.append(a) citation = UrlCitation(**a) @@ -555,7 +618,6 @@ async def run(self): if has_partial_citations: should_send_output_text_delta = False - if should_send_output_text_delta: yield self._send_event( ResponseOutputTextDelta( @@ -589,7 +651,9 @@ async def run(self): type="response.content_part.added", output_index=current_output_index, content_index=current_content_index, - part=ReasoningTextContentItem(type="reasoning_text", text=""), + part=ReasoningTextContentItem( + type="reasoning_text", text="" + ), ) ) yield self._send_event( @@ -618,7 +682,7 @@ async def run(self): and last_message.recipient is not None and last_message.recipient.startswith("browser.") ): - function_name = last_message.recipient[len("browser."):] + function_name = last_message.recipient[len("browser.") :] action = None parsed_args = browser_tool.process_arguments(last_message) if function_name == "search": @@ -629,32 +693,42 @@ async def run(self): elif function_name == "open": action = WebSearchActionOpenPage( type="open_page", - url=parsed_args["url"] if "url" in parsed_args else None, + url=( + parsed_args["url"] + if "url" in parsed_args + else None + ), ) elif function_name == "find": action = WebSearchActionFind( type="find", pattern=parsed_args["pattern"], - url=parsed_args["url"] if "url" in parsed_args else None, + url=( + parsed_args["url"] + if "url" in parsed_args + else None + ), ) if action is not None: web_search_call_id = f"ws_{uuid.uuid4().hex}" self.browser_call_ids.append(web_search_call_id) - yield self._send_event(ResponseOutputItemAdded( - type="response.output_item.added", - output_index=current_output_index, - item=WebSearchCallItem( - type="web_search_call", - id=web_search_call_id, - action=action, - ), - )) + yield self._send_event( + ResponseOutputItemAdded( + type="response.output_item.added", + output_index=current_output_index, + item=WebSearchCallItem( + type="web_search_call", + id=web_search_call_id, + action=action, + ), + ) + ) yield self._send_event( ResponseWebSearchCallInProgress( type="response.web_search_call.in_progress", output_index=current_output_index, - id=web_search_call_id + id=web_search_call_id, ) ) @@ -676,10 +750,12 @@ async def run_tool(): new_tokens = encoding.render_conversation_for_completion( Conversation.from_messages(result), Role.ASSISTANT ) - + print(encoding.decode_utf8(new_tokens)) self.output_tokens.append(next_tok) - self.tokens.append(encoding.encode('<|end|>', allowed_special="all")[0]) + self.tokens.append( + encoding.encode("<|end|>", allowed_special="all")[0] + ) for token in new_tokens: self.parser.process(token) @@ -693,19 +769,94 @@ async def run_tool(): id=web_search_call_id, ) ) - yield self._send_event(ResponseOutputItemDone( - type="response.output_item.done", - output_index=current_output_index, - item=WebSearchCallItem( - type="web_search_call", - id=web_search_call_id, - action=action, - ), - )) + yield self._send_event( + ResponseOutputItemDone( + type="response.output_item.done", + output_index=current_output_index, + item=WebSearchCallItem( + type="web_search_call", + id=web_search_call_id, + action=action, + ), + ) + ) current_output_index += 1 self.new_request = True - + + continue + + elif ( + self.use_code_interpreter + and last_message.recipient is not None + and last_message.recipient.startswith("python") + ): + code_call_id = f"ci_{uuid.uuid4().hex}" + self.python_call_ids.append(code_call_id) + yield self._send_event( + ResponseOutputItemAdded( + type="response.output_item.added", + output_index=current_output_index, + item=CodeInterpreterCallItem( + type="code_interpreter_call", + id=code_call_id, + ), + ) + ) + yield self._send_event( + ResponseCodeInterpreterCallInProgress( + type="response.code_interpreter_call.in_progress", + output_index=current_output_index, + id=code_call_id, + ) + ) + + async def run_python_tool(): + results = [] + async for msg in self.python_tool.process(last_message): + results.append(msg) + return results + + result = await run_python_tool() + + print(result) + + new_tokens = encoding.render_conversation_for_completion( + Conversation.from_messages(result), Role.ASSISTANT + ) + + print(encoding.decode_utf8(new_tokens)) + self.output_tokens.append(next_tok) + self.tokens.append( + encoding.encode("<|end|>", allowed_special="all")[0] + ) + + for token in new_tokens: + self.parser.process(token) + self.output_tokens.append(token) + self.tokens.append(token) + + yield self._send_event( + ResponseCodeInterpreterCallCompleted( + type="response.code_interpreter_call.completed", + output_index=current_output_index, + id=code_call_id, + ) + ) + yield self._send_event( + ResponseOutputItemDone( + type="response.output_item.done", + output_index=current_output_index, + item=CodeInterpreterCallItem( + type="code_interpreter_call", + id=code_call_id, + ), + ) + ) + + current_output_index += 1 + self.new_request = True + continue else: @@ -747,6 +898,10 @@ async def generate(body: ResponsesRequest, request: Request): getattr(tool, "type", None) == "browser_search" for tool in (body.tools or []) ) + use_code_interpreter = any( + getattr(tool, "type", None) == "code_interpreter" + for tool in (body.tools or []) + ) if use_browser_tool: backend = ExaBackend( @@ -756,6 +911,11 @@ async def generate(body: ResponsesRequest, request: Request): else: browser_tool = None + if use_code_interpreter: + python_tool = PythonTool() + else: + python_tool = None + if body.previous_response_id: prev = responses_store.get(body.previous_response_id) if prev: @@ -779,22 +939,30 @@ def _ensure_list(inp): body.instructions = prev_req.instructions body.input = merged_input - system_message_content = SystemContent.new().with_conversation_start_date( datetime.datetime.now().strftime("%Y-%m-%d") ) - + if body.reasoning is not None: try: - reasoning_effort = get_reasoning_effort(body.reasoning.effect) + reasoning_effort = get_reasoning_effort(body.reasoning.effort) except ValueError as e: - from fastapi import HTTP Exception + from fastapi import HTTPException + raise HTTPException(status_code=422, detail=str(e)) - system_message_content = system_message_content.with_reasoning_effort(reasoning_effort) + system_message_content = system_message_content.with_reasoning_effort( + reasoning_effort + ) if use_browser_tool: - system_message_content = system_message_content.with_tools(browser_tool.tool_config) + system_message_content = system_message_content.with_tools( + browser_tool.tool_config + ) + if use_code_interpreter: + system_message_content = system_message_content.with_tools( + python_tool.tool_config + ) system_message = Message.from_role_and_content( Role.SYSTEM, system_message_content @@ -818,8 +986,8 @@ def _ensure_list(inp): ) if tools: - developer_message_content = developer_message_content.with_function_tools( - tools + developer_message_content = ( + developer_message_content.with_function_tools(tools) ) developer_message = Message.from_role_and_content( @@ -852,7 +1020,9 @@ def _ensure_list(inp): else: for content_item in item.content: messages.append( - Message.from_role_and_content(item.role, content_item.text) + Message.from_role_and_content( + item.role, content_item.text + ) ) # add final channel to the last assistant message if it's from the assistant if item.role == Role.ASSISTANT: @@ -885,7 +1055,9 @@ def _ensure_list(inp): Message.from_author_and_content( Author.new(Role.TOOL, f"functions.{function_call.name}"), item.output, - ).with_recipient("assistant").with_channel("commentary") + ) + .with_recipient("assistant") + .with_channel("commentary") ) conversation = Conversation.from_messages(messages) @@ -907,6 +1079,7 @@ def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject): response_id=response_id, store_callback=store_callback, browser_tool=browser_tool, + python_tool=python_tool, ) if body.stream: diff --git a/gpt_oss/responses_api/events.py b/gpt_oss/responses_api/events.py index 7adecc64..fed4c6e6 100644 --- a/gpt_oss/responses_api/events.py +++ b/gpt_oss/responses_api/events.py @@ -4,14 +4,15 @@ from pydantic import BaseModel from .types import ( + CodeInterpreterCallItem, FunctionCallItem, Item, ReasoningItem, + ReasoningTextContentItem, ResponseObject, TextContentItem, - ReasoningTextContentItem, - WebSearchCallItem, UrlCitation, + WebSearchCallItem, ) @@ -67,13 +68,25 @@ class ResponseReasoningTextDone(ResponseEvent): class ResponseOutputItemAdded(ResponseEvent): type: Literal["response.output_item.added"] = "response.output_item.added" output_index: int = 0 - item: Union[Item, ReasoningItem, FunctionCallItem, WebSearchCallItem] + item: Union[ + Item, + ReasoningItem, + FunctionCallItem, + WebSearchCallItem, + CodeInterpreterCallItem, + ] class ResponseOutputItemDone(ResponseEvent): type: Literal["response.output_item.done"] = "response.output_item.done" output_index: int = 0 - item: Union[Item, ReasoningItem, FunctionCallItem, WebSearchCallItem] + item: Union[ + Item, + ReasoningItem, + FunctionCallItem, + WebSearchCallItem, + CodeInterpreterCallItem, + ] class ResponseInProgressEvent(ResponseEvent): @@ -105,25 +118,53 @@ class ResponseContentPartDone(ResponseEvent): content_index: int = 0 part: Union[TextContentItem, ReasoningTextContentItem] + class ResponseOutputTextAnnotationAdded(ResponseEvent): - type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added" + type: Literal["response.output_text.annotation.added"] = ( + "response.output_text.annotation.added" + ) item_id: str = "item_1234" output_index: int = 0 content_index: int = 0 annotation_index: int = 0 annotation: UrlCitation + class ResponseWebSearchCallInProgress(ResponseEvent): - type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress" + type: Literal["response.web_search_call.in_progress"] = ( + "response.web_search_call.in_progress" + ) output_index: int = 0 item_id: str = "item_1234" + class ResponseWebSearchCallSearching(ResponseEvent): - type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching" + type: Literal["response.web_search_call.searching"] = ( + "response.web_search_call.searching" + ) output_index: int = 0 item_id: str = "item_1234" + class ResponseWebSearchCallCompleted(ResponseEvent): - type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed" + type: Literal["response.web_search_call.completed"] = ( + "response.web_search_call.completed" + ) output_index: int = 0 - item_id: str = "item_1234" \ No newline at end of file + item_id: str = "item_1234" + + +class ResponseCodeInterpreterCallInProgress(ResponseEvent): + type: Literal["response.code_interpreter_call.in_progress"] = ( + "response.code_interpreter_call.in_progress" + ) + output_index: int = 0 + item_id: str = "item_1234" + + +class ResponseCodeInterpreterCallCompleted(ResponseEvent): + type: Literal["response.code_interpreter_call.completed"] = ( + "response.code_interpreter_call.completed" + ) + output_index: int = 0 + item_id: str = "item_1234" diff --git a/gpt_oss/responses_api/inference/ollama.py b/gpt_oss/responses_api/inference/ollama.py index 35eb1b2f..e0196c6d 100644 --- a/gpt_oss/responses_api/inference/ollama.py +++ b/gpt_oss/responses_api/inference/ollama.py @@ -1,6 +1,6 @@ """ NOTE: this is a stitched together implementation that uses Ollama for inference. It's primarily used -for testing and development. It does not leverage any prompt caching or other optimizations and +for testing and development. It does not leverage any prompt caching or other optimizations and can therefore be slow between turns. """ @@ -8,17 +8,17 @@ import threading import time from typing import Callable, Optional -import requests -from openai_harmony import load_harmony_encoding, HarmonyEncodingName +import requests +from openai_harmony import HarmonyEncodingName, load_harmony_encoding EOS_TOKEN = 200002 # only used on hard timeout # Tunables -POLL_INTERVAL_S = 0.01 # 10ms between buffer checks -CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call -NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS -FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS +POLL_INTERVAL_S = 0.01 # 10ms between buffer checks +CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call +NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS +FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS # Shared state _token_buffer: list[int] = [] @@ -26,9 +26,10 @@ _stream_thread: Optional[threading.Thread] = None _stream_done = threading.Event() _stream_error: Optional[Exception] = None -_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens +_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens _previous_request_tokens: list[int] = [] + def lcp(cache: list[int], inp: list[int]) -> list[int]: i = 0 max_len = min(len(cache), len(inp)) @@ -36,13 +37,16 @@ def lcp(cache: list[int], inp: list[int]) -> list[int]: i += 1 return cache[:i] + def _now(): return time.monotonic() + def _touch_progress(): global _last_progress_ts _last_progress_ts = _now() + def _reset_stream_state(): global _token_buffer, _stream_thread, _stream_error with _buffer_lock: @@ -52,12 +56,14 @@ def _reset_stream_state(): _stream_error = None _touch_progress() + def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]: encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) model_name = checkpoint def _start_stream(token_ids: list[int], temperature: float): prompt_text = encoding.decode(token_ids) + def run(): nonlocal prompt_text, temperature global _stream_error @@ -68,21 +74,13 @@ def run(): try: url = "http://localhost:11434/api/generate" - context = None - if len(_previous_request_tokens) > 0: - context = _previous_request_tokens - # cache_hit = lcp(_previous_request_tokens, token_ids) - # if len(cache_hit) > 0: - # context = cache_hit - # print(f"Cache hit: {encoding.decode(context)}") - # prompt_text = encoding.decode(token_ids[len(context):]) payload = { "model": model_name, "prompt": prompt_text, "stream": True, - "context": context, "options": {"temperature": temperature}, + "raw": True, } with requests.post(url, json=payload, stream=True, timeout=60) as resp: @@ -106,9 +104,6 @@ def run(): _token_buffer.append(EOS_TOKEN) last_len = len(toks) _touch_progress() - context = obj.get("context") - if context and len(context) > 0: - _previous_request_tokens = context break _stream_done.set() @@ -187,6 +182,8 @@ def infer_next_token( # If we reach here, we still haven't got a token—ask the caller to call again soon. # Return a harmless token that the server will replace/ignore if your interface supports it. # If your interface does NOT allow a sentinel, keep the short-blocking behavior above. - return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores + return ( + EOS_TOKEN if False else 0 + ) # replace `0` with a PAD/NOOP token your server ignores return infer_next_token diff --git a/gpt_oss/responses_api/types.py b/gpt_oss/responses_api/types.py index 1d908e34..4ca72c56 100644 --- a/gpt_oss/responses_api/types.py +++ b/gpt_oss/responses_api/types.py @@ -8,6 +8,7 @@ REASONING_EFFORT = ReasoningEffort.LOW DEFAULT_MAX_OUTPUT_TOKENS = 10_000 + class UrlCitation(BaseModel): type: Literal["url_citation"] end_index: int @@ -15,6 +16,7 @@ class UrlCitation(BaseModel): url: str title: str + class TextContentItem(BaseModel): type: Union[Literal["text"], Literal["input_text"], Literal["output_text"]] text: str @@ -61,25 +63,37 @@ class FunctionCallOutputItem(BaseModel): call_id: str = "call_1234" output: str + class WebSearchActionSearch(BaseModel): type: Literal["search"] query: Optional[str] = None + class WebSearchActionOpenPage(BaseModel): type: Literal["open_page"] url: Optional[str] = None + class WebSearchActionFind(BaseModel): type: Literal["find"] pattern: Optional[str] = None url: Optional[str] = None + class WebSearchCallItem(BaseModel): type: Literal["web_search_call"] id: str = "ws_1234" status: Literal["in_progress", "completed", "incomplete"] = "completed" action: Union[WebSearchActionSearch, WebSearchActionOpenPage, WebSearchActionFind] + +class CodeInterpreterCallItem(BaseModel): + type: Literal["code_interpreter_call"] + id: str = "ci_1234" + status: Literal["in_progress", "completed", "incomplete"] = "completed" + input: Optional[str] = None + + class Error(BaseModel): code: str message: str @@ -107,6 +121,10 @@ class BrowserToolConfig(BaseModel): type: Literal["browser_search"] +class CodeInterpreterToolConfig(BaseModel): + type: Literal["code_interpreter"] + + class ReasoningConfig(BaseModel): effort: Literal["low", "medium", "high"] = REASONING_EFFORT @@ -115,11 +133,24 @@ class ResponsesRequest(BaseModel): instructions: Optional[str] = None max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS input: Union[ - str, list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]] + str, + list[ + Union[ + Item, + ReasoningItem, + FunctionCallItem, + FunctionCallOutputItem, + WebSearchCallItem, + ] + ], ] model: Optional[str] = MODEL_IDENTIFIER stream: Optional[bool] = False - tools: Optional[list[Union[FunctionToolDefinition, BrowserToolConfig]]] = [] + tools: Optional[ + list[ + Union[FunctionToolDefinition, BrowserToolConfig, CodeInterpreterToolConfig] + ] + ] = [] reasoning: Optional[ReasoningConfig] = ReasoningConfig() metadata: Optional[Dict[str, Any]] = {} tool_choice: Optional[Literal["auto", "none"]] = "auto" @@ -131,7 +162,16 @@ class ResponsesRequest(BaseModel): class ResponseObject(BaseModel): - output: list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]] + output: list[ + Union[ + Item, + ReasoningItem, + FunctionCallItem, + FunctionCallOutputItem, + WebSearchCallItem, + CodeInterpreterCallItem, + ] + ] created_at: int usage: Optional[Usage] = None status: Literal["completed", "failed", "incomplete", "in_progress"] = "in_progress" diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py index 7067c1e1..c31680ea 100644 --- a/gpt_oss/tools/python_docker/docker_tool.py +++ b/gpt_oss/tools/python_docker/docker_tool.py @@ -1,5 +1,7 @@ # Run this before running the tool: # $ docker image pull python:3.11 +import io +import tarfile from typing import Any, AsyncIterator import docker @@ -11,12 +13,9 @@ TextContent, ToolNamespaceConfig, ) -import io -import tarfile from ..tool import Tool - _docker_client = None @@ -78,15 +77,13 @@ def name(self) -> str: def instruction(self) -> str: return """ Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files). -When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you. +When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you. You have to use print statements to access the output. """.strip() @property def tool_config(self) -> ToolNamespaceConfig: return ToolNamespaceConfig( - name=self.get_tool_name(), - description=self.instruction, - tools=[] + name=self.get_tool_name(), description=self.instruction, tools=[] ) def _make_response( @@ -111,7 +108,7 @@ def make_response( message = Message( author=author, content=[content], - ).with_recipient('assistant') + ).with_recipient("assistant") if channel: message = message.with_channel(channel) From 65b3d6bb0e8d0da8b3b4404e6d82855d49a459f2 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Fri, 15 Aug 2025 00:32:58 +0800 Subject: [PATCH 57/91] Add some links to awesome-gpt-oss.md (#28) * add training resources * update * Update awesome-gpt-oss.md --- awesome-gpt-oss.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index aba8a2b2..f4491d56 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -10,6 +10,7 @@ This is a list of guides and resources to help you get started with the gpt-oss - [Cloud](#cloud) - [Examples / Tutorials](#examples--tutorials) - [Tools](#tools) +- [Training](#training) ## Inference @@ -72,6 +73,12 @@ This is a list of guides and resources to help you get started with the gpt-oss - [Example `python` tool for gpt-oss](./gpt_oss/tools/python_docker/) - [Example `browser` tool for gpt-oss](./gpt_oss/tools/simple_browser/) +## Training + +- [Hugging Face TRL examples](https://github.com/huggingface/gpt-oss-recipes) +- [LlamaFactory examples](https://llamafactory.readthedocs.io/en/latest/advanced/best_practice/gpt-oss.html) +- [Unsloth examples](https://docs.unsloth.ai/basics/gpt-oss-how-to-run-and-fine-tune) + ## Contributing Feel free to open a PR to add your own guides and resources on how to run gpt-oss. We will try to review it and add it here. From 53efd592078904ed13ada04ec11f5fd34601cdb7 Mon Sep 17 00:00:00 2001 From: Will <52027937+liuzhiqi71@users.noreply.github.com> Date: Thu, 14 Aug 2025 11:33:33 -0500 Subject: [PATCH 58/91] fix: fix f-string unmatched '(' bug in streamlit_chat.py (#31) From f018fab6eb9d9d3ea43e21fc139f28fc92ea066a Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Fri, 15 Aug 2025 20:27:56 +0100 Subject: [PATCH 59/91] Fix start_q use in upper bound calculation (#136) Noticed this suspicious use of `start_q` in the attention kernel. When calculating `lo`, `start_q` is not multiplied by `BLOCK_M` but when calculating `hi` it is despite both being in the same units. Looking at how `start_q` is defined, it is defined entirely based on the tensor's shape and doesn't know about `BLOCK_M` so shouldn't be multiplied by `BLOCK_M`: https://github.com/openai/gpt-oss/blob/a02c2ce83e4c9b46773ff13c1eecf91b5d29a937/gpt_oss/triton/model.py#L153 Because the loading and computation are all masked, I don't think this effects numerics but it should reduce the number of iterations of the for loop. --- gpt_oss/triton/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index e9222f0c..6fa77fd2 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -111,9 +111,9 @@ def _attn_fwd( q = tl.load(Q_block_ptr) if BANDWIDTH: - lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), (start_q + start_m + 1) * BLOCK_M + lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: - lo, hi = start_q, (start_q + start_m + 1) * BLOCK_M + lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M # advance the KV block-pointers so they point at `lo` K_block_ptr = tl.advance(K_block_ptr, (0, lo)) @@ -299,4 +299,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q) o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q) - torch.testing.assert_close(o1, o2) \ No newline at end of file + torch.testing.assert_close(o1, o2) From cf427a62e2d80b33b87cbd1ab715730910f5aad0 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Fri, 15 Aug 2025 14:46:30 -0700 Subject: [PATCH 60/91] Process tokens in Context lazily (#138) - Process tokens when Context.process or Context.sample is called rather than when we accumulate max_batch_tokens unprocessed tokens. - Avoid re-creating command buffers for each batch. - Avoid redundantly computing output activations for the last token in each batch. - Avoid invalidating KV cache on Context.reset. Match longest common prefix to tokens in KV cache when tokens are appended following Context.reset. --- gpt_oss/metal/source/context.c | 709 ++++++++++-------- gpt_oss/metal/source/generate.c | 12 +- gpt_oss/metal/source/include/internal/math.h | 4 + gpt_oss/metal/source/include/internal/model.h | 7 - 4 files changed, 396 insertions(+), 336 deletions(-) diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index c356ea44..b58df99a 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -157,9 +157,20 @@ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens( return gptoss_status_success; } -static enum gptoss_status process_batch( - gptoss_context_t context) +// Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0. +// Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1. +// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens. +static enum gptoss_status process_tokens( + gptoss_context_t context, + size_t input_tokens_offset, + size_t num_input_tokens, + size_t num_output_tokens) { + assert(num_input_tokens != 0); + assert(num_input_tokens <= context->max_batch_tokens); + assert(num_output_tokens <= context->max_batch_tokens); + assert(num_input_tokens >= num_output_tokens); + enum gptoss_status status = gptoss_status_success; const struct gptoss_model* model = context->model; struct gptoss_metal_command_buffer command_buffer = {0}; @@ -170,322 +181,348 @@ static enum gptoss_status process_batch( if (status != gptoss_status_success) { goto cleanup; } - status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings( - &command_buffer, - &model->bf16_f32_embeddings_fn, - /*threadgroup_size=*/512, - &context->token_buffer, - (context->num_tokens - context->num_batch_tokens) * sizeof(uint32_t), - &model->shared_weight_buffer, - /*weight_offset=*/0, - &context->residual_activation_buffer, - /*output_offset=*/0, - /*num_tokens=*/context->num_batch_tokens, - /*num_channels=*/model->embedding_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch"); - goto cleanup; - } - for (uint32_t n = 0; n < model->num_blocks; n++) { - const bool last_block = n + 1 == model->num_blocks; - const size_t num_output_tokens = last_block ? 1 : context->num_batch_tokens; + const size_t input_tokens_end = input_tokens_offset + num_input_tokens; + for (size_t input_batch_start = input_tokens_offset; + input_batch_start < input_tokens_end; + input_batch_start += model->max_batch_tokens) + { + const size_t input_batch_size = math_min(model->max_batch_tokens, input_tokens_end - input_batch_start); + const size_t input_batch_end = input_batch_start + input_batch_size; + const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end); - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( + status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings( &command_buffer, - &model->f32_bf16w_rmsnorm_fn, - &context->residual_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, - &context->rmsnorm_activation_buffer, - /*output_offset=*/0, - /*num_tokens=*/context->num_batch_tokens, - /*num_channels=*/model->embedding_dim, - model->rmsnorm_epsilon); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); - goto cleanup; - } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - &command_buffer, - &model->f32_bf16w_matmul_fn, - /*threadgroup_size=*/256, - &context->rmsnorm_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n, + &model->bf16_f32_embeddings_fn, + /*threadgroup_size=*/512, + &context->token_buffer, + input_batch_start * sizeof(uint32_t), &model->shared_weight_buffer, - /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, - &context->qkv_activation_buffer, + /*weight_offset=*/0, + &context->residual_activation_buffer, /*output_offset=*/0, - /*num_tokens=*/context->num_batch_tokens, - /*num_cols=*/model->embedding_dim, - /*num_rows=*/attn_qkv_dim); + /*num_tokens=*/input_batch_size, + /*num_channels=*/model->embedding_dim); if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); + GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch"); goto cleanup; } + for (uint32_t n = 0; n < model->num_blocks; n++) { + const bool last_block = n + 1 == model->num_blocks; + const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size; - status = gptoss_metal_command_buffer_encode_launch_f32_rope( - &command_buffer, - &model->f32_rope_fn, - /*threadgroup_size=*/32, - &context->qkv_activation_buffer, - model->rope_theta, - model->interpolation_scale, - model->yarn_offset, - model->yarn_scale, - model->yarn_multiplier, - context->num_batch_tokens, - model->num_heads, - model->num_kv_heads, - model->head_dim, - /*token_offset=*/context->num_kv_tokens); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch"); - goto cleanup; - } - for (uint32_t t = 0; t < context->num_batch_tokens; t++) { - status = gptoss_metal_command_buffer_encode_copy_buffer( + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( &command_buffer, - &context->qkv_activation_buffer, - /*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float), - &context->kvcache_buffer, - /*output_offset=*/(n * context->max_tokens + context->num_kv_tokens + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float), - /*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float)); + &model->f32_bf16w_rmsnorm_fn, + &context->residual_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, + &context->rmsnorm_activation_buffer, + /*output_offset=*/0, + /*num_tokens=*/input_batch_size, + /*num_channels=*/model->embedding_dim, + model->rmsnorm_epsilon); if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t); + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); goto cleanup; } - } - status = gptoss_metal_command_buffer_encode_launch_f32_sdpa( - &command_buffer, - &model->f32_sdpa_q8_d64_fn, - &context->qkv_activation_buffer, - /*q_offset=*/attn_qkv_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), - &context->kvcache_buffer, - /*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float), - &context->kvcache_buffer, - /*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float), - &model->shared_weight_buffer, - /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n, - &context->sdpa_activation_buffer, /*output_offset=*/0, - /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX, - num_output_tokens, context->num_kv_tokens + (context->num_batch_tokens - num_output_tokens), - model->num_heads, model->num_kv_heads, model->head_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch"); - goto cleanup; - } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( - &command_buffer, - &model->f32_bf16w_matmul_fn, - /*threadgroup_size=*/256, - &context->sdpa_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n, - &model->shared_weight_buffer, - /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, - &context->residual_activation_buffer, - /*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), - /*num_tokens=*/num_output_tokens, - /*num_cols=*/model->num_heads * model->head_dim, - /*num_rows=*/model->embedding_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch"); - goto cleanup; - } + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( + &command_buffer, + &model->f32_bf16w_matmul_fn, + /*threadgroup_size=*/256, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, + &context->qkv_activation_buffer, + /*output_offset=*/0, + /*num_tokens=*/input_batch_size, + /*num_cols=*/model->embedding_dim, + /*num_rows=*/attn_qkv_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); + goto cleanup; + } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( - &command_buffer, - &model->f32_bf16w_rmsnorm_fn, - &context->residual_activation_buffer, - /*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), - &model->shared_weight_buffer, - /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, - &context->rmsnorm_activation_buffer, - /*output_offset=*/0, - num_output_tokens, - model->embedding_dim, - model->rmsnorm_epsilon); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); - goto cleanup; - } + status = gptoss_metal_command_buffer_encode_launch_f32_rope( + &command_buffer, + &model->f32_rope_fn, + /*threadgroup_size=*/32, + &context->qkv_activation_buffer, + model->rope_theta, + model->interpolation_scale, + model->yarn_offset, + model->yarn_scale, + model->yarn_multiplier, + input_batch_size, + model->num_heads, + model->num_kv_heads, + model->head_dim, + /*token_offset=*/input_batch_start); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch"); + goto cleanup; + } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - &command_buffer, - &model->f32_bf16w_matmul_fn, - /*threadgroup_size=*/256, - &context->rmsnorm_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n, - &model->shared_weight_buffer, - /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, - &context->gate_activation_buffer, - /*output_offset=*/0, - /*num_tokens=*/num_output_tokens, - /*num_cols=*/model->embedding_dim, - /*num_rows=*/model->num_experts); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); - goto cleanup; - } + for (uint32_t t = 0; t < input_batch_size; t++) { + status = gptoss_metal_command_buffer_encode_copy_buffer( + &command_buffer, + &context->qkv_activation_buffer, + /*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float), + &context->kvcache_buffer, + /*output_offset=*/(n * context->max_tokens + input_batch_start + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float), + /*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float)); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t); + goto cleanup; + } + } - const char* kernel_name = NULL; - switch (model->num_experts) { - case 32: - kernel_name = "f32_topk_softmax_e32_k4_fn"; - status = gptoss_metal_command_buffer_encode_launch_f32_topk( + if (num_block_output_tokens != 0) { + status = gptoss_metal_command_buffer_encode_launch_f32_sdpa( &command_buffer, - &model->f32_topk_softmax_e32_k4_fn, - &context->gate_activation_buffer, /*input_offset=*/0, - &context->expert_activation_buffer, /*output_offset=*/0, - num_output_tokens, - model->num_experts, - model->num_active_experts); - break; - case 128: - kernel_name = "f32_topk_softmax_e128_k4_fn"; - status = gptoss_metal_command_buffer_encode_launch_f32_topk( + &model->f32_sdpa_q8_d64_fn, + &context->qkv_activation_buffer, + /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &context->kvcache_buffer, + /*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float), + &context->kvcache_buffer, + /*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float), + &model->shared_weight_buffer, + /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n, + &context->sdpa_activation_buffer, + /*output_offset=*/0, + /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX, + num_block_output_tokens, + input_batch_start + input_batch_size - num_block_output_tokens, + model->num_heads, model->num_kv_heads, model->head_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch"); + goto cleanup; + } + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( + &command_buffer, + &model->f32_bf16w_matmul_fn, + /*threadgroup_size=*/256, + &context->sdpa_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, + &context->residual_activation_buffer, + /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + /*num_tokens=*/num_block_output_tokens, + /*num_cols=*/model->num_heads * model->head_dim, + /*num_rows=*/model->embedding_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch"); + goto cleanup; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( + &command_buffer, + &model->f32_bf16w_rmsnorm_fn, + &context->residual_activation_buffer, + /*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &model->shared_weight_buffer, + /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, + &context->rmsnorm_activation_buffer, + /*output_offset=*/0, + num_block_output_tokens, + model->embedding_dim, + model->rmsnorm_epsilon); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); + goto cleanup; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( + &command_buffer, + &model->f32_bf16w_matmul_fn, + /*threadgroup_size=*/256, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, + &context->gate_activation_buffer, + /*output_offset=*/0, + /*num_tokens=*/num_block_output_tokens, + /*num_cols=*/model->embedding_dim, + /*num_rows=*/model->num_experts); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); + goto cleanup; + } + + const char* kernel_name = NULL; + switch (model->num_experts) { + case 32: + kernel_name = "f32_topk_softmax_e32_k4_fn"; + status = gptoss_metal_command_buffer_encode_launch_f32_topk( + &command_buffer, + &model->f32_topk_softmax_e32_k4_fn, + &context->gate_activation_buffer, /*input_offset=*/0, + &context->expert_activation_buffer, /*output_offset=*/0, + num_block_output_tokens, + model->num_experts, + model->num_active_experts); + break; + case 128: + kernel_name = "f32_topk_softmax_e128_k4_fn"; + status = gptoss_metal_command_buffer_encode_launch_f32_topk( + &command_buffer, + &model->f32_topk_softmax_e128_k4_fn, + &context->gate_activation_buffer, /*input_offset=*/0, + &context->expert_activation_buffer, /*output_offset=*/0, + num_block_output_tokens, + model->num_experts, + model->num_active_experts); + break; + default: + status = gptoss_status_unsupported_argument; + GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts); + goto cleanup; + } + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name); + goto cleanup; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( &command_buffer, - &model->f32_topk_softmax_e128_k4_fn, - &context->gate_activation_buffer, /*input_offset=*/0, - &context->expert_activation_buffer, /*output_offset=*/0, - num_output_tokens, - model->num_experts, + &model->f32_mf4w_moe_matmul_swiglu_fn, + /*threadgroup_size=*/512, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_offset=*/0, + &model->block_weight_buffers[n], + /*weight_block_offset=*/0, + &model->block_weight_buffers[n], + /*weight_scale_offset=*/model->mlp_swiglu_scale_offset, + &model->block_weight_buffers[n], + /*bias_offset=*/model->mlp_swiglu_bias_offset, + &context->swiglu_activation_buffer, + /*output_offset=*/0, + model->swiglu_limit, + model->per_expert_block_weight_size, + num_block_output_tokens, + model->num_active_experts, + model->embedding_dim, + model->mlp_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch"); + goto cleanup; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( + &command_buffer, + &model->f32_mf4w_moe_matmul_fn, + /*threadgroup_size=*/512, + &context->swiglu_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_offset=*/0, + &model->block_weight_buffers[n], + /*weight_block_offset=*/model->mlp_out_block_offset, + &model->block_weight_buffers[n], + /*weight_scale_offset=*/model->mlp_out_scale_offset, + &model->block_weight_buffers[n], + /*bias_offset=*/model->mlp_out_bias_offset, + &context->moe_activation_buffer, + /*output_offset=*/0, + model->per_expert_block_weight_size, + num_block_output_tokens, + model->num_active_experts, + model->mlp_dim, + model->embedding_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch"); + goto cleanup; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_accumulate( + &command_buffer, + &model->f32_accumulate_e4_fn, + /*threadgroup_size=*/256, + model->max_threadgroups, + &context->moe_activation_buffer, + /*input_offset=*/0, + &context->expert_activation_buffer, + /*expert_offset=*/0, + &context->residual_activation_buffer, + /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + model->embedding_dim, + num_block_output_tokens, model->num_active_experts); - break; - default: - status = gptoss_status_unsupported_argument; - GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts); - goto cleanup; - } - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name); - goto cleanup; + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch"); + goto cleanup; + } + } } - status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( - &command_buffer, - &model->f32_mf4w_moe_matmul_swiglu_fn, - /*threadgroup_size=*/512, - &context->rmsnorm_activation_buffer, /*input_offset=*/0, - &context->expert_activation_buffer, /*expert_offset=*/0, - &model->block_weight_buffers[n], /*weight_block_offset=*/0, - &model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_swiglu_scale_offset, - &model->block_weight_buffers[n], /*bias_offset=*/model->mlp_swiglu_bias_offset, - &context->swiglu_activation_buffer, /*output_offset=*/0, - model->swiglu_limit, - model->per_expert_block_weight_size, - num_output_tokens, - model->num_active_experts, - model->embedding_dim, - model->mlp_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch"); - goto cleanup; - } + if (output_batch_size != 0) { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( + &command_buffer, + &model->f32_bf16w_rmsnorm_fn, + &context->residual_activation_buffer, + /*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float), + &model->shared_weight_buffer, + /*weight_offset=*/model->rmsnorm_weight_offset, + &context->rmsnorm_activation_buffer, + /*output_offset=*/0, + /*num_tokens=*/output_batch_size, + /*num_channels=*/model->embedding_dim, + model->rmsnorm_epsilon); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); + goto cleanup; + } - status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( - &command_buffer, - &model->f32_mf4w_moe_matmul_fn, - /*threadgroup_size=*/512, - &context->swiglu_activation_buffer, /*input_offset=*/0, - &context->expert_activation_buffer, /*expert_offset=*/0, - &model->block_weight_buffers[n], /*weight_block_offset=*/model->mlp_out_block_offset, - &model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_out_scale_offset, - &model->block_weight_buffers[n], /*bias_offset=*/model->mlp_out_bias_offset, - &context->moe_activation_buffer, /*output_offset=*/0, - model->per_expert_block_weight_size, - num_output_tokens, - model->num_active_experts, - model->mlp_dim, - model->embedding_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch"); - goto cleanup; - } + status = gptoss_metal_command_buffer_encode_fill_buffer( + &command_buffer, + &context->argmax_buffer, + /*offset=*/0, + /*size=*/sizeof(uint64_t) * output_batch_size, + /*fill_value=*/0xFF); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode fill buffer command"); + goto cleanup; + } - status = gptoss_metal_command_buffer_encode_launch_f32_accumulate( - &command_buffer, - &model->f32_accumulate_e4_fn, - /*threadgroup_size=*/256, - model->max_threadgroups, - &context->moe_activation_buffer, - /*input_offset=*/0, - &context->expert_activation_buffer, - /*expert_offset=*/0, - &context->residual_activation_buffer, - /*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), - model->embedding_dim, - num_output_tokens, - model->num_active_experts); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch"); - goto cleanup; + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( + &command_buffer, + &model->f32_bf16w_unembedding_fn, + /*threadgroup_size=*/256, + model->max_threadgroups, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->unembedding_weight_offset, + &context->score_buffer, + /*output_offset=*/0, + &context->argmax_buffer, + /*argmax_offset=*/0, + /*num_tokens=*/output_batch_size, + /*num_cols=*/model->embedding_dim, + /*num_rows=*/model->vocabulary_size); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch"); + goto cleanup; + } } } - const size_t num_output_tokens = 1; - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( - &command_buffer, - &model->f32_bf16w_rmsnorm_fn, - &context->residual_activation_buffer, - /*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float), - &model->shared_weight_buffer, - /*weight_offset=*/model->rmsnorm_weight_offset, - &context->rmsnorm_activation_buffer, - /*output_offset=*/0, - /*num_tokens=*/num_output_tokens, - /*num_channels=*/model->embedding_dim, - model->rmsnorm_epsilon); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); - goto cleanup; - } - - status = gptoss_metal_command_buffer_encode_fill_buffer( - &command_buffer, - &context->argmax_buffer, - /*offset=*/0, - /*size=*/sizeof(uint64_t) * num_output_tokens, - /*fill_value=*/0xFF); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode fill buffer command"); - goto cleanup; - } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( - &command_buffer, - &model->f32_bf16w_unembedding_fn, - /*threadgroup_size=*/256, - model->max_threadgroups, - &context->rmsnorm_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->unembedding_weight_offset, - &context->score_buffer, - /*output_offset=*/0, - &context->argmax_buffer, - /*argmax_offset=*/0, - /*num_tokens=*/num_output_tokens, - /*num_cols=*/model->embedding_dim, - /*num_rows=*/model->vocabulary_size); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch"); - goto cleanup; - } - gptoss_metal_command_buffer_commit(&command_buffer); gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL); - context->num_kv_tokens = context->num_tokens; - context->num_processed_tokens = num_output_tokens; - context->num_batch_tokens = 0; - cleanup: gptoss_metal_command_buffer_release(&command_buffer); return status; @@ -530,17 +567,18 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_chars( } uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr; - input_tokens[context->num_tokens] = best_token; - context->num_tokens++; - num_appended_tokens++; - if (++context->num_batch_tokens == model->max_batch_tokens) { - status = process_batch(context); - if (status != gptoss_status_success) { - break; + if (context->num_kv_tokens > context->num_tokens) { + if (input_tokens[context->num_tokens] != best_token) { + input_tokens[context->num_tokens] = best_token; + + // Invalidate the KV cache starting with the newly added token. + context->num_kv_tokens = context->num_tokens; } - assert(context->num_batch_tokens == 0); + context->num_tokens++; + } else { + input_tokens[context->num_tokens++] = best_token; } - assert(context->num_batch_tokens < model->max_batch_tokens); + num_appended_tokens++; text += best_token_length; text_length -= best_token_length; } @@ -570,27 +608,31 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens( enum gptoss_status status = gptoss_status_success; uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr; while (num_tokens != 0) { - assert(context->num_batch_tokens < model->max_batch_tokens); if (context->num_tokens == context->max_tokens) { status = gptoss_status_context_overflow; break; } - const size_t num_tokens_to_copy = - math_min(context->max_tokens - context->num_tokens, - math_min(num_tokens, model->max_batch_tokens - context->num_batch_tokens)); - memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t)); - context->num_tokens += num_tokens_to_copy; - context->num_batch_tokens += num_tokens_to_copy; - if (context->num_batch_tokens == model->max_batch_tokens) { - status = process_batch(context); - if (status != gptoss_status_success) { - break; + if (context->num_kv_tokens > context->num_tokens) { + const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens); + size_t num_verified_tokens = 0; + for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) { + if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) { + break; + } } - assert(context->num_batch_tokens == 0); + + context->num_tokens += num_verified_tokens; + context->num_kv_tokens = context->num_tokens; + tokens += num_verified_tokens; + num_tokens -= num_verified_tokens; + } else { + const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens); + memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t)); + context->num_tokens += num_tokens_to_copy; + tokens += num_tokens_to_copy; + num_tokens -= num_tokens_to_copy; } - tokens += num_tokens_to_copy; - num_tokens -= num_tokens_to_copy; } return status; @@ -599,10 +641,19 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens( enum gptoss_status GPTOSS_ABI gptoss_context_process( gptoss_context_t context) { - if (context->num_batch_tokens != 0) { - process_batch(context); - } + if (context->num_tokens > context->num_kv_tokens) { + enum gptoss_status status = process_tokens( + context, + /*input_tokens_offset=*/context->num_kv_tokens, + /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens, + /*num_output_tokens=*/0); + if (status != gptoss_status_success) { + return status; + } + context->num_kv_tokens = context->num_tokens; + } + return gptoss_status_success; } @@ -617,11 +668,22 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample( struct gptoss_metal_command_buffer command_buffer = {0}; *token_out = UINT32_MAX; - if (context->num_batch_tokens != 0) { - status = process_batch(context); - if (status != gptoss_status_success) { - return status; - } + if (context->num_kv_tokens < context->num_tokens) { + status = process_tokens( + context, + /*input_tokens_offset=*/context->num_kv_tokens, + /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens, + /*num_output_tokens=*/1); + context->num_kv_tokens = context->num_tokens; + } else { + status = process_tokens( + context, + /*input_tokens_offset=*/context->num_tokens - 1, + /*num_input_tokens=*/1, + /*num_output_tokens=*/1); + } + if (status != gptoss_status_success) { + return status; } if (temperature == 0.0f) { @@ -721,9 +783,10 @@ enum gptoss_status GPTOSS_ABI gptoss_context_reset( gptoss_context_t context) { context->num_tokens = 0; - context->num_kv_tokens = 0; - context->num_batch_tokens = 0; - context->num_processed_tokens = 0; + + // Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared. + // If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache. + return gptoss_status_success; } diff --git a/gpt_oss/metal/source/generate.c b/gpt_oss/metal/source/generate.c index 976046f6..1711410a 100644 --- a/gpt_oss/metal/source/generate.c +++ b/gpt_oss/metal/source/generate.c @@ -162,7 +162,7 @@ struct options parse_options(int argc, char** argv) { static void print_profile() { const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens); const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds); - const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens) - 1; + const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens); const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds); const uint64_t inference_bytes = atomic_load(&globals.inference_bytes); if (num_prefill_tokens != 0 || num_generated_tokens != 0) { @@ -173,10 +173,10 @@ static void print_profile() { num_prefill_tokens, (double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6); } - if (num_generated_tokens > 5) { - printf("Generation speed (%zu tokens, excluding the first 5): %.1f tokens/second\n", - (num_generated_tokens - 5), - (double) (num_generated_tokens - 5) / (double) generation_microseconds * 1.0e+6); + if (num_generated_tokens != 0) { + printf("Generation speed (%zu tokens): %.1f tokens/second\n", + num_generated_tokens, + (double) num_generated_tokens / (double) generation_microseconds * 1.0e+6); } } @@ -292,7 +292,7 @@ int main(int argc, char *argv[]) { const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1); if (previous_num_generated_tokens == 0) { atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time)); - } else if (previous_num_generated_tokens > 5) { + } else { atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp)); } printf("%.*s", (int) token_size, (const char*) token_ptr); diff --git a/gpt_oss/metal/source/include/internal/math.h b/gpt_oss/metal/source/include/internal/math.h index 8d6a9040..d2a7b512 100644 --- a/gpt_oss/metal/source/include/internal/math.h +++ b/gpt_oss/metal/source/include/internal/math.h @@ -15,6 +15,10 @@ inline static size_t math_min(size_t a, size_t b) { return a < b ? a : b; } +inline static size_t math_sub_sat(size_t a, size_t b) { + return a > b ? a - b : 0; +} + static size_t math_round_up_po2(size_t bytes, size_t multiple) { const size_t multiple_mask = multiple - 1; if ((bytes & multiple_mask) != 0) { diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index af24419e..6b477745 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -114,13 +114,6 @@ struct gptoss_context { // Length of the context. size_t max_tokens; - // Current number of tokens in the batch. - // Always in the [0, max_batch_tokens) range. - size_t num_batch_tokens; - // Number of tokens processed in the last batch. - // Activations for [num_batch_tokens, num_processed_tokens) tokens can be accessed from internal structures. - size_t num_processed_tokens; - size_t kvcache_size; size_t allocation_size; From 11c01b2299f957aeca856018e69b9fc5f6bb31af Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Sun, 17 Aug 2025 17:48:25 -0700 Subject: [PATCH 61/91] Create CODEOWNERS --- .github/CODEOWNERS | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..d185e0a5 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,5 @@ +@openai/developer-experience +dkundel-openai +Maratyszcza +scott-oai +volsgd From 56930eb3c93d22b4ccbf1dfafe3ee66d79a52f8f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 17 Aug 2025 17:49:21 -0700 Subject: [PATCH 62/91] Replace '/' with '__' in model names (#142) Refs https://github.com/openai/gpt-oss/issues/141 --- gpt_oss/evals/__main__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gpt_oss/evals/__main__.py b/gpt_oss/evals/__main__.py index bb34e2c3..40d56c12 100644 --- a/gpt_oss/evals/__main__.py +++ b/gpt_oss/evals/__main__.py @@ -155,6 +155,7 @@ def get_evals(eval_name, debug_mode): now = datetime.now() date_str = now.strftime("%Y%m%d_%H%M%S") for model_name, sampler in models.items(): + model_name = model_name.replace("/", "__") for eval_name, eval_obj in evals.items(): result = eval_obj(sampler) # ^^^ how to use a sampler From 64f8a4b1737e8984082d2e4b779ddf2472819916 Mon Sep 17 00:00:00 2001 From: Jay Wang Date: Sun, 17 Aug 2025 17:49:51 -0700 Subject: [PATCH 63/91] Rename `with_browser` to `with_browser_tool` in README (#140) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1796d38d..876a58f0 100644 --- a/README.md +++ b/README.md @@ -360,7 +360,7 @@ Both gpt-oss models were trained with the capability to browse using the `browse #### Usage -To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_browser()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example: +To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_browser_tool()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example: ```python import datetime @@ -386,7 +386,7 @@ if use_browser_tool: # enables the tool system_message_content = system_message_content.with_tools(browser_tool.tool_config) # alternatively you could use the following if your tool is not stateless - system_message_content = system_message_content.with_browser() + system_message_content = system_message_content.with_browser_tool() # construct the system message system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content) From 69a0b1c8a867fa1c4b1277622ed3af5788e2e72e Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Mon, 18 Aug 2025 01:50:06 +0100 Subject: [PATCH 64/91] Update attention kernel to use TensorDescriptor (#137) * Update attention kernel to use TensorDescriptor Block pointer is deprecated in triton, so replacing with the TensorDescriptor API which also enables use of TMA hardware on hopper and newer GPUs. * Add minimum triton version --- gpt_oss/triton/attention.py | 105 ++++++------------------------------ pyproject.toml | 2 +- 2 files changed, 17 insertions(+), 90 deletions(-) diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index 6fa77fd2..bf689055 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -11,6 +11,8 @@ import triton import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + @triton.jit @@ -23,22 +25,6 @@ def _attn_fwd( M, Out, # Start_q, - stride_qz, - stride_qh, - stride_qm, - stride_qk, # - stride_kz, - stride_kh, - stride_kn, - stride_kk, # - stride_vz, - stride_vh, - stride_vn, - stride_vk, # - stride_oz, - stride_oh, - stride_om, - stride_ok, # Z, H, N_Q_CTX, @@ -54,44 +40,6 @@ def _attn_fwd( off_hz = tl.program_id(1) off_z = off_hz // H off_h = off_hz % H - q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh - k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh - v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh - o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh - - # block pointers - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(N_Q_CTX, HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(N_KV_CTX, HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=(1, 0), - ) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(HEAD_DIM, N_KV_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), - order=(0, 1), - ) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(N_Q_CTX, HEAD_DIM), - strides=(stride_om, stride_ok), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) # load attention sinks if Sinks is not None: @@ -108,17 +56,13 @@ def _attn_fwd( acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) # load scales qk_scale = sm_scale - q = tl.load(Q_block_ptr) + q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M - # advance the KV block-pointers so they point at `lo` - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -128,7 +72,7 @@ def _attn_fwd( too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) mask = mask | too_old - k = tl.load(K_block_ptr) + k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T qk = tl.dot(q, k, allow_tf32=False) qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0) @@ -140,22 +84,21 @@ def _attn_fwd( l_ij = tl.sum(p, 1) acc = acc * alpha[:, None] - v = tl.load(V_block_ptr).to(tl.float32) + v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]) + v = v.to(tl.float32) acc = tl.dot(p, v, acc, allow_tf32=False) l_i = l_i * alpha + l_ij m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - sink = tl.math.exp(sink - m_i) z = l_i + sink acc = acc / z[:, None] m_i += tl.math.log(l_i) m_ptrs = M + off_hz * N_Q_CTX + offs_m tl.store(m_ptrs, m_i) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + acc = acc.to(Out.dtype)[None, None, :, :] + Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) class _attention(torch.autograd.Function): @@ -189,35 +132,19 @@ def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q): M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32) grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1) _attn_fwd[grid]( - q, - k, - v, + TensorDescriptor.from_tensor(q, [1, 1, BLOCK_M, HEAD_DIM_K]), + TensorDescriptor.from_tensor(k, [1, 1, BLOCK_N, HEAD_DIM_K]), + TensorDescriptor.from_tensor(v, [1, 1, BLOCK_N, HEAD_DIM_K]), sinks, sm_scale, M, - o, # + TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, HEAD_DIM_K]), start_q, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), # - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # q.shape[0], - q.shape[1], # - N_Q_CTX=n_ctx + m_pad_size, # - N_KV_CTX=n_kv_ctx, # - HEAD_DIM=HEAD_DIM_K, # + q.shape[1], + N_Q_CTX=n_ctx + m_pad_size, + N_KV_CTX=n_kv_ctx, + HEAD_DIM=HEAD_DIM_K, BANDWIDTH=bandwidth, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, diff --git a/pyproject.toml b/pyproject.toml index a52efdde..f84c3a85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.12,<3.13" version = "0.0.3" [project.optional-dependencies] -triton = ["triton", "safetensors>=0.5.3", "torch>=2.7.0"] +triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"] torch = ["safetensors>=0.5.3", "torch>=2.7.0"] metal = ["numpy", "tqdm", "safetensors", "torch"] test = ["pytest>=8.4.1", "httpx>=0.28.1"] From 995e148feeff988fc172686656e0099b1ace0fe7 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Sun, 17 Aug 2025 17:50:58 -0700 Subject: [PATCH 65/91] feat(metail): Parallelize SDPA across multiple simdgroups (#144) --- gpt_oss/metal/source/include/internal/math.h | 21 +- gpt_oss/metal/source/include/internal/metal.h | 7 +- .../metal/source/include/internal/metal.hpp | 10 +- gpt_oss/metal/source/metal-kernels.c | 55 +++-- gpt_oss/metal/source/metal.m | 16 +- gpt_oss/metal/source/sdpa.metal | 198 +++++++++++++++--- 6 files changed, 238 insertions(+), 69 deletions(-) diff --git a/gpt_oss/metal/source/include/internal/math.h b/gpt_oss/metal/source/include/internal/math.h index d2a7b512..06f2b1f1 100644 --- a/gpt_oss/metal/source/include/internal/math.h +++ b/gpt_oss/metal/source/include/internal/math.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -19,11 +20,21 @@ inline static size_t math_sub_sat(size_t a, size_t b) { return a > b ? a - b : 0; } -static size_t math_round_up_po2(size_t bytes, size_t multiple) { +static size_t math_round_down_po2(size_t number, size_t multiple) { + assert(multiple != 0); + assert((multiple & (multiple - 1)) == 0); + + return number & -multiple; +} + +static size_t math_round_up_po2(size_t number, size_t multiple) { + assert(multiple != 0); + assert((multiple & (multiple - 1)) == 0); + const size_t multiple_mask = multiple - 1; - if ((bytes & multiple_mask) != 0) { - bytes |= multiple_mask; - bytes += 1; + if ((number & multiple_mask) != 0) { + number |= multiple_mask; + number += 1; } - return bytes; + return number; } diff --git a/gpt_oss/metal/source/include/internal/metal.h b/gpt_oss/metal/source/include/internal/metal.h index 41194bda..f38190f0 100644 --- a/gpt_oss/metal/source/include/internal/metal.h +++ b/gpt_oss/metal/source/include/internal/metal.h @@ -118,9 +118,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel( size_t num_threadgroups_z, size_t params_size, const void* params, - size_t num_buffers, - const struct gptoss_metal_buffer** buffers, - const size_t* buffer_offsets); + size_t num_device_buffers, + const struct gptoss_metal_buffer** device_buffers, + const size_t* device_buffer_offsets, + size_t threadgroup_buffer_size); enum gptoss_status gptoss_metal_command_buffer_commit( const struct gptoss_metal_command_buffer* command_buffer); diff --git a/gpt_oss/metal/source/include/internal/metal.hpp b/gpt_oss/metal/source/include/internal/metal.hpp index 9df7aed7..a143a11a 100644 --- a/gpt_oss/metal/source/include/internal/metal.hpp +++ b/gpt_oss/metal/source/include/internal/metal.hpp @@ -246,10 +246,11 @@ class CommandBuffer { const std::array& threadgroup_size, const std::array& num_threadgroups, size_t params_size, const void* params, - std::initializer_list buffers = {}) + std::initializer_list device_buffers = {}, + size_t threadgroup_buffer_size = 0) { - std::vector buffer_handles(buffers.size()); - std::transform(buffers.begin(), buffers.end(), buffer_handles.begin(), + std::vector buffer_handles(device_buffers.size()); + std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(), [](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); }); Check(gptoss_metal_command_buffer_encode_launch_kernel( &command_buffer_, function.handle(), @@ -258,7 +259,8 @@ class CommandBuffer { params_size, params, buffer_handles.size(), buffer_handles.data(), - /*buffer_offsets=*/nullptr), + /*buffer_offsets=*/nullptr, + threadgroup_buffer_size), "gptoss_metal_command_buffer_encode_launch_kernel"); } diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c index 61b9c973..46fd1586 100644 --- a/gpt_oss/metal/source/metal-kernels.c +++ b/gpt_oss/metal/source/metal-kernels.c @@ -46,7 +46,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random( threadgroup_size, 1, 1, num_threadgroups, 1, 1, sizeof(args), &args, - 1, &output_buffer, &output_offset); + 1, &output_buffer, &output_offset, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random( @@ -93,7 +94,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random( threadgroup_size, 1, 1, num_threadgroups, 1, 1, sizeof(args), &args, - 1, &output_buffer, &output_offset); + 1, &output_buffer, &output_offset, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random( @@ -140,7 +142,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random( threadgroup_size, 1, 1, num_threadgroups, 1, 1, sizeof(args), &args, - 1, &output_buffer, &output_offset); + 1, &output_buffer, &output_offset, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert( @@ -180,7 +183,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert( threadgroup_size, 1, 1, num_threadgroups, 1, 1, sizeof(args), &args, - 3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL); + 3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings( @@ -222,7 +226,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings sizeof(args), &args, 3, (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer}, - (const size_t[]) {token_offset, weight_offset, output_offset}); + (const size_t[]) {token_offset, weight_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( @@ -268,7 +273,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( sizeof(args), &args, 3, (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer}, - (const size_t[]) {input_offset, weight_offset, output_offset}); + (const size_t[]) {input_offset, weight_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( @@ -325,7 +331,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( sizeof(args), &args, 4, (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset}); + (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( @@ -382,7 +389,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad sizeof(args), &args, 4, (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset}); + (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( @@ -437,7 +445,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi sizeof(args), &args, 4, (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer}, - (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset}); + (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( @@ -510,7 +519,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul sizeof(args), &args, 6, (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}); + (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( @@ -581,7 +591,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul sizeof(args), &args, 6, (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}); + (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope( @@ -631,7 +642,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope( threadgroup_size, 1, 1, num_qk_heads / num_simdgroups, num_tokens, 1, sizeof(args), &args, - 1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL); + 1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( @@ -680,7 +692,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( sizeof(args), &args, 3, (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer}, - (const size_t[]) {input_offset, expert_offset, output_offset}); + (const size_t[]) {input_offset, expert_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk( @@ -715,7 +728,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk( sizeof(args), &args, 2, (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer}, - (const size_t[]) {input_offset, output_offset}); + (const size_t[]) {input_offset, output_offset}, + /*threadgroup_buffer_size=*/0); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( @@ -753,6 +767,11 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( return gptoss_status_invalid_argument; } + const size_t max_context_tokens = math_min(num_q_tokens + num_kv_tokens + 1, window); + const size_t threadgroup_size = math_min(f32_sdpa_fn->max_threadgroup_threads, + max_context_tokens * f32_sdpa_fn->simdgroup_threads); + const size_t half_threadgroup_size = math_round_down_po2(threadgroup_size / 2, f32_sdpa_fn->simdgroup_threads); + const struct gptoss_sdpa_args args = { .qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads), .num_kv_tokens = num_kv_tokens, @@ -761,12 +780,13 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( return gptoss_metal_command_buffer_encode_launch_kernel( command_buffer, f32_sdpa_fn, - /*threadgroup_size=*/32, 1, 1, + threadgroup_size, 1, 1, num_q_tokens, num_kv_heads, 1, sizeof(args), &args, 5, (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer}, - (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset}); + (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset}, + /*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float)); } enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( @@ -813,5 +833,6 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( sizeof(args), &args, 4, (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer}, - (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset}); + (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset}, + /*threadgroup_buffer_size=*/0); } diff --git a/gpt_oss/metal/source/metal.m b/gpt_oss/metal/source/metal.m index 4f6cb35f..a873bb36 100644 --- a/gpt_oss/metal/source/metal.m +++ b/gpt_oss/metal/source/metal.m @@ -380,9 +380,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel( size_t num_threadgroups_z, size_t params_size, const void* params, - size_t num_buffers, - const struct gptoss_metal_buffer** buffers, - const size_t* buffer_offsets) + size_t num_device_buffers, + const struct gptoss_metal_buffer** device_buffers, + const size_t* device_buffer_offsets, + size_t threadgroup_buffer_size) { if (command_buffer->object == NULL || function->pipeline_state_object == NULL) { return gptoss_status_invalid_state; @@ -396,11 +397,14 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel( // Set kernel arguments [command_encoder_obj setComputePipelineState:pipeline_state_obj]; [command_encoder_obj setBytes:params length:params_size atIndex:0]; - for (size_t i = 0; i < num_buffers; ++i) { - id buffer_obj = (id) buffers[i]->object; - const NSUInteger offset = buffer_offsets == NULL ? 0 : (NSUInteger) buffer_offsets[i]; + for (size_t i = 0; i < num_device_buffers; ++i) { + id buffer_obj = (id) device_buffers[i]->object; + const NSUInteger offset = device_buffer_offsets == NULL ? 0 : (NSUInteger) device_buffer_offsets[i]; [command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1]; } + if (threadgroup_buffer_size != 0) { + [command_encoder_obj setThreadgroupMemoryLength:threadgroup_buffer_size atIndex:0]; + } // Dispatch kernel const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z); diff --git a/gpt_oss/metal/source/sdpa.metal b/gpt_oss/metal/source/sdpa.metal index 5564be6c..5050cb41 100644 --- a/gpt_oss/metal/source/sdpa.metal +++ b/gpt_oss/metal/source/sdpa.metal @@ -11,7 +11,6 @@ // Each threadgroup handles 8 Q heads / 1 KV head for 1 token -[[max_total_threads_per_threadgroup(32)]] kernel void gptoss_f32_sdpa_q8_d64( constant gptoss_sdpa_args& args [[ buffer(0) ]], const device float* q [[ buffer(1) ]], @@ -19,14 +18,22 @@ kernel void gptoss_f32_sdpa_q8_d64( const device float* v [[ buffer(3) ]], const device bfloat* s [[ buffer(4) ]], device float* output [[ buffer(5) ]], + threadgroup void* threadgroup_buffer [[ threadgroup(0) ]], uint2 gid [[threadgroup_position_in_grid]], - uint tid [[thread_index_in_threadgroup]]) + uint2 tid [[thread_position_in_threadgroup]], + uint simdgroup_tid [[thread_index_in_simdgroup]], + uint simdgroup_idx [[simdgroup_index_in_threadgroup]], + uint num_simdgroups [[simdgroups_per_threadgroup]]) { + const uint simdgroup_size = 32; + const uint num_q_heads = 64; const uint num_kv_heads = 8; const uint head_dim = 64; const uint qmul = 8; + const uint token_stride = 2 * num_kv_heads * head_dim; + const uint qt = gid.x; // Q token index const uint h = gid.y; // KV head index @@ -44,14 +51,14 @@ kernel void gptoss_f32_sdpa_q8_d64( float m6 = static_cast(s[h * qmul + 6]); float m7 = static_cast(s[h * qmul + 7]); - float l0 = 1.0f; - float l1 = 1.0f; - float l2 = 1.0f; - float l3 = 1.0f; - float l4 = 1.0f; - float l5 = 1.0f; - float l6 = 1.0f; - float l7 = 1.0f; + float l0 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l1 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l2 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l3 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l4 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l5 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l6 = simdgroup_idx == 0 ? 1.0f : 0.0f; + float l7 = simdgroup_idx == 0 ? 1.0f : 0.0f; float2 out0 = 0.0f; float2 out1 = 0.0f; @@ -62,22 +69,22 @@ kernel void gptoss_f32_sdpa_q8_d64( float2 out6 = 0.0f; float2 out7 = 0.0f; - float2 q0 = reinterpret_cast(q + 0 * head_dim)[tid]; - float2 q1 = reinterpret_cast(q + 1 * head_dim)[tid]; - float2 q2 = reinterpret_cast(q + 2 * head_dim)[tid]; - float2 q3 = reinterpret_cast(q + 3 * head_dim)[tid]; - float2 q4 = reinterpret_cast(q + 4 * head_dim)[tid]; - float2 q5 = reinterpret_cast(q + 5 * head_dim)[tid]; - float2 q6 = reinterpret_cast(q + 6 * head_dim)[tid]; - float2 q7 = reinterpret_cast(q + 7 * head_dim)[tid]; + float2 q0 = reinterpret_cast(q + 0 * head_dim)[simdgroup_tid]; + float2 q1 = reinterpret_cast(q + 1 * head_dim)[simdgroup_tid]; + float2 q2 = reinterpret_cast(q + 2 * head_dim)[simdgroup_tid]; + float2 q3 = reinterpret_cast(q + 3 * head_dim)[simdgroup_tid]; + float2 q4 = reinterpret_cast(q + 4 * head_dim)[simdgroup_tid]; + float2 q5 = reinterpret_cast(q + 5 * head_dim)[simdgroup_tid]; + float2 q6 = reinterpret_cast(q + 6 * head_dim)[simdgroup_tid]; + float2 q7 = reinterpret_cast(q + 7 * head_dim)[simdgroup_tid]; const uint kt_end = qt + args.num_kv_tokens + 1; - const uint kt_start = metal::subsat(kt_end, args.window); - k += 2 * num_kv_heads * head_dim * kt_start; - v += 2 * num_kv_heads * head_dim * kt_start; - for (uint kt = kt_start; kt < kt_end; kt++) { - const float2 kval = reinterpret_cast(k)[tid]; - k += 2 * num_kv_heads * head_dim; + const uint kt_start = metal::subsat(kt_end, args.window) + simdgroup_idx; + k += token_stride * kt_start; + v += token_stride * kt_start; + for (uint kt = kt_start; kt < kt_end; kt += num_simdgroups) { + const float2 kval = reinterpret_cast(k)[simdgroup_tid]; + k += token_stride * num_simdgroups; float qk0 = metal::dot(q0, kval); float qk1 = metal::dot(q1, kval); @@ -142,8 +149,8 @@ kernel void gptoss_f32_sdpa_q8_d64( m6 = new_m6; m7 = new_m7; - const float2 vval = reinterpret_cast(v)[tid]; - v += 2 * num_kv_heads * head_dim; + const float2 vval = reinterpret_cast(v)[simdgroup_tid]; + v += token_stride * num_simdgroups; out0 = metal::fma(vval, qk0, out0 * alpha0); out1 = metal::fma(vval, qk1, out1 * alpha1); out2 = metal::fma(vval, qk2, out2 * alpha2); @@ -153,12 +160,135 @@ kernel void gptoss_f32_sdpa_q8_d64( out6 = metal::fma(vval, qk6, out6 * alpha6); out7 = metal::fma(vval, qk7, out7 * alpha7); } - reinterpret_cast(output + 0 * head_dim)[tid] = out0 / l0; - reinterpret_cast(output + 1 * head_dim)[tid] = out1 / l1; - reinterpret_cast(output + 2 * head_dim)[tid] = out2 / l2; - reinterpret_cast(output + 3 * head_dim)[tid] = out3 / l3; - reinterpret_cast(output + 4 * head_dim)[tid] = out4 / l4; - reinterpret_cast(output + 5 * head_dim)[tid] = out5 / l5; - reinterpret_cast(output + 6 * head_dim)[tid] = out6 / l6; - reinterpret_cast(output + 7 * head_dim)[tid] = out7 / l7; + if (num_simdgroups > 1) { + if (metal::simd_is_first()) { + static_cast(threadgroup_buffer)[0 * num_simdgroups + simdgroup_idx] = m0; + static_cast(threadgroup_buffer)[1 * num_simdgroups + simdgroup_idx] = m1; + static_cast(threadgroup_buffer)[2 * num_simdgroups + simdgroup_idx] = m2; + static_cast(threadgroup_buffer)[3 * num_simdgroups + simdgroup_idx] = m3; + static_cast(threadgroup_buffer)[4 * num_simdgroups + simdgroup_idx] = m4; + static_cast(threadgroup_buffer)[5 * num_simdgroups + simdgroup_idx] = m5; + static_cast(threadgroup_buffer)[6 * num_simdgroups + simdgroup_idx] = m6; + static_cast(threadgroup_buffer)[7 * num_simdgroups + simdgroup_idx] = m7; + + static_cast(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_idx] = l0; + static_cast(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_idx] = l1; + static_cast(threadgroup_buffer)[10 * num_simdgroups + simdgroup_idx] = l2; + static_cast(threadgroup_buffer)[11 * num_simdgroups + simdgroup_idx] = l3; + static_cast(threadgroup_buffer)[12 * num_simdgroups + simdgroup_idx] = l4; + static_cast(threadgroup_buffer)[13 * num_simdgroups + simdgroup_idx] = l5; + static_cast(threadgroup_buffer)[14 * num_simdgroups + simdgroup_idx] = l6; + static_cast(threadgroup_buffer)[15 * num_simdgroups + simdgroup_idx] = l7; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + // Note: simdgroup refers not to the thread's current simdgroup, but to one with simdgroup_idx == thread's simdgroup_tid. + float simdgroup_m0 = m0; + float simdgroup_m1 = m1; + float simdgroup_m2 = m2; + float simdgroup_m3 = m3; + float simdgroup_m4 = m4; + float simdgroup_m5 = m5; + float simdgroup_m6 = m6; + float simdgroup_m7 = m7; + if (simdgroup_tid < num_simdgroups) { + simdgroup_m0 = static_cast(threadgroup_buffer)[0 * num_simdgroups + simdgroup_tid]; + simdgroup_m1 = static_cast(threadgroup_buffer)[1 * num_simdgroups + simdgroup_tid]; + simdgroup_m2 = static_cast(threadgroup_buffer)[2 * num_simdgroups + simdgroup_tid]; + simdgroup_m3 = static_cast(threadgroup_buffer)[3 * num_simdgroups + simdgroup_tid]; + simdgroup_m4 = static_cast(threadgroup_buffer)[4 * num_simdgroups + simdgroup_tid]; + simdgroup_m5 = static_cast(threadgroup_buffer)[5 * num_simdgroups + simdgroup_tid]; + simdgroup_m6 = static_cast(threadgroup_buffer)[6 * num_simdgroups + simdgroup_tid]; + simdgroup_m7 = static_cast(threadgroup_buffer)[7 * num_simdgroups + simdgroup_tid]; + } + + const float threadgroup_m0 = metal::simd_max(simdgroup_m0); + const float threadgroup_m1 = metal::simd_max(simdgroup_m1); + const float threadgroup_m2 = metal::simd_max(simdgroup_m2); + const float threadgroup_m3 = metal::simd_max(simdgroup_m3); + const float threadgroup_m4 = metal::simd_max(simdgroup_m4); + const float threadgroup_m5 = metal::simd_max(simdgroup_m5); + const float threadgroup_m6 = metal::simd_max(simdgroup_m6); + const float threadgroup_m7 = metal::simd_max(simdgroup_m7); + + out0 *= metal::fast::exp(m0 - threadgroup_m0); + out1 *= metal::fast::exp(m1 - threadgroup_m1); + out2 *= metal::fast::exp(m2 - threadgroup_m2); + out3 *= metal::fast::exp(m3 - threadgroup_m3); + out4 *= metal::fast::exp(m4 - threadgroup_m4); + out5 *= metal::fast::exp(m5 - threadgroup_m5); + out6 *= metal::fast::exp(m6 - threadgroup_m6); + out7 *= metal::fast::exp(m7 - threadgroup_m7); + + if (simdgroup_idx == 0) { + l0 = 0.0f; + l1 = 0.0f; + l2 = 0.0f; + l3 = 0.0f; + l4 = 0.0f; + l5 = 0.0f; + l6 = 0.0f; + l7 = 0.0f; + if (simdgroup_tid < num_simdgroups) { + l0 = static_cast(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_tid]; + l1 = static_cast(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_tid]; + l2 = static_cast(threadgroup_buffer)[10 * num_simdgroups + simdgroup_tid]; + l3 = static_cast(threadgroup_buffer)[11 * num_simdgroups + simdgroup_tid]; + l4 = static_cast(threadgroup_buffer)[12 * num_simdgroups + simdgroup_tid]; + l5 = static_cast(threadgroup_buffer)[13 * num_simdgroups + simdgroup_tid]; + l6 = static_cast(threadgroup_buffer)[14 * num_simdgroups + simdgroup_tid]; + l7 = static_cast(threadgroup_buffer)[15 * num_simdgroups + simdgroup_tid]; + } + + l0 = metal::simd_sum(l0 * metal::fast::exp(simdgroup_m0 - threadgroup_m0)); + l1 = metal::simd_sum(l1 * metal::fast::exp(simdgroup_m1 - threadgroup_m1)); + l2 = metal::simd_sum(l2 * metal::fast::exp(simdgroup_m2 - threadgroup_m2)); + l3 = metal::simd_sum(l3 * metal::fast::exp(simdgroup_m3 - threadgroup_m3)); + l4 = metal::simd_sum(l4 * metal::fast::exp(simdgroup_m4 - threadgroup_m4)); + l5 = metal::simd_sum(l5 * metal::fast::exp(simdgroup_m5 - threadgroup_m5)); + l6 = metal::simd_sum(l6 * metal::fast::exp(simdgroup_m6 - threadgroup_m6)); + l7 = metal::simd_sum(l7 * metal::fast::exp(simdgroup_m7 - threadgroup_m7)); + } + + uint num_threads = num_simdgroups * simdgroup_size; + do { + const uint num_smem_threads = (num_threads / 2) & -simdgroup_size; + const uint num_half_threads = num_threads - num_smem_threads; + + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + const uint smem_tid = tid.x - num_half_threads; + if (smem_tid < num_smem_threads) { + static_cast(threadgroup_buffer)[num_smem_threads * 0 + smem_tid] = out0; + static_cast(threadgroup_buffer)[num_smem_threads * 1 + smem_tid] = out1; + static_cast(threadgroup_buffer)[num_smem_threads * 2 + smem_tid] = out2; + static_cast(threadgroup_buffer)[num_smem_threads * 3 + smem_tid] = out3; + static_cast(threadgroup_buffer)[num_smem_threads * 4 + smem_tid] = out4; + static_cast(threadgroup_buffer)[num_smem_threads * 5 + smem_tid] = out5; + static_cast(threadgroup_buffer)[num_smem_threads * 6 + smem_tid] = out6; + static_cast(threadgroup_buffer)[num_smem_threads * 7 + smem_tid] = out7; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + if (tid.x < num_smem_threads) { + out0 += static_cast(threadgroup_buffer)[num_smem_threads * 0 + tid.x]; + out1 += static_cast(threadgroup_buffer)[num_smem_threads * 1 + tid.x]; + out2 += static_cast(threadgroup_buffer)[num_smem_threads * 2 + tid.x]; + out3 += static_cast(threadgroup_buffer)[num_smem_threads * 3 + tid.x]; + out4 += static_cast(threadgroup_buffer)[num_smem_threads * 4 + tid.x]; + out5 += static_cast(threadgroup_buffer)[num_smem_threads * 5 + tid.x]; + out6 += static_cast(threadgroup_buffer)[num_smem_threads * 6 + tid.x]; + out7 += static_cast(threadgroup_buffer)[num_smem_threads * 7 + tid.x]; + } + + num_threads = num_half_threads; + } while (num_threads > simdgroup_size); + } + if (simdgroup_idx == 0) { + reinterpret_cast(output + 0 * head_dim)[simdgroup_tid] = out0 / l0; + reinterpret_cast(output + 1 * head_dim)[simdgroup_tid] = out1 / l1; + reinterpret_cast(output + 2 * head_dim)[simdgroup_tid] = out2 / l2; + reinterpret_cast(output + 3 * head_dim)[simdgroup_tid] = out3 / l3; + reinterpret_cast(output + 4 * head_dim)[simdgroup_tid] = out4 / l4; + reinterpret_cast(output + 5 * head_dim)[simdgroup_tid] = out5 / l5; + reinterpret_cast(output + 6 * head_dim)[simdgroup_tid] = out6 / l6; + reinterpret_cast(output + 7 * head_dim)[simdgroup_tid] = out7 / l7; + } } From 352cd3ca53f107f73b3daf284016bf1d4d00d01d Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Sun, 17 Aug 2025 18:02:48 -0700 Subject: [PATCH 66/91] chore: release 0.0.4 (#145) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f84c3a85..9ed47f92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ ] readme = "README.md" requires-python = ">=3.12,<3.13" -version = "0.0.3" +version = "0.0.4" [project.optional-dependencies] triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"] From 18fd1870d0b628f01c996b222a0c556aa2e6e623 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Tue, 19 Aug 2025 14:40:44 -0700 Subject: [PATCH 67/91] Update awesome-gpt-oss.md with llama.cpp (#148) --- awesome-gpt-oss.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index f4491d56..82cf7071 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -32,6 +32,8 @@ This is a list of guides and resources to help you get started with the gpt-oss - [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) - AMD - [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html) +- llama.cpp + - [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396) ### Server From dbb76fa4dbf60f24bd9aa516af16ec44a45a9914 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Tue, 26 Aug 2025 10:26:52 -0700 Subject: [PATCH 68/91] Update README.md (#154) --- README.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 876a58f0..4ef20827 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

Try gpt-oss · Guides · - Model card · + Model card · OpenAI blog

@@ -498,3 +498,17 @@ We recommend sampling with `temperature=1.0` and `top_p=1.0`. The reference implementations in this repository are meant as a starting point and inspiration. Outside of bug fixes we do not intend to accept new feature contributions. If you build implementations based on this code such as new tool implementations you are welcome to contribute them to the [`awesome-gpt-oss.md`](./awesome-gpt-oss.md) file. [harmony]: https://github.com/openai/harmony + +## Citation + +```bibtex +@misc{openai2025gptoss120bgptoss20bmodel, + title={gpt-oss-120b & gpt-oss-20b Model Card}, + author={OpenAI}, + year={2025}, + eprint={2508.10925}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2508.10925}, +} +``` From 5ec1d16f423a735375a755eb9f511d738c02bbe3 Mon Sep 17 00:00:00 2001 From: Samagra Sharma Date: Wed, 27 Aug 2025 19:57:03 -0700 Subject: [PATCH 69/91] Added Tensorfuse (AWS) guide (#118) --- awesome-gpt-oss.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index 82cf7071..9cebe650 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -65,6 +65,8 @@ This is a list of guides and resources to help you get started with the gpt-oss - [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b) - AMD - [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot) +- AWS (Deploy via Tensorfuse) + - [Deploy gpt-oss for both 20b and 120b models on AWS EKS](https://tensorfuse.io/docs/guides/modality/text/openai_oss) ## Examples & Tutorials From a19d0bc94d480505adb1e9b493d66e1a99d24443 Mon Sep 17 00:00:00 2001 From: Daniel Holanda Date: Thu, 28 Aug 2025 11:58:19 -0700 Subject: [PATCH 70/91] Add Lemonade to `awesome-gpt-oss` (#117) * Update awesome-gpt-oss.md * Update awesome-gpt-oss.md * Update awesome-gpt-oss.md --------- Co-authored-by: Dominik Kundel --- awesome-gpt-oss.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index 9cebe650..ac5a1c38 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -32,6 +32,8 @@ This is a list of guides and resources to help you get started with the gpt-oss - [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) - AMD - [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html) +- Lemonade + - [Running gpt-oss on STX Halo and Radeon dGPUs using Lemonade](https://lemonade-server.ai/news/gpt-oss.html) - llama.cpp - [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396) From 0c39f1da17df3a5f895802a4e9d130f033f79ba5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 28 Aug 2025 12:03:52 -0700 Subject: [PATCH 71/91] Add uv python backend (#156) * add uv python backend Co-authored-by: simon-mo * dangerously_use_uv --------- Co-authored-by: simon-mo --- gpt_oss/tools/python_docker/docker_tool.py | 32 +++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py index c31680ea..3d630cc1 100644 --- a/gpt_oss/tools/python_docker/docker_tool.py +++ b/gpt_oss/tools/python_docker/docker_tool.py @@ -3,6 +3,9 @@ import io import tarfile from typing import Any, AsyncIterator +import tempfile +import os +import subprocess import docker from openai_harmony import ( @@ -18,6 +21,11 @@ _docker_client = None +PYTHON_EXECUTION_BACKEND = "docker" + +if os.environ.get("PYTHON_EXECUTION_BACKEND") == "dangerously_use_uv": + PYTHON_EXECUTION_BACKEND = "dangerously_use_uv" + def call_python_script(script: str) -> str: """ @@ -58,6 +66,21 @@ def call_python_script(script: str) -> str: return output +def call_python_script_with_uv(script: str) -> str: + """ + Call a python script by writing it to a file to a temporary directory + and executing it with uv. + """ + with tempfile.TemporaryDirectory() as temp_dir: + script_path = os.path.join(temp_dir, "script.py") + with open(script_path, "w") as f: + f.write(script) + exec_result = subprocess.run( + ["uv", "run", "--no-project", "python", script_path], + capture_output=True) + return exec_result.stdout.decode("utf-8") + + class PythonTool(Tool): def __init__( self, @@ -118,5 +141,12 @@ def make_response( async def _process(self, message: Message) -> AsyncIterator[Message]: script = message.content[0].text channel = message.channel - output = call_python_script(script) + if PYTHON_EXECUTION_BACKEND == "docker": + output = call_python_script(script) + elif PYTHON_EXECUTION_BACKEND == "dangerously_use_uv": + output = call_python_script_with_uv(script) + else: + raise ValueError( + f"Invalid PYTHON_EXECUTION_BACKEND: {PYTHON_EXECUTION_BACKEND}" + ) yield self._make_response(output, channel=channel) From 7be9334950053a888e24887a57dac797a17d6e00 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Thu, 28 Aug 2025 12:04:11 -0700 Subject: [PATCH 72/91] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9ed47f92..fd38db07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ ] readme = "README.md" requires-python = ">=3.12,<3.13" -version = "0.0.4" +version = "0.0.5" [project.optional-dependencies] triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"] From 8ee92ec85de42a50deefb4075c45f014da5b2e22 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 08:21:27 -0700 Subject: [PATCH 73/91] Metal: add end-to-end benchmarks (#161) --- gpt_oss/metal/CMakeLists.txt | 4 ++ gpt_oss/metal/benchmark/end-to-end.cc | 82 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 gpt_oss/metal/benchmark/end-to-end.cc diff --git a/gpt_oss/metal/CMakeLists.txt b/gpt_oss/metal/CMakeLists.txt index c6a8e32b..d18708f4 100644 --- a/gpt_oss/metal/CMakeLists.txt +++ b/gpt_oss/metal/CMakeLists.txt @@ -147,6 +147,10 @@ add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc) target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels) target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include) +add_executable(end-to-end-bench benchmark/end-to-end.cc) +target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss) +target_include_directories(end-to-end-bench PRIVATE source/include) + # --- [ Python extension ] ----------------------------------------------- find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc new file mode 100644 index 00000000..4f73be7a --- /dev/null +++ b/gpt_oss/metal/benchmark/end-to-end.cc @@ -0,0 +1,82 @@ +#include + +#include +#include +#include +#include + +#include + + +static void end2end(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + + gptoss_tokenizer_t tokenizer_ptr = nullptr; + status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to retrieve Tokenizer"); + return; + } + std::unique_ptr, decltype(&gptoss_tokenizer_release)> tokenizer(tokenizer_ptr, gptoss_tokenizer_release); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), nullptr); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + + for (std::uint32_t i = 0; i < 3; i++) { + std::uint32_t predicted_token = std::numeric_limits::max(); + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + } + + for (auto _ : state) { + std::uint32_t predicted_token = std::numeric_limits::max(); + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + } + state.counters["tokens"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); +} + +BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond); +BENCHMARK_CAPTURE(end2end, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond); + +BENCHMARK_MAIN(); From 57e45b11b3a135e3e09d8fc5fea1bc793e003d44 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 08:22:01 -0700 Subject: [PATCH 74/91] Metal: simplify and optimize Reponses API adapter (#162) --- gpt_oss/responses_api/inference/metal.py | 62 ++---------------------- 1 file changed, 5 insertions(+), 57 deletions(-) diff --git a/gpt_oss/responses_api/inference/metal.py b/gpt_oss/responses_api/inference/metal.py index 9abe50db..ec84af7e 100644 --- a/gpt_oss/responses_api/inference/metal.py +++ b/gpt_oss/responses_api/inference/metal.py @@ -11,68 +11,16 @@ def setup_model(checkpoint: str) -> Callable[[list[int], float], int]: model = Model(checkpoint) context = Context(model) - def lcp(cache: list[int], inp: list[int]) -> list[int]: - i = 0 - max_len = min(len(cache), len(inp)) - while i < max_len and cache[i] == inp[i]: - i += 1 - return cache[:i] - - tokens_so_far = [] - def infer_next_token( tokens: list[int], temperature: float = 0.0, new_request: bool = False ) -> int: """Infer next token using incremental LCP caching when possible.""" - nonlocal tokens_so_far - - # Fast path: first call or explicitly new request. - if new_request or not tokens_so_far: - context.reset() - for t in tokens: - context.append(t) - tokens_so_far = tokens.copy() - context.process() - return int(context.sample(temperature=temperature)) - - # Longest common prefix length - overlap = lcp(tokens_so_far, tokens) - ol = len(overlap) - prev_len = len(tokens_so_far) - cur_len = len(tokens) - - diverged_midstream = (ol < prev_len) and ( - ol < cur_len - ) # mismatch not at the end - - if diverged_midstream: - # safest: rebuild - context.reset() - for t in tokens: - context.append(t) - tokens_so_far = tokens.copy() - context.process() - return int(context.sample(temperature=temperature)) - - if cur_len > prev_len: - # pure extension (good for KV reuse) - extension = tokens[prev_len:] - for t in extension: - context.append(t) - tokens_so_far = tokens.copy() - context.process() - return int(context.sample(temperature=temperature)) - - if cur_len < prev_len: - # truncation/backspace; easiest correct behavior is rebuild - context.reset() - for t in tokens: - context.append(t) - tokens_so_far = tokens.copy() - context.process() - return int(context.sample(temperature=temperature)) - # cur_len == prev_len and everything matches => no new tokens; just sample. + # Context handles LCP caching internally; if `tokens` matches the + # tokens in the KV cache, the KV cache is reused after reset+append. + context.reset() + for t in tokens: + context.append(t) return int(context.sample(temperature=temperature)) return infer_next_token From 38df14a605d27bb0e4fc473266e3c2094e66a42b Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 09:43:56 -0700 Subject: [PATCH 75/91] Metal: fix KV-cache invalidation after reset+append (#163) --- gpt_oss/metal/source/context.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index b58df99a..c0155d64 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -618,12 +618,13 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens( size_t num_verified_tokens = 0; for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) { if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) { + // Invalidate the KV cache starting with the newly added tokens. + context->num_kv_tokens = context->num_tokens + num_verified_tokens; break; } } context->num_tokens += num_verified_tokens; - context->num_kv_tokens = context->num_tokens; tokens += num_verified_tokens; num_tokens -= num_verified_tokens; } else { From 24804a6ac991b0dae88e32f8f1335c94bdfbf285 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 13:00:01 -0700 Subject: [PATCH 76/91] Increase max output tokens in Reponses API to 131K (#165) --- gpt_oss/responses_api/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt_oss/responses_api/types.py b/gpt_oss/responses_api/types.py index 4ca72c56..454d8e07 100644 --- a/gpt_oss/responses_api/types.py +++ b/gpt_oss/responses_api/types.py @@ -6,7 +6,7 @@ MODEL_IDENTIFIER = "gpt-oss-120b" DEFAULT_TEMPERATURE = 0.0 REASONING_EFFORT = ReasoningEffort.LOW -DEFAULT_MAX_OUTPUT_TOKENS = 10_000 +DEFAULT_MAX_OUTPUT_TOKENS = 131072 class UrlCitation(BaseModel): From 942ef444ae25493ff99cf2223e21096877f24f21 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 13:44:13 -0700 Subject: [PATCH 77/91] Remove requirement on maximum Python version (#167) Codebase works fine with CPython 3.13, and the current stable is 3.13.7, so no good reason to restrict that --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fd38db07..88f0ac45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "termcolor", ] readme = "README.md" -requires-python = ">=3.12,<3.13" +requires-python = ">=3.12" version = "0.0.5" [project.optional-dependencies] From a8ce88fcaf376e336f2265a8cabdf3acacc69323 Mon Sep 17 00:00:00 2001 From: Daniel Holanda Date: Tue, 2 Sep 2025 13:45:07 -0700 Subject: [PATCH 78/91] Move Lemonade to AMD section of `awesome-gpt-oss` (#164) * Update awesome-gpt-oss.md * Update awesome-gpt-oss.md * Update awesome-gpt-oss.md * Add Lemonade to AMD section --------- Co-authored-by: Dominik Kundel --- awesome-gpt-oss.md | 1 - 1 file changed, 1 deletion(-) diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md index ac5a1c38..8b82ebf8 100644 --- a/awesome-gpt-oss.md +++ b/awesome-gpt-oss.md @@ -32,7 +32,6 @@ This is a list of guides and resources to help you get started with the gpt-oss - [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) - AMD - [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html) -- Lemonade - [Running gpt-oss on STX Halo and Radeon dGPUs using Lemonade](https://lemonade-server.ai/news/gpt-oss.html) - llama.cpp - [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396) From 864020abceb92dc5354ebd0b0f51be43bedf65ed Mon Sep 17 00:00:00 2001 From: hrithiksagar-tih Date: Wed, 3 Sep 2025 02:16:19 +0530 Subject: [PATCH 79/91] Added VLLM Offline Serve working code. (#150) --- README.md | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/README.md b/README.md index 4ef20827..c4612bca 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,82 @@ vllm serve openai/gpt-oss-20b [Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm) +Offline Serve Code: +- run this code after installing proper libraries as described, while additionally installing this: +- `uv pip install openai-harmony` +```python +# source .oss/bin/activate + +import os +os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0" + +import json +from openai_harmony import ( + HarmonyEncodingName, + load_harmony_encoding, + Conversation, + Message, + Role, + SystemContent, + DeveloperContent, +) + +from vllm import LLM, SamplingParams +import os + +# --- 1) Render the prefill with Harmony --- +encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + +convo = Conversation.from_messages( + [ + Message.from_role_and_content(Role.SYSTEM, SystemContent.new()), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Always respond in riddles"), + ), + Message.from_role_and_content(Role.USER, "What is the weather like in SF?"), + ] +) + +prefill_ids = encoding.render_conversation_for_completion(convo, Role.ASSISTANT) + +# Harmony stop tokens (pass to sampler so they won't be included in output) +stop_token_ids = encoding.stop_tokens_for_assistant_actions() + +# --- 2) Run vLLM with prefill --- +llm = LLM( + model="openai/gpt-oss-20b", + trust_remote_code=True, + gpu_memory_utilization = 0.95, + max_num_batched_tokens=4096, + max_model_len=5000, + tensor_parallel_size=1 +) + +sampling = SamplingParams( + max_tokens=128, + temperature=1, + stop_token_ids=stop_token_ids, +) + +outputs = llm.generate( + prompt_token_ids=[prefill_ids], # batch of size 1 + sampling_params=sampling, +) + +# vLLM gives you both text and token IDs +gen = outputs[0].outputs[0] +text = gen.text +output_tokens = gen.token_ids # <-- these are the completion token IDs (no prefill) + +# --- 3) Parse the completion token IDs back into structured Harmony messages --- +entries = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT) + +# 'entries' is a sequence of structured conversation entries (assistant messages, tool calls, etc.). +for message in entries: + print(f"{json.dumps(message.to_dict())}") +``` + #### PyTorch / Triton / Metal These implementations are largely reference implementations for educational purposes and are not expected to be run in production. From 95d7716e75bb0dd8966f234bad0e32c61fd1e851 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 21:33:40 -0700 Subject: [PATCH 80/91] Metal: indicate threadgroup is a multiple of simdgroup (#168) 2% speedup on gpt-oss-20b end-to-end sampling --- gpt_oss/metal/source/metal.m | 67 +++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/gpt_oss/metal/source/metal.m b/gpt_oss/metal/source/metal.m index a873bb36..03d69962 100644 --- a/gpt_oss/metal/source/metal.m +++ b/gpt_oss/metal/source/metal.m @@ -96,18 +96,19 @@ enum gptoss_status gptoss_metal_library_create_default( enum gptoss_status status = gptoss_status_success; id device_obj = (id) device->object; id library_obj = nil; - NSError* error_obj = nil; - NSString* error_string_obj = nil; + NSAutoreleasePool* autorelease_pool = nil; dispatch_data_t library_blob = NULL; unsigned long library_size = 0; uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size); if (library_data != NULL) { library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT); + + autorelease_pool = [[NSAutoreleasePool alloc] init]; + NSError* error_obj = nil; library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj]; if (library_obj == nil) { - error_string_obj = [error_obj localizedDescription]; - GPTOSS_LOG_ERROR("failed to create Metal library: %s", [error_string_obj UTF8String]); + GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]); status = gptoss_status_unsupported_system; goto cleanup; } @@ -129,11 +130,8 @@ enum gptoss_status gptoss_metal_library_create_default( if (library_blob != NULL) { dispatch_release(library_blob); } - if (error_string_obj != nil) { - [error_string_obj release]; - } - if (error_obj != nil) { - [error_obj release]; + if (autorelease_pool != nil) { + [autorelease_pool drain]; } return status; } @@ -154,14 +152,16 @@ enum gptoss_status gptoss_metal_function_create( const char* name, struct gptoss_metal_function* function_out) { - NSString* name_obj = nil; - NSError* error_obj = nil; - NSString* error_string_obj = nil; + __block NSString* error_string_obj = nil; id function_obj = nil; + MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil; + __block id pipeline_state_obj = nil; + dispatch_semaphore_t pipeline_build_semaphore = NULL; enum gptoss_status status = gptoss_status_success; + NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init]; id library_obj = (id) library->object; - name_obj = [NSString stringWithUTF8String:name]; + NSString* name_obj = [NSString stringWithUTF8String:name]; function_obj = [library_obj newFunctionWithName:name_obj]; if (function_obj == nil) { GPTOSS_LOG_ERROR("failed to create Metal function %s", name); @@ -169,11 +169,33 @@ enum gptoss_status gptoss_metal_function_create( goto cleanup; } id device_obj = [library_obj device]; - id pipeline_state_obj = [device_obj newComputePipelineStateWithFunction:function_obj error:&error_obj]; + pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init]; + [pipeline_descriptor_obj setComputeFunction:function_obj]; + [pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES]; + + pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0); + [device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj + options:MTLPipelineOptionNone + completionHandler:^(id _Nullable new_state, + MTLComputePipelineReflection* _Nullable reflection, + NSError* _Nullable error_obj) { + if (new_state != nil) { + pipeline_state_obj = [new_state retain]; + } + if (error_obj != nil) { + error_string_obj = [[error_obj localizedDescription] copy]; + } + dispatch_semaphore_signal(pipeline_build_semaphore); + }]; + dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER); + if (pipeline_state_obj == nil) { - error_string_obj = [error_obj localizedDescription]; + const char* error_string = "unknown error"; + if (error_string_obj != nil) { + error_string = [error_string_obj UTF8String]; + } GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s", - name, [error_string_obj UTF8String]); + name, error_string); status = gptoss_status_unsupported_system; goto cleanup; } @@ -189,17 +211,20 @@ enum gptoss_status gptoss_metal_function_create( pipeline_state_obj = nil; cleanup: - if (name_obj != nil) { - [name_obj release]; - } if (function_obj != nil) { [function_obj release]; } + if (pipeline_descriptor_obj != nil) { + [pipeline_descriptor_obj release]; + } if (error_string_obj != nil) { [error_string_obj release]; } - if (error_obj != nil) { - [error_obj release]; + if (pipeline_build_semaphore != NULL) { + dispatch_release(pipeline_build_semaphore); + } + if (autorelease_pool != nil) { + [autorelease_pool drain]; } return status; } From 7f3c896dad67c3d39c73372d9a0a16f2c8835755 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 2 Sep 2025 23:16:12 -0700 Subject: [PATCH 81/91] Metal: mlock model weights in memory (#170) --- gpt_oss/metal/source/include/internal/model.h | 3 +++ gpt_oss/metal/source/model.c | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index 6b477745..ae62a3ec 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -54,6 +55,8 @@ struct gptoss_model { // Once the batch size is reached, we process it to fill the KV cache. size_t max_batch_tokens; + bool lock_memory; + size_t weights_size; size_t allocation_size; diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index e3aeb98f..70668639 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -290,6 +290,12 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( prefetch_fd(fd, model_mapping_start, model_mapping_size, path); + if (mlock(model_mapping_ptr, model_mapping_size) != 0) { + GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno); + } else { + model->lock_memory = true; + } + // Initialize Metal status = gptoss_metal_device_create_system_default(&model->device); if (status != gptoss_status_success) { @@ -497,6 +503,12 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release( // Weight buffers if (model->mapping_ptr != NULL && model->mapping_size != 0) { + if (model->lock_memory) { + if (munlock(model->mapping_ptr, model->mapping_size) != 0) { + GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno); + } + } + if (munmap(model->mapping_ptr, model->mapping_size) != 0) { GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno); } From a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd Mon Sep 17 00:00:00 2001 From: bojanbabic Date: Wed, 3 Sep 2025 15:30:02 -0700 Subject: [PATCH 82/91] Add You.com as tool for browser (#171) * Add You.com as tool for browser * change key name * update tests in order to mock API key * address changes * address changes * update README --- README.md | 13 ++- gpt-oss-mcp-server/browser_server.py | 12 ++- gpt-oss-mcp-server/reference-system-prompt.py | 4 +- gpt_oss/chat.py | 4 +- gpt_oss/responses_api/api_server.py | 13 ++- gpt_oss/tools/simple_browser/__init__.py | 3 +- gpt_oss/tools/simple_browser/backend.py | 102 ++++++++++++++++-- .../tools/simple_browser/test_backend.py | 70 ++++++++++++ 8 files changed, 197 insertions(+), 24 deletions(-) create mode 100644 tests/gpt_oss/tools/simple_browser/test_backend.py diff --git a/README.md b/README.md index c4612bca..0104cec4 100644 --- a/README.md +++ b/README.md @@ -426,7 +426,7 @@ codex -p oss ### Browser > [!WARNING] -> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`ExaBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment. +> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`YouComBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment. Currently we have available `YouComBackend` and `ExaBackend`. Both gpt-oss models were trained with the capability to browse using the `browser` tool that exposes the following three methods: @@ -441,15 +441,20 @@ To enable the browser tool, you'll have to place the definition into the `system ```python import datetime from gpt_oss.tools.simple_browser import SimpleBrowserTool -from gpt_oss.tools.simple_browser.backend import ExaBackend +from gpt_oss.tools.simple_browser.backend import YouComBackend from openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) -# Exa backend requires you to have set the EXA_API_KEY environment variable -backend = ExaBackend( +# Depending on the choice of the browser backend you need corresponding env variables setup +# In case you use You.com backend requires you to have set the YDC_API_KEY environment variable, +# while for Exa you might need EXA_API_KEY environment variable set +backend = YouComBackend( source="web", ) +# backend = ExaBackend( +# source="web", +# ) browser_tool = SimpleBrowserTool(backend=backend) # create a basic system prompt diff --git a/gpt-oss-mcp-server/browser_server.py b/gpt-oss-mcp-server/browser_server.py index 5d5ad4ad..b37a63a6 100644 --- a/gpt-oss-mcp-server/browser_server.py +++ b/gpt-oss-mcp-server/browser_server.py @@ -1,3 +1,4 @@ +import os from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -5,8 +6,7 @@ from mcp.server.fastmcp import Context, FastMCP from gpt_oss.tools.simple_browser import SimpleBrowserTool -from gpt_oss.tools.simple_browser.backend import ExaBackend - +from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend @dataclass class AppContext: @@ -14,7 +14,13 @@ class AppContext: def create_or_get_browser(self, session_id: str) -> SimpleBrowserTool: if session_id not in self.browsers: - backend = ExaBackend(source="web") + tool_backend = os.getenv("BROWSER_BACKEND", "exa") + if tool_backend == "youcom": + backend = YouComBackend(source="web") + elif tool_backend == "exa": + backend = ExaBackend(source="web") + else: + raise ValueError(f"Invalid tool backend: {tool_backend}") self.browsers[session_id] = SimpleBrowserTool(backend=backend) return self.browsers[session_id] diff --git a/gpt-oss-mcp-server/reference-system-prompt.py b/gpt-oss-mcp-server/reference-system-prompt.py index 98f171dd..6ddbf7c9 100644 --- a/gpt-oss-mcp-server/reference-system-prompt.py +++ b/gpt-oss-mcp-server/reference-system-prompt.py @@ -1,7 +1,7 @@ import datetime from gpt_oss.tools.simple_browser import SimpleBrowserTool -from gpt_oss.tools.simple_browser.backend import ExaBackend +from gpt_oss.tools.simple_browser.backend import YouComBackend from gpt_oss.tools.python_docker.docker_tool import PythonTool from gpt_oss.tokenizer import tokenizer @@ -22,7 +22,7 @@ ReasoningEffort.LOW).with_conversation_start_date( datetime.datetime.now().strftime("%Y-%m-%d"))) -backend = ExaBackend(source="web", ) +backend = YouComBackend(source="web") browser_tool = SimpleBrowserTool(backend=backend) system_message_content = system_message_content.with_tools( browser_tool.tool_config) diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 5e40079d..4856a397 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -19,7 +19,7 @@ from gpt_oss.tools import apply_patch from gpt_oss.tools.simple_browser import SimpleBrowserTool -from gpt_oss.tools.simple_browser.backend import ExaBackend +from gpt_oss.tools.simple_browser.backend import YouComBackend from gpt_oss.tools.python_docker.docker_tool import PythonTool from openai_harmony import ( @@ -85,7 +85,7 @@ def main(args): ) if args.browser: - backend = ExaBackend( + backend = YouComBackend( source="web", ) browser_tool = SimpleBrowserTool(backend=backend) diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py index 2934b011..8eb053f1 100644 --- a/gpt_oss/responses_api/api_server.py +++ b/gpt_oss/responses_api/api_server.py @@ -1,3 +1,4 @@ +import os import datetime import uuid from typing import Callable, Literal, Optional @@ -20,7 +21,7 @@ from gpt_oss.tools.python_docker.docker_tool import PythonTool from gpt_oss.tools.simple_browser import SimpleBrowserTool -from gpt_oss.tools.simple_browser.backend import ExaBackend +from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend from .events import ( ResponseCodeInterpreterCallCompleted, @@ -904,9 +905,13 @@ async def generate(body: ResponsesRequest, request: Request): ) if use_browser_tool: - backend = ExaBackend( - source="web", - ) + tool_backend = os.getenv("BROWSER_BACKEND", "exa") + if tool_backend == "youcom": + backend = YouComBackend(source="web") + elif tool_backend == "exa": + backend = ExaBackend(source="web") + else: + raise ValueError(f"Invalid tool backend: {tool_backend}") browser_tool = SimpleBrowserTool(backend=backend) else: browser_tool = None diff --git a/gpt_oss/tools/simple_browser/__init__.py b/gpt_oss/tools/simple_browser/__init__.py index 9043cb18..da3ff280 100644 --- a/gpt_oss/tools/simple_browser/__init__.py +++ b/gpt_oss/tools/simple_browser/__init__.py @@ -1,7 +1,8 @@ from .simple_browser_tool import SimpleBrowserTool -from .backend import ExaBackend +from .backend import ExaBackend, YouComBackend __all__ = [ "SimpleBrowserTool", "ExaBackend", + "YouComBackend", ] diff --git a/gpt_oss/tools/simple_browser/backend.py b/gpt_oss/tools/simple_browser/backend.py index 03bdf566..33daf8d6 100644 --- a/gpt_oss/tools/simple_browser/backend.py +++ b/gpt_oss/tools/simple_browser/backend.py @@ -3,6 +3,7 @@ """ import functools +import asyncio import logging import os from abc import abstractmethod @@ -87,6 +88,24 @@ async def search( async def fetch(self, url: str, session: ClientSession) -> PageContents: pass + async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict: + headers = {"x-api-key": self._get_api_key()} + async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp: + if resp.status != 200: + raise BackendError( + f"{self.__class__.__name__} error {resp.status}: {await resp.text()}" + ) + return await resp.json() + + async def _get(self, session: ClientSession, endpoint: str, params: dict) -> dict: + headers = {"x-api-key": self._get_api_key()} + async with session.get(f"{self.BASE_URL}{endpoint}", params=params, headers=headers) as resp: + if resp.status != 200: + raise BackendError( + f"{self.__class__.__name__} error {resp.status}: {await resp.text()}" + ) + return await resp.json() + @chz.chz(typecheck=True) class ExaBackend(Backend): @@ -106,14 +125,6 @@ def _get_api_key(self) -> str: raise BackendError("Exa API key not provided") return key - async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict: - headers = {"x-api-key": self._get_api_key()} - async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp: - if resp.status != 200: - raise BackendError( - f"Exa API error {resp.status}: {await resp.text()}" - ) - return await resp.json() async def search( self, query: str, topn: int, session: ClientSession @@ -164,3 +175,78 @@ async def fetch(self, url: str, session: ClientSession) -> PageContents: display_urls=True, session=session, ) + +@chz.chz(typecheck=True) +class YouComBackend(Backend): + """Backend that uses the You.com Search API.""" + + source: str = chz.field(doc="Description of the backend source") + + BASE_URL: str = "https://api.ydc-index.io" + + def _get_api_key(self) -> str: + key = os.environ.get("YDC_API_KEY") + if not key: + raise BackendError("You.com API key not provided") + return key + + + async def search( + self, query: str, topn: int, session: ClientSession + ) -> PageContents: + data = await self._get( + session, + "/v1/search", + {"query": query, "count": topn}, + ) + # make a simple HTML page to work with browser format + web_titles_and_urls, news_titles_and_urls = [], [] + if "web" in data["results"]: + web_titles_and_urls = [ + (result["title"], result["url"], result["snippets"]) + for result in data["results"]["web"] + ] + if "news" in data["results"]: + news_titles_and_urls = [ + (result["title"], result["url"], result["description"]) + for result in data["results"]["news"] + ] + titles_and_urls = web_titles_and_urls + news_titles_and_urls + html_page = f""" + +

Search Results

+
    +{"".join([f"
  • {title} {summary}
  • " for title, url, summary in titles_and_urls])} +
+ +""" + + return process_html( + html=html_page, + url="", + title=query, + display_urls=True, + session=session, + ) + + async def fetch(self, url: str, session: ClientSession) -> PageContents: + is_view_source = url.startswith(VIEW_SOURCE_PREFIX) + if is_view_source: + url = url[len(VIEW_SOURCE_PREFIX) :] + data = await self._post( + session, + "/v1/contents", + {"urls": [url], "livecrawl_formats": "html"}, + ) + if not data: + raise BackendError(f"No contents returned for {url}") + if "html" not in data[0]: + raise BackendError(f"No HTML returned for {url}") + return process_html( + html=data[0].get("html", ""), + url=url, + title=data[0].get("title", ""), + display_urls=True, + session=session, + ) + diff --git a/tests/gpt_oss/tools/simple_browser/test_backend.py b/tests/gpt_oss/tools/simple_browser/test_backend.py new file mode 100644 index 00000000..ab0dc780 --- /dev/null +++ b/tests/gpt_oss/tools/simple_browser/test_backend.py @@ -0,0 +1,70 @@ +import pytest +from typing import Generator, Any +from unittest import mock +from aiohttp import ClientSession + +from gpt_oss.tools.simple_browser.backend import YouComBackend + +class MockAiohttpResponse: + """Mocks responses for get/post requests from async libraries.""" + + def __init__(self, json: dict, status: int): + self._json = json + self.status = status + + async def json(self): + return self._json + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self + +def mock_os_environ_get(name: str, default: Any = "test_api_key"): + assert name in ["YDC_API_KEY"] + return default + +def test_youcom_backend(): + backend = YouComBackend(source="web") + assert backend.source == "web" + +@pytest.mark.asyncio +@mock.patch("aiohttp.ClientSession.get") +async def test_youcom_backend_search(mock_session_get): + backend = YouComBackend(source="web") + api_response = { + "results": { + "web": [ + {"title": "Web Result 1", "url": "https://www.example.com/web1", "snippets": "Web Result 1 snippets"}, + {"title": "Web Result 2", "url": "https://www.example.com/web2", "snippets": "Web Result 2 snippets"}, + ], + "news": [ + {"title": "News Result 1", "url": "https://www.example.com/news1", "description": "News Result 1 description"}, + {"title": "News Result 2", "url": "https://www.example.com/news2", "description": "News Result 2 description"}, + ], + } + } + with mock.patch("os.environ.get", wraps=mock_os_environ_get): + mock_session_get.return_value = MockAiohttpResponse(api_response, 200) + async with ClientSession() as session: + result = await backend.search(query="test", topn=10, session=session) + assert result.title == "test" + assert result.urls == {"0": "https://www.example.com/web1", "1": "https://www.example.com/web2", "2": "https://www.example.com/news1", "3": "https://www.example.com/news2"} + +@pytest.mark.asyncio +@mock.patch("aiohttp.ClientSession.post") +async def test_youcom_backend_fetch(mock_session_get): + backend = YouComBackend(source="web") + api_response = [ + {"title": "Fetch Result 1", "url": "https://www.example.com/fetch1", "html": "
Fetch Result 1 text
"}, + ] + with mock.patch("os.environ.get", wraps=mock_os_environ_get): + mock_session_get.return_value = MockAiohttpResponse(api_response, 200) + async with ClientSession() as session: + result = await backend.fetch(url="https://www.example.com/fetch1", session=session) + assert result.title == "Fetch Result 1" + assert result.text == "\nURL: https://www.example.com/fetch1\nFetch Result 1 text" + + + \ No newline at end of file From b558ecc5534986fb73fd8555ca04e2436149eb12 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Mon, 8 Sep 2025 00:21:32 -0700 Subject: [PATCH 83/91] Evals: correctly pass temperature/max_tokens when using Responses API (#174) --- gpt_oss/evals/responses_sampler.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/gpt_oss/evals/responses_sampler.py b/gpt_oss/evals/responses_sampler.py index fd9daef3..134303f5 100644 --- a/gpt_oss/evals/responses_sampler.py +++ b/gpt_oss/evals/responses_sampler.py @@ -42,24 +42,17 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: trial = 0 while True: try: + request_kwargs = { + "model": self.model, + "input": message_list, + "temperature": self.temperature, + "max_output_tokens": self.max_tokens, + } if self.reasoning_model: - reasoning = ( - {"effort": self.reasoning_effort} - if self.reasoning_effort - else None - ) - response = self.client.responses.create( - model=self.model, - input=message_list, - reasoning=reasoning, - ) - else: - response = self.client.responses.create( - model=self.model, - input=message_list, - temperature=self.temperature, - max_output_tokens=self.max_tokens, + request_kwargs["reasoning"] = ( + {"effort": self.reasoning_effort} if self.reasoning_effort else None ) + response = self.client.responses.create(**request_kwargs) for output in response.output: if hasattr(output, "text"): From be0d32efaef14a33bb5fad0a9b1d87ca240ad85b Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Mon, 8 Sep 2025 13:34:35 -0700 Subject: [PATCH 84/91] Metal: move sampling to GPU (#175) --- gpt_oss/metal/source/context.c | 225 +++++++++--------- .../source/include/internal/kernel-args.h | 8 + .../source/include/internal/metal-kernels.h | 16 ++ gpt_oss/metal/source/include/internal/model.h | 1 + gpt_oss/metal/source/metal-kernels.c | 53 +++++ gpt_oss/metal/source/model.c | 5 + gpt_oss/metal/source/sample.metal | 141 +++++++++++ 7 files changed, 332 insertions(+), 117 deletions(-) diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index c0155d64..0791c3eb 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -162,6 +162,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens( // Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens. static enum gptoss_status process_tokens( gptoss_context_t context, + struct gptoss_metal_command_buffer* command_buffer, size_t input_tokens_offset, size_t num_input_tokens, size_t num_output_tokens) @@ -173,14 +174,9 @@ static enum gptoss_status process_tokens( enum gptoss_status status = gptoss_status_success; const struct gptoss_model* model = context->model; - struct gptoss_metal_command_buffer command_buffer = {0}; const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads); - status = gptoss_metal_command_buffer_create(&model->command_queue, &command_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } const size_t input_tokens_end = input_tokens_offset + num_input_tokens; for (size_t input_batch_start = input_tokens_offset; input_batch_start < input_tokens_end; @@ -191,7 +187,7 @@ static enum gptoss_status process_tokens( const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end); status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings( - &command_buffer, + command_buffer, &model->bf16_f32_embeddings_fn, /*threadgroup_size=*/512, &context->token_buffer, @@ -204,14 +200,14 @@ static enum gptoss_status process_tokens( /*num_channels=*/model->embedding_dim); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch"); - goto cleanup; + return status; } for (uint32_t n = 0; n < model->num_blocks; n++) { const bool last_block = n + 1 == model->num_blocks; const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size; status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( - &command_buffer, + command_buffer, &model->f32_bf16w_rmsnorm_fn, &context->residual_activation_buffer, /*input_offset=*/0, @@ -224,11 +220,11 @@ static enum gptoss_status process_tokens( model->rmsnorm_epsilon); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); - goto cleanup; + return status; } status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - &command_buffer, + command_buffer, &model->f32_bf16w_matmul_fn, /*threadgroup_size=*/256, &context->rmsnorm_activation_buffer, @@ -244,11 +240,11 @@ static enum gptoss_status process_tokens( /*num_rows=*/attn_qkv_dim); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); - goto cleanup; + return status; } status = gptoss_metal_command_buffer_encode_launch_f32_rope( - &command_buffer, + command_buffer, &model->f32_rope_fn, /*threadgroup_size=*/32, &context->qkv_activation_buffer, @@ -264,12 +260,12 @@ static enum gptoss_status process_tokens( /*token_offset=*/input_batch_start); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch"); - goto cleanup; + return status; } for (uint32_t t = 0; t < input_batch_size; t++) { status = gptoss_metal_command_buffer_encode_copy_buffer( - &command_buffer, + command_buffer, &context->qkv_activation_buffer, /*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float), &context->kvcache_buffer, @@ -277,13 +273,13 @@ static enum gptoss_status process_tokens( /*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float)); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t); - goto cleanup; + return status; } } if (num_block_output_tokens != 0) { status = gptoss_metal_command_buffer_encode_launch_f32_sdpa( - &command_buffer, + command_buffer, &model->f32_sdpa_q8_d64_fn, &context->qkv_activation_buffer, /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), @@ -301,10 +297,11 @@ static enum gptoss_status process_tokens( model->num_heads, model->num_kv_heads, model->head_dim); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch"); - goto cleanup; + return status; } + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( - &command_buffer, + command_buffer, &model->f32_bf16w_matmul_fn, /*threadgroup_size=*/256, &context->sdpa_activation_buffer, @@ -320,11 +317,11 @@ static enum gptoss_status process_tokens( /*num_rows=*/model->embedding_dim); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch"); - goto cleanup; + return status; } - + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( - &command_buffer, + command_buffer, &model->f32_bf16w_rmsnorm_fn, &context->residual_activation_buffer, /*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), @@ -337,11 +334,11 @@ static enum gptoss_status process_tokens( model->rmsnorm_epsilon); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); - goto cleanup; + return status; } - + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - &command_buffer, + command_buffer, &model->f32_bf16w_matmul_fn, /*threadgroup_size=*/256, &context->rmsnorm_activation_buffer, @@ -357,15 +354,15 @@ static enum gptoss_status process_tokens( /*num_rows=*/model->num_experts); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); - goto cleanup; + return status; } - + const char* kernel_name = NULL; switch (model->num_experts) { case 32: kernel_name = "f32_topk_softmax_e32_k4_fn"; status = gptoss_metal_command_buffer_encode_launch_f32_topk( - &command_buffer, + command_buffer, &model->f32_topk_softmax_e32_k4_fn, &context->gate_activation_buffer, /*input_offset=*/0, &context->expert_activation_buffer, /*output_offset=*/0, @@ -376,7 +373,7 @@ static enum gptoss_status process_tokens( case 128: kernel_name = "f32_topk_softmax_e128_k4_fn"; status = gptoss_metal_command_buffer_encode_launch_f32_topk( - &command_buffer, + command_buffer, &model->f32_topk_softmax_e128_k4_fn, &context->gate_activation_buffer, /*input_offset=*/0, &context->expert_activation_buffer, /*output_offset=*/0, @@ -387,15 +384,15 @@ static enum gptoss_status process_tokens( default: status = gptoss_status_unsupported_argument; GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts); - goto cleanup; + return status; } if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name); - goto cleanup; + return status; } - + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( - &command_buffer, + command_buffer, &model->f32_mf4w_moe_matmul_swiglu_fn, /*threadgroup_size=*/512, &context->rmsnorm_activation_buffer, @@ -418,11 +415,11 @@ static enum gptoss_status process_tokens( model->mlp_dim); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch"); - goto cleanup; + return status; } - + status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( - &command_buffer, + command_buffer, &model->f32_mf4w_moe_matmul_fn, /*threadgroup_size=*/512, &context->swiglu_activation_buffer, @@ -444,11 +441,11 @@ static enum gptoss_status process_tokens( model->embedding_dim); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch"); - goto cleanup; + return status; } - + status = gptoss_metal_command_buffer_encode_launch_f32_accumulate( - &command_buffer, + command_buffer, &model->f32_accumulate_e4_fn, /*threadgroup_size=*/256, model->max_threadgroups, @@ -463,14 +460,14 @@ static enum gptoss_status process_tokens( model->num_active_experts); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch"); - goto cleanup; + return status; } } } if (output_batch_size != 0) { status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( - &command_buffer, + command_buffer, &model->f32_bf16w_rmsnorm_fn, &context->residual_activation_buffer, /*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float), @@ -483,22 +480,22 @@ static enum gptoss_status process_tokens( model->rmsnorm_epsilon); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); - goto cleanup; + return status; } status = gptoss_metal_command_buffer_encode_fill_buffer( - &command_buffer, + command_buffer, &context->argmax_buffer, /*offset=*/0, /*size=*/sizeof(uint64_t) * output_batch_size, /*fill_value=*/0xFF); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode fill buffer command"); - goto cleanup; + return status; } status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( - &command_buffer, + command_buffer, &model->f32_bf16w_unembedding_fn, /*threadgroup_size=*/256, model->max_threadgroups, @@ -515,17 +512,11 @@ static enum gptoss_status process_tokens( /*num_rows=*/model->vocabulary_size); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch"); - goto cleanup; + return status; } } } - - gptoss_metal_command_buffer_commit(&command_buffer); - gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL); - -cleanup: - gptoss_metal_command_buffer_release(&command_buffer); - return status; + return gptoss_status_success; } enum gptoss_status GPTOSS_ABI gptoss_context_append_chars( @@ -643,16 +634,38 @@ enum gptoss_status GPTOSS_ABI gptoss_context_process( gptoss_context_t context) { if (context->num_tokens > context->num_kv_tokens) { - enum gptoss_status status = process_tokens( + struct gptoss_metal_command_buffer command_buffer = {0}; + + enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + + status = process_tokens( context, + &command_buffer, /*input_tokens_offset=*/context->num_kv_tokens, /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens, /*num_output_tokens=*/0); if (status != gptoss_status_success) { - return status; + goto cleanup; + } + + status = gptoss_metal_command_buffer_commit(&command_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + + status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL); + if (status != gptoss_status_success) { + goto cleanup; } context->num_kv_tokens = context->num_tokens; + +cleanup: + gptoss_metal_command_buffer_release(&command_buffer); + return status; } return gptoss_status_success; @@ -669,9 +682,16 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample( struct gptoss_metal_command_buffer command_buffer = {0}; *token_out = UINT32_MAX; + + status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } + if (context->num_kv_tokens < context->num_tokens) { status = process_tokens( context, + &command_buffer, /*input_tokens_offset=*/context->num_kv_tokens, /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens, /*num_output_tokens=*/1); @@ -679,30 +699,23 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample( } else { status = process_tokens( context, + &command_buffer, /*input_tokens_offset=*/context->num_tokens - 1, /*num_input_tokens=*/1, /*num_output_tokens=*/1); } if (status != gptoss_status_success) { - return status; + goto cleanup; } - if (temperature == 0.0f) { - const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0]; - *token_out = (uint32_t) argmax_bits; - } else { + if (temperature != 0.0f) { assert(context->num_processed_tokens != 0); - status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } - uint32_t num_threadgroups = 0; uint32_t num_dims_per_threadgroup = 0; status = gptoss_metal_command_buffer_encode_launch_f32_softmax( &command_buffer, &model->f32_softmax_fn, - /*threadgroup_size=*/256, + /*threadgroup_size=*/512, model->max_threadgroups, &context->score_buffer, /*score_offset=*/0, @@ -719,65 +732,43 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample( &num_dims_per_threadgroup); if (status != gptoss_status_success) { GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch"); + goto cleanup; } - gptoss_metal_command_buffer_commit(&command_buffer); - gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL); - - const uint32_t sample_word = rng_squares32(context->num_tokens, seed + UINT64_C(0x123456789ABCDEF)); - float sample_cdf = (float) ((int32_t) sample_word & INT32_C(0x00FFFFFF)) * 0x1.0p-24f; - - const float* sum_ptr = (const float*) context->sum_buffer.ptr; - float sum = 0.0f; - for (uint32_t i = 0; i < num_threadgroups; i++) { - sum += sum_ptr[i]; - } - sample_cdf *= sum; - - uint32_t block_idx = 0, token_idx = 0; - if (sample_cdf == 0.0f) { - // Make sure we choose the first token with non-zero probability rather than just the first token - sample_cdf = FLT_TRUE_MIN; - } - - // Step 1: find block - float cumsum = 0.0f; - for (; block_idx < num_threadgroups; block_idx++) { - const float new_cumsum = cumsum + sum_ptr[block_idx]; - if (new_cumsum >= sample_cdf) { - break; - } - cumsum = new_cumsum; - } - if (block_idx == num_threadgroups) { - block_idx -= 1; - } - - // Step 2: find token - const float* prob_ptr = (const float*) context->prob_buffer.ptr + block_idx * num_dims_per_threadgroup; - assert(model->vocabulary_size > num_dims_per_threadgroup * block_idx); - uint32_t num_dims_per_block = math_min(num_dims_per_threadgroup, model->vocabulary_size - num_dims_per_threadgroup * block_idx); - for (; token_idx < num_dims_per_block; token_idx++) { - const float new_cumsum = cumsum + prob_ptr[token_idx]; - if (new_cumsum >= sample_cdf) { - break; - } - cumsum = new_cumsum; - } - if (token_idx == num_dims_per_block) { - token_idx -= 1; + status = gptoss_metal_command_buffer_encode_launch_f32_sample( + &command_buffer, + &model->f32_sample_fn, + /*min_threadgroup_size=*/512, + &context->prob_buffer, + /*prob_offset=*/0, + &context->sum_buffer, + /*sum_offset=*/0, + &context->argmax_buffer, + /*prediction_offset=*/0, + /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF), + /*num_blocks=*/num_threadgroups, + /*num_channels=*/model->vocabulary_size, + /*num_channels_per_block=*/num_dims_per_threadgroup, + /*token_offset=*/context->num_tokens); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch"); + goto cleanup; } + } - token_idx += block_idx * num_dims_per_threadgroup; - - *token_out = token_idx; + gptoss_metal_command_buffer_commit(&command_buffer); + gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL); -cleanup: - gptoss_metal_command_buffer_release(&command_buffer); - return status; + if (temperature == 0.0f) { + const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0]; + *token_out = (uint32_t) argmax_bits; + } else { + *token_out = ((uint32_t*) context->argmax_buffer.ptr)[0]; } - return gptoss_status_success; +cleanup: + gptoss_metal_command_buffer_release(&command_buffer); + return status; } enum gptoss_status GPTOSS_ABI gptoss_context_reset( diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h index 677ce488..a031902d 100644 --- a/gpt_oss/metal/source/include/internal/kernel-args.h +++ b/gpt_oss/metal/source/include/internal/kernel-args.h @@ -103,3 +103,11 @@ struct gptoss_softmax_args { uint32_t max_threadgroups; float temperature; }; + +struct gptoss_sample_args { + uint64_t seed; + uint32_t token_offset; + uint32_t num_blocks; + uint32_t num_dims; + uint32_t num_dims_per_block; +}; diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h index aa5a3ef7..64cb36e0 100644 --- a/gpt_oss/metal/source/include/internal/metal-kernels.h +++ b/gpt_oss/metal/source/include/internal/metal-kernels.h @@ -265,6 +265,22 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( uint32_t* num_threadgroups_out, uint32_t* num_channels_per_threadgroup_out); +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_sample_fn, + size_t min_threadgroup_size, + const struct gptoss_metal_buffer* prob_buffer, + size_t prob_offset, + const struct gptoss_metal_buffer* sum_buffer, + size_t sum_offset, + const struct gptoss_metal_buffer* prediction_buffer, + size_t prediction_offset, + uint64_t rng_seed, + uint32_t num_blocks, + uint32_t num_channels, + uint32_t num_channels_per_block, + uint32_t token_offset); + #ifdef __cplusplus } // extern "C" #endif diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index ae62a3ec..c17510b8 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -77,6 +77,7 @@ struct gptoss_model { struct gptoss_metal_function f32_topk_softmax_e128_k4_fn; struct gptoss_metal_function f32_sdpa_q8_d64_fn; struct gptoss_metal_function f32_softmax_fn; + struct gptoss_metal_function f32_sample_fn; size_t per_block_shared_weights_size; size_t per_expert_block_weight_size; diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c index 46fd1586..a9a5253c 100644 --- a/gpt_oss/metal/source/metal-kernels.c +++ b/gpt_oss/metal/source/metal-kernels.c @@ -836,3 +836,56 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset}, /*threadgroup_buffer_size=*/0); } + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_sample_fn, + size_t min_threadgroup_size, + const struct gptoss_metal_buffer* prob_buffer, + size_t prob_offset, + const struct gptoss_metal_buffer* sum_buffer, + size_t sum_offset, + const struct gptoss_metal_buffer* prediction_buffer, + size_t prediction_offset, + uint64_t rng_seed, + uint32_t num_blocks, + uint32_t num_channels, + uint32_t num_channels_per_block, + uint32_t token_offset) +{ + if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) { + return gptoss_status_invalid_state; + } + + if (min_threadgroup_size > f32_sample_fn->max_threadgroup_threads) { + return gptoss_status_invalid_argument; + } + + if (min_threadgroup_size % f32_sample_fn->simdgroup_threads != 0) { + return gptoss_status_invalid_argument; + } + + if (num_blocks > f32_sample_fn->max_threadgroup_threads) { + return gptoss_status_invalid_argument; + } + + const struct gptoss_sample_args args = { + .seed = rng_seed, + .token_offset = token_offset, + .num_blocks = num_blocks, + .num_dims = num_channels, + .num_dims_per_block = num_channels_per_block, + }; + + const size_t threadgroup_size = math_max(min_threadgroup_size, + math_round_up_po2(num_blocks, f32_sample_fn->simdgroup_threads)); + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_sample_fn, + threadgroup_size, 1, 1, + 1, 1, 1, + sizeof(args), &args, + 3, + (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, prediction_buffer}, + (const size_t[]) {prob_offset, sum_offset, prediction_offset}, + /*threadgroup_buffer_size=*/0); +} diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index 70668639..7a0450ce 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -356,6 +356,10 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( if (status != gptoss_status_success) { goto cleanup; } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn); + if (status != gptoss_status_success) { + goto cleanup; + } status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn); if (status != gptoss_status_success) { goto cleanup; @@ -495,6 +499,7 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release( gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn); gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn); gptoss_metal_function_release(&model->f32_softmax_fn); + gptoss_metal_function_release(&model->f32_sample_fn); gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn); gptoss_metal_library_release(&model->library); diff --git a/gpt_oss/metal/source/sample.metal b/gpt_oss/metal/source/sample.metal index b739f72c..8ce4598b 100644 --- a/gpt_oss/metal/source/sample.metal +++ b/gpt_oss/metal/source/sample.metal @@ -9,6 +9,27 @@ #pragma METAL fp contract(off) +inline static uint rng_squares32(ulong offset, ulong seed) { + const ulong y = offset * seed; + const ulong z = y + seed; + + /* Round 1 */ + ulong x = y * y + y; + x = metal::rotate(x, 32ul); + + /* Round 2 */ + x = x * x + z; + x = metal::rotate(x, 32ul); + + /* Round 3 */ + x = x * x + y; + x = metal::rotate(x, 32ul); + + /* Round 4 */ + x = x * x + z; + return as_type(x).y; +} + kernel void gptoss_f32_softmax( constant gptoss_softmax_args& args [[ buffer(0) ]], const device float* score [[ buffer(1) ]], @@ -58,3 +79,123 @@ kernel void gptoss_f32_softmax( } } } + +[[max_total_threads_per_threadgroup(1024)]] +kernel void gptoss_f32_sample( + constant gptoss_sample_args& args [[ buffer(0) ]], + device const float* prob [[ buffer(1) ]], + device const float* sum [[ buffer(2) ]], + device uint* prediction [[ buffer(3) ]], + uint tid [[thread_position_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]], + uint simdgroup_tid [[thread_index_in_simdgroup]], + uint simdgroup_idx [[simdgroup_index_in_threadgroup]], + uint num_simdgroups [[simdgroups_per_threadgroup]]) +{ + threadgroup float threadgroup_sum_buffer[32]; + threadgroup uint threadgroup_idx_buffer[32]; + threadgroup float threadgroup_cumsum_buffer[32]; + + const uint sample_word = rng_squares32(args.token_offset, args.seed); + float sample_cdf = static_cast(sample_word & 0x00FFFFFFu) * 0x1.0p-24f; + + float cumsum = 0.0f; + if (tid < args.num_blocks) { + cumsum = sum[tid]; + } + cumsum = metal::simd_prefix_inclusive_sum(cumsum); + if (simdgroup_tid == 31) { + threadgroup_sum_buffer[simdgroup_idx] = cumsum; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f; + if (simdgroup_tid < num_simdgroups) { + threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid]; + if (simdgroup_tid < simdgroup_idx) { + threadgroup_cumsum = threadgroup_sum; + } + } + threadgroup_sum = metal::simd_sum(threadgroup_sum); + cumsum += metal::simd_sum(threadgroup_cumsum); + + sample_cdf *= threadgroup_sum; + sample_cdf = metal::max(sample_cdf, 0x1.0p-149f); + + // Find the block: the smallest tid where sample_cdf >= s + uint block_idx = args.num_blocks; + float block_sum = cumsum; + if (tid >= args.num_blocks - 1) { + block_idx = args.num_blocks - 1; + block_sum = 0.0f; + } else if (cumsum >= sample_cdf) { + block_idx = tid; + block_sum = 0.0f; + } + block_idx = metal::simd_min(block_idx); + block_sum = metal::simd_max(block_sum); + if (simdgroup_tid == 0) { + threadgroup_idx_buffer[simdgroup_idx] = block_idx; + threadgroup_cumsum_buffer[simdgroup_idx] = block_sum; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + if (simdgroup_tid < num_simdgroups) { + block_idx = threadgroup_idx_buffer[simdgroup_tid]; + block_sum = threadgroup_cumsum_buffer[simdgroup_tid]; + } + block_idx = metal::simd_min(block_idx); + block_sum = metal::simd_max(block_sum); + + const uint block_start = args.num_dims_per_block * block_idx; + const uint block_end = metal::min(block_start + args.num_dims_per_block, args.num_dims); + uint offset = block_start + tid; + float accumulated_sum = block_sum; + uint sample_idx; + + // This loop must be threadgroup-uniform. + do { + // Find the token: the smallest tid where sample_cdf >= s + float cumsum = 0.0f; + if (offset < block_end) { + cumsum = prob[offset]; + } + cumsum = metal::simd_prefix_inclusive_sum(cumsum); + if (simdgroup_tid == 31) { + threadgroup_sum_buffer[simdgroup_idx] = cumsum; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f; + if (simdgroup_tid < num_simdgroups) { + threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid]; + if (simdgroup_tid < simdgroup_idx) { + threadgroup_cumsum = threadgroup_sum; + } + } + threadgroup_sum = metal::simd_sum(threadgroup_sum); + cumsum += metal::simd_sum(threadgroup_cumsum); + cumsum += accumulated_sum; + + sample_idx = block_end; + if (offset >= block_end) { + // Trigger loop exit, with the last token in the block being sampled if no other candidate was found. + sample_idx = block_end - 1; + } else if (cumsum >= sample_cdf) { + sample_idx = offset; + } + sample_idx = metal::simd_min(sample_idx); + if (simdgroup_tid == 0) { + threadgroup_idx_buffer[simdgroup_idx] = sample_idx; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + if (simdgroup_tid < num_simdgroups) { + sample_idx = threadgroup_idx_buffer[simdgroup_tid]; + } + sample_idx = metal::simd_min(sample_idx); + + offset += threadgroup_size; + accumulated_sum += threadgroup_sum; + } while (sample_idx == block_end); + + if (tid == 0) { + *prediction = sample_idx; + } +} From f2a1458a5625adafbb9fcc8e0df363ef76ef170b Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Mon, 8 Sep 2025 14:21:39 -0700 Subject: [PATCH 85/91] Metal: benchmark generation of 100 tokens instead of 1 (#178) --- gpt_oss/metal/benchmark/end-to-end.cc | 58 ++++++++++++++----- gpt_oss/metal/source/include/internal/model.h | 25 ++++---- 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc index 4f73be7a..f4168f94 100644 --- a/gpt_oss/metal/benchmark/end-to-end.cc +++ b/gpt_oss/metal/benchmark/end-to-end.cc @@ -1,13 +1,20 @@ #include +#include -#include +#include +#include #include +#include #include +#include #include #include +constexpr std::uint32_t num_generated_tokens = 100; + + static void end2end(benchmark::State& state, const char* env_var_name) { const char* model_path = getenv(env_var_name); if (model_path == NULL) { @@ -40,7 +47,8 @@ static void end2end(benchmark::State& state, const char* env_var_name) { std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); const char* prompt = "why did the chicken cross the road?"; - status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), nullptr); + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); if (status != gptoss_status_success) { state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); return; @@ -53,25 +61,49 @@ static void end2end(benchmark::State& state, const char* env_var_name) { return; } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + std::uint64_t rng_seed = 0; for (std::uint32_t i = 0; i < 3; i++) { - std::uint32_t predicted_token = std::numeric_limits::max(); - status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token); - if (status != gptoss_status_success) { - state.SkipWithError("failed to sample from the Context object"); - return; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + for (std::uint32_t n = 0; n < num_generated_tokens; n++) { + std::uint32_t predicted_token = std::numeric_limits::max(); + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + status = gptoss_context_append_tokens(context.get(), 1, &predicted_token); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token)); + return; + } } } for (auto _ : state) { - std::uint32_t predicted_token = std::numeric_limits::max(); - status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/0, &predicted_token); - if (status != gptoss_status_success) { - state.SkipWithError("failed to sample from the Context object"); - return; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + for (std::uint32_t n = 0; n < num_generated_tokens; n++) { + std::uint32_t predicted_token = std::numeric_limits::max(); + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + status = gptoss_context_append_tokens(context.get(), 1, &predicted_token); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token)); + return; + } } } - state.counters["tokens"] = + state.counters["generations"] = benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * num_generated_tokens, benchmark::Counter::kIsRate); } BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH") diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index c17510b8..50ed201c 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -1,6 +1,8 @@ #pragma once -#include +#ifndef __cplusplus + #include +#endif #include #include #include @@ -9,7 +11,11 @@ struct gptoss_tokenizer { +#ifndef __cplusplus atomic_uint_least64_t ref_count; +#else + uint_least64_t ref_count; +#endif void* mapping_ptr; size_t mapping_size; @@ -24,7 +30,11 @@ struct gptoss_tokenizer { }; struct gptoss_model { +#ifndef __cplusplus atomic_uint_least64_t ref_count; +#else + uint_least64_t ref_count; +#endif struct gptoss_tokenizer* tokenizer; @@ -108,7 +118,11 @@ struct gptoss_model { #define GPTOSS_DEFAULT_BATCH_SIZE 128 struct gptoss_context { +#ifndef __cplusplus atomic_uint_least64_t ref_count; +#else + uint_least64_t ref_count; +#endif struct gptoss_model* model; // Number of tokens processed in the context. @@ -140,12 +154,3 @@ struct gptoss_context { struct gptoss_metal_buffer argmax_buffer; struct gptoss_metal_buffer kvcache_buffer; }; - -struct gptoss_sampler { - atomic_uint_least64_t ref_count; - - float temperature; - float top_p; - float presence_penalty; - float frequency_penalty; -}; From 152fc0ce3b500752214c8d59440ef3a909e1e556 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Tue, 9 Sep 2025 10:20:33 -0700 Subject: [PATCH 86/91] Metal: support generating multiple tokens at once (#179) --- gpt_oss/metal/benchmark/end-to-end.cc | 52 ++--- gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc | 4 + gpt_oss/metal/include/gpt-oss/functions.h | 4 +- gpt_oss/metal/python/context.c | 57 +++-- gpt_oss/metal/source/accumulate.metal | 4 + gpt_oss/metal/source/context.c | 200 ++++++++++++------ gpt_oss/metal/source/embeddings.metal | 5 + gpt_oss/metal/source/generate.c | 3 +- .../source/include/internal/kernel-args.h | 8 +- .../source/include/internal/metal-kernels.h | 35 ++- gpt_oss/metal/source/include/internal/model.h | 1 + gpt_oss/metal/source/matmul.metal | 8 + gpt_oss/metal/source/metal-kernels.c | 115 ++++++---- gpt_oss/metal/source/moematmul.metal | 8 + gpt_oss/metal/source/rmsnorm.metal | 4 + gpt_oss/metal/source/rope.metal | 5 + gpt_oss/metal/source/sample.metal | 10 +- gpt_oss/metal/source/sdpa.metal | 4 + gpt_oss/metal/source/topk.metal | 8 + .../metal/test/embeddings-kernel-tester.hpp | 4 + gpt_oss/metal/test/matmul-kernel-tester.hpp | 4 + gpt_oss/metal/test/rmsnorm-kernel-tester.hpp | 4 + gpt_oss/metal/test/rope-kernel-tester.hpp | 5 + gpt_oss/responses_api/inference/metal.py | 29 ++- 24 files changed, 409 insertions(+), 172 deletions(-) diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc index f4168f94..0a242340 100644 --- a/gpt_oss/metal/benchmark/end-to-end.cc +++ b/gpt_oss/metal/benchmark/end-to-end.cc @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -12,7 +13,7 @@ #include -constexpr std::uint32_t num_generated_tokens = 100; +constexpr std::uint32_t kNumGeneratedTokens = 100; static void end2end(benchmark::State& state, const char* env_var_name) { @@ -30,14 +31,6 @@ static void end2end(benchmark::State& state, const char* env_var_name) { } std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); - gptoss_tokenizer_t tokenizer_ptr = nullptr; - status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr); - if (status != gptoss_status_success) { - state.SkipWithError("failed to retrieve Tokenizer"); - return; - } - std::unique_ptr, decltype(&gptoss_tokenizer_release)> tokenizer(tokenizer_ptr, gptoss_tokenizer_release); - gptoss_context_t context_ptr = nullptr; status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); if (status != gptoss_status_success) { @@ -60,50 +53,51 @@ static void end2end(benchmark::State& state, const char* env_var_name) { state.SkipWithError("failed to prefill Context object"); return; } - const std::size_t num_kvcache_tokens = context->num_kv_tokens; + std::uint64_t rng_seed = 0; for (std::uint32_t i = 0; i < 3; i++) { + const std::uint64_t current_rng_seed = rng_seed++; context->num_kv_tokens = num_prompt_tokens; context->num_tokens = num_prompt_tokens; - for (std::uint32_t n = 0; n < num_generated_tokens; n++) { - std::uint32_t predicted_token = std::numeric_limits::max(); - status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token); + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); if (status != gptoss_status_success) { state.SkipWithError("failed to sample from the Context object"); return; } - status = gptoss_context_append_tokens(context.get(), 1, &predicted_token); - if (status != gptoss_status_success) { - state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token)); - return; - } - } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); } for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; context->num_kv_tokens = num_prompt_tokens; context->num_tokens = num_prompt_tokens; - for (std::uint32_t n = 0; n < num_generated_tokens; n++) { - std::uint32_t predicted_token = std::numeric_limits::max(); - status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token); + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); if (status != gptoss_status_success) { state.SkipWithError("failed to sample from the Context object"); return; } - status = gptoss_context_append_tokens(context.get(), 1, &predicted_token); - if (status != gptoss_status_success) { - state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token)); - return; - } - } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); } + state.counters["generations"] = benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); state.counters["tokens"] = - benchmark::Counter(state.iterations() * num_generated_tokens, benchmark::Counter::kIsRate); + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); } BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH") diff --git a/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc b/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc index 17515942..ee7551c2 100644 --- a/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc +++ b/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc @@ -26,6 +26,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) { Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)}; Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)}; Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)}; + Buffer control_buffer{device, sizeof(gptoss_control)}; + std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control)); { CommandBuffer command_buffer{command_queue}; @@ -69,6 +71,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) { /*weight_offset=*/0, output_buffer.handle(), /*output_offset=*/0, + control_buffer.handle(), + /*control_offset=*/0, num_tokens, num_channels, kEpsilon), diff --git a/gpt_oss/metal/include/gpt-oss/functions.h b/gpt_oss/metal/include/gpt-oss/functions.h index 085ebe0d..6ddde253 100644 --- a/gpt_oss/metal/include/gpt-oss/functions.h +++ b/gpt_oss/metal/include/gpt-oss/functions.h @@ -290,7 +290,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample( gptoss_context_t context, float temperature, uint64_t seed, - uint32_t* token_out); + size_t max_tokens, + uint32_t* tokens_out, + size_t* num_tokens_out); /* * Increments a Context object's reference count. diff --git a/gpt_oss/metal/python/context.c b/gpt_oss/metal/python/context.c index d71cc396..abc031af 100644 --- a/gpt_oss/metal/python/context.c +++ b/gpt_oss/metal/python/context.c @@ -120,25 +120,54 @@ static PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) { } static PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) { - static char *kwlist[] = {"temperature", "seed", NULL}; + static char *kwlist[] = {"max_output_tokens", "temperature", "seed", NULL}; + PyObject* token_list_obj = NULL; + uint32_t* token_ptr = NULL; + unsigned int max_output_tokens = 0; unsigned long long seed = 0; float temperature = 1.0f; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$fK", kwlist, - &temperature, &seed)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I|$fK", kwlist, + &max_output_tokens, &temperature, &seed)) { return NULL; } - uint32_t token_out = UINT32_MAX; - enum gptoss_status status = gptoss_context_sample( - self->handle, temperature, (uint64_t) seed, &token_out); + token_ptr = (uint32_t*) PyMem_Malloc(max_output_tokens * sizeof(uint32_t)); + if (token_ptr == NULL) { + goto error; + } + + size_t num_tokens = 0; + const enum gptoss_status status = gptoss_context_sample( + self->handle, temperature, (uint64_t) seed, + (size_t) max_output_tokens, token_ptr, &num_tokens); if (status != gptoss_status_success) { // TODO: set exception - return NULL; + goto error; } - return PyLong_FromUnsignedLong((unsigned long) token_out); + token_list_obj = PyList_New((Py_ssize_t) num_tokens); + if (token_list_obj == NULL) { + goto error; + } + + for (size_t t = 0; t < num_tokens; t++) { + PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]); + if (token_obj == NULL) { + goto error; + } + + PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj); + } + + PyMem_Free(token_ptr); + return token_list_obj; + +error: + PyMem_Free(token_ptr); + Py_XDECREF(token_list_obj); + return NULL; } static PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) { @@ -155,7 +184,7 @@ static PyMethodDef PyGPTOSSContext_methods[] = { {"__copy__", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, "Create a copy of the Context"}, {"append", (PyCFunction) PyGPTOSSContext_append, METH_O, "Append bytes to the Context"}, {"process", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, "Process tokens in the Context"}, - {"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token prediction from the Context"}, + {"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token predictions from the Context"}, {"reset", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, "Discard the content of the Context"}, {NULL}, }; @@ -184,7 +213,6 @@ static PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* clo static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) { PyObject* token_list_obj = NULL; - PyObject* token_obj = NULL; uint32_t* token_ptr = NULL; size_t num_tokens = 0; @@ -210,14 +238,12 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure } for (size_t t = 0; t < num_tokens; t++) { - token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]); + PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]); if (token_obj == NULL) { goto error; } - if (PyList_SetItem(token_list_obj, (Py_ssize_t) t, token_obj) < 0) { - goto error; - } - token_obj = NULL; // PyList_SetItem stole the reference + + PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj); } PyMem_Free(token_ptr); @@ -225,7 +251,6 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure error: PyMem_Free(token_ptr); - Py_XDECREF(token_obj); Py_XDECREF(token_list_obj); return NULL; } diff --git a/gpt_oss/metal/source/accumulate.metal b/gpt_oss/metal/source/accumulate.metal index f7ebc506..70dc4c2b 100644 --- a/gpt_oss/metal/source/accumulate.metal +++ b/gpt_oss/metal/source/accumulate.metal @@ -12,11 +12,15 @@ kernel void gptoss_f32_accumulate_e4( const device float4* input [[ buffer(1) ]], const device gptoss_expert_prediction* expert [[ buffer(2) ]], device float4* output [[ buffer(3) ]], + const device gptoss_control* control [[ buffer(4) ]], uint2 gid [[threadgroup_position_in_grid]], uint tid [[thread_index_in_threadgroup]], uint2 threadgroup_size [[ threads_per_threadgroup ]]) { const uint num_active_experts = 4; + if (control->abort != 0) { + return; + } const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup; const uint threadgroup_start = gid.x * num_vecs_per_threadgroup; diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index 0791c3eb..2d246294 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -82,6 +82,10 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create( } // Input/output buffers + status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer); + if (status != gptoss_status_success) { + goto cleanup; + } status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer); if (status != gptoss_status_success) { goto cleanup; @@ -196,6 +200,8 @@ static enum gptoss_status process_tokens( /*weight_offset=*/0, &context->residual_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/input_batch_size, /*num_channels=*/model->embedding_dim); if (status != gptoss_status_success) { @@ -215,6 +221,8 @@ static enum gptoss_status process_tokens( /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, &context->rmsnorm_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/input_batch_size, /*num_channels=*/model->embedding_dim, model->rmsnorm_epsilon); @@ -235,6 +243,8 @@ static enum gptoss_status process_tokens( /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, &context->qkv_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/input_batch_size, /*num_cols=*/model->embedding_dim, /*num_rows=*/attn_qkv_dim); @@ -248,6 +258,9 @@ static enum gptoss_status process_tokens( &model->f32_rope_fn, /*threadgroup_size=*/32, &context->qkv_activation_buffer, + /*input_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, model->rope_theta, model->interpolation_scale, model->yarn_offset, @@ -291,6 +304,8 @@ static enum gptoss_status process_tokens( /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n, &context->sdpa_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX, num_block_output_tokens, input_batch_start + input_batch_size - num_block_output_tokens, @@ -312,6 +327,8 @@ static enum gptoss_status process_tokens( /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, &context->residual_activation_buffer, /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/num_block_output_tokens, /*num_cols=*/model->num_heads * model->head_dim, /*num_rows=*/model->embedding_dim); @@ -329,6 +346,8 @@ static enum gptoss_status process_tokens( /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n, &context->rmsnorm_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, num_block_output_tokens, model->embedding_dim, model->rmsnorm_epsilon); @@ -349,6 +368,8 @@ static enum gptoss_status process_tokens( /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, &context->gate_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/num_block_output_tokens, /*num_cols=*/model->embedding_dim, /*num_rows=*/model->num_experts); @@ -366,6 +387,7 @@ static enum gptoss_status process_tokens( &model->f32_topk_softmax_e32_k4_fn, &context->gate_activation_buffer, /*input_offset=*/0, &context->expert_activation_buffer, /*output_offset=*/0, + &context->control_buffer, /*control_offset=*/0, num_block_output_tokens, model->num_experts, model->num_active_experts); @@ -377,6 +399,7 @@ static enum gptoss_status process_tokens( &model->f32_topk_softmax_e128_k4_fn, &context->gate_activation_buffer, /*input_offset=*/0, &context->expert_activation_buffer, /*output_offset=*/0, + &context->control_buffer, /*control_offset=*/0, num_block_output_tokens, model->num_experts, model->num_active_experts); @@ -407,6 +430,8 @@ static enum gptoss_status process_tokens( /*bias_offset=*/model->mlp_swiglu_bias_offset, &context->swiglu_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, model->swiglu_limit, model->per_expert_block_weight_size, num_block_output_tokens, @@ -434,6 +459,8 @@ static enum gptoss_status process_tokens( /*bias_offset=*/model->mlp_out_bias_offset, &context->moe_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, model->per_expert_block_weight_size, num_block_output_tokens, model->num_active_experts, @@ -455,6 +482,8 @@ static enum gptoss_status process_tokens( /*expert_offset=*/0, &context->residual_activation_buffer, /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &context->control_buffer, + /*control_offset=*/0, model->embedding_dim, num_block_output_tokens, model->num_active_experts); @@ -475,6 +504,8 @@ static enum gptoss_status process_tokens( /*weight_offset=*/model->rmsnorm_weight_offset, &context->rmsnorm_activation_buffer, /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/output_batch_size, /*num_channels=*/model->embedding_dim, model->rmsnorm_epsilon); @@ -507,6 +538,8 @@ static enum gptoss_status process_tokens( /*output_offset=*/0, &context->argmax_buffer, /*argmax_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, /*num_tokens=*/output_batch_size, /*num_cols=*/model->embedding_dim, /*num_rows=*/model->vocabulary_size); @@ -641,6 +674,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_process( goto cleanup; } + struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr; + control->abort = 0; + status = process_tokens( context, &command_buffer, @@ -675,96 +711,121 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample( gptoss_context_t context, float temperature, uint64_t seed, - uint32_t* token_out) + size_t max_tokens, + uint32_t* tokens_out, + size_t* num_tokens_out) { enum gptoss_status status = gptoss_status_success; const struct gptoss_model* model = context->model; struct gptoss_metal_command_buffer command_buffer = {0}; - *token_out = UINT32_MAX; + *num_tokens_out = 0; - status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer); - if (status != gptoss_status_success) { - goto cleanup; - } + const uint32_t num_original_tokens = context->num_tokens; - if (context->num_kv_tokens < context->num_tokens) { - status = process_tokens( - context, - &command_buffer, - /*input_tokens_offset=*/context->num_kv_tokens, - /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens, - /*num_output_tokens=*/1); - context->num_kv_tokens = context->num_tokens; - } else { - status = process_tokens( - context, - &command_buffer, - /*input_tokens_offset=*/context->num_tokens - 1, - /*num_input_tokens=*/1, - /*num_output_tokens=*/1); - } + status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer); if (status != gptoss_status_success) { goto cleanup; } - if (temperature != 0.0f) { - assert(context->num_processed_tokens != 0); - uint32_t num_threadgroups = 0; - uint32_t num_dims_per_threadgroup = 0; - status = gptoss_metal_command_buffer_encode_launch_f32_softmax( - &command_buffer, - &model->f32_softmax_fn, - /*threadgroup_size=*/512, - model->max_threadgroups, - &context->score_buffer, - /*score_offset=*/0, - &context->argmax_buffer, - /*argmax_offset=*/0, - &context->prob_buffer, - /*prob_offset=*/0, - &context->sum_buffer, - /*sum_offset=*/0, - model->vocabulary_size, - /*num_tokens=*/1, - temperature, - &num_threadgroups, - &num_dims_per_threadgroup); + struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr; + control->abort = 0; + + for (size_t t = 0; t < max_tokens; t++) { + if (context->num_kv_tokens < context->num_tokens) { + status = process_tokens( + context, + &command_buffer, + /*input_tokens_offset=*/context->num_kv_tokens, + /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens, + /*num_output_tokens=*/1); + context->num_kv_tokens = context->num_tokens; + } else { + status = process_tokens( + context, + &command_buffer, + /*input_tokens_offset=*/context->num_tokens - 1, + /*num_input_tokens=*/1, + /*num_output_tokens=*/1); + } if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch"); goto cleanup; } - status = gptoss_metal_command_buffer_encode_launch_f32_sample( - &command_buffer, - &model->f32_sample_fn, - /*min_threadgroup_size=*/512, - &context->prob_buffer, - /*prob_offset=*/0, - &context->sum_buffer, - /*sum_offset=*/0, - &context->argmax_buffer, - /*prediction_offset=*/0, - /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF), - /*num_blocks=*/num_threadgroups, - /*num_channels=*/model->vocabulary_size, - /*num_channels_per_block=*/num_dims_per_threadgroup, - /*token_offset=*/context->num_tokens); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch"); - goto cleanup; + if (temperature != 0.0f) { + assert(context->num_processed_tokens != 0); + uint32_t num_threadgroups = 0; + uint32_t num_dims_per_threadgroup = 0; + status = gptoss_metal_command_buffer_encode_launch_f32_softmax( + &command_buffer, + &model->f32_softmax_fn, + /*threadgroup_size=*/512, + model->max_threadgroups, + &context->score_buffer, + /*score_offset=*/0, + &context->argmax_buffer, + /*argmax_offset=*/0, + &context->prob_buffer, + /*prob_offset=*/0, + &context->sum_buffer, + /*sum_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + model->vocabulary_size, + /*num_tokens=*/1, + temperature, + &num_threadgroups, + &num_dims_per_threadgroup); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch"); + goto cleanup; + } + + status = gptoss_metal_command_buffer_encode_launch_f32_sample( + &command_buffer, + &model->f32_sample_fn, + /*min_threadgroup_size=*/512, + &context->prob_buffer, + /*prob_offset=*/0, + &context->sum_buffer, + /*sum_offset=*/0, + &context->token_buffer, + /*token_offset=*/context->num_tokens * sizeof(uint32_t), + &context->control_buffer, + /*control_offset=*/0, + /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF), + /*rng_offset=*/context->num_tokens, + /*num_blocks=*/num_threadgroups, + /*num_channels=*/model->vocabulary_size, + /*num_channels_per_block=*/num_dims_per_threadgroup); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch"); + goto cleanup; + } + } else { + status = gptoss_metal_command_buffer_encode_copy_buffer( + &command_buffer, + &context->argmax_buffer, + /*input_offset=*/0, + &context->token_buffer, + /*output_offset=*/context->num_tokens * sizeof(uint32_t), + /*size=*/sizeof(uint32_t)); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode copy buffer"); + goto cleanup; + } } + context->num_tokens += 1; + context->num_kv_tokens = context->num_tokens; } gptoss_metal_command_buffer_commit(&command_buffer); gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL); - if (temperature == 0.0f) { - const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0]; - *token_out = (uint32_t) argmax_bits; - } else { - *token_out = ((uint32_t*) context->argmax_buffer.ptr)[0]; - } + const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr; + const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens; + memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t)); + *num_tokens_out = num_generated_tokens; cleanup: gptoss_metal_command_buffer_release(&command_buffer); @@ -805,6 +866,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release( gptoss_metal_buffer_release(&context->moe_activation_buffer); // Input/output buffers + gptoss_metal_buffer_release(&context->control_buffer); gptoss_metal_buffer_release(&context->token_buffer); gptoss_metal_buffer_release(&context->score_buffer); gptoss_metal_buffer_release(&context->prob_buffer); diff --git a/gpt_oss/metal/source/embeddings.metal b/gpt_oss/metal/source/embeddings.metal index b4541d21..9cc7d121 100644 --- a/gpt_oss/metal/source/embeddings.metal +++ b/gpt_oss/metal/source/embeddings.metal @@ -9,10 +9,15 @@ kernel void gptoss_bf16_f32_embeddings( const device uint* tokens [[ buffer(1) ]], const device bfloat4* weights [[ buffer(2) ]], device float4* output [[ buffer(3) ]], + const device gptoss_control* control [[ buffer(4) ]], uint gid [[threadgroup_position_in_grid]], uint tid [[thread_position_in_threadgroup]], uint threadgroup_size [[ threads_per_threadgroup ]]) { + if (control->abort != 0) { + return; + } + const uint t = tokens[gid]; weights += t * args.num_vecs; diff --git a/gpt_oss/metal/source/generate.c b/gpt_oss/metal/source/generate.c index 1711410a..36a5527b 100644 --- a/gpt_oss/metal/source/generate.c +++ b/gpt_oss/metal/source/generate.c @@ -268,8 +268,9 @@ int main(int argc, char *argv[]) { while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) { uint32_t predicted_token = UINT32_MAX; + size_t num_predicted_tokens = 0; const uint64_t inference_start_timestamp = mach_continuous_time(); - status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, &predicted_token); + status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens); if (status != gptoss_status_success) { fprintf(stderr, "Error: failed to sample from the Context object\n"); goto error; diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h index a031902d..259eaa8a 100644 --- a/gpt_oss/metal/source/include/internal/kernel-args.h +++ b/gpt_oss/metal/source/include/internal/kernel-args.h @@ -9,6 +9,10 @@ struct gptoss_expert_prediction { float score; }; +struct gptoss_control { + uint32_t abort; +}; + struct gptoss_topk_args { uint32_t num_vecs_per_token; }; @@ -105,8 +109,8 @@ struct gptoss_softmax_args { }; struct gptoss_sample_args { - uint64_t seed; - uint32_t token_offset; + uint64_t rng_seed; + uint32_t rng_offset; uint32_t num_blocks; uint32_t num_dims; uint32_t num_dims_per_block; diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h index 64cb36e0..269f025d 100644 --- a/gpt_oss/metal/source/include/internal/metal-kernels.h +++ b/gpt_oss/metal/source/include/internal/metal-kernels.h @@ -74,6 +74,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings size_t weight_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_channels); @@ -86,6 +88,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( size_t weight_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_channels, float epsilon); @@ -102,6 +106,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_cols, uint32_t num_rows); @@ -118,6 +124,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_cols, uint32_t num_rows); @@ -135,6 +143,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi size_t output_offset, const struct gptoss_metal_buffer* argmax_buffer, size_t argmax_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_cols, uint32_t num_rows); @@ -155,6 +165,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, float swiglu_limit, uint32_t expert_stride, uint32_t num_tokens, @@ -178,6 +190,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t expert_stride, uint32_t num_tokens, uint32_t num_active_experts, @@ -189,6 +203,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope( const struct gptoss_metal_function* f32_rope_fn, size_t threadgroup_size, const struct gptoss_metal_buffer* activations_buffer, + size_t activations_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, float rope_base, float interpolation_scale, float yarn_offset, @@ -211,6 +228,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( size_t expert_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_channels, uint32_t num_tokens, uint32_t num_experts); @@ -222,6 +241,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk( size_t input_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_experts, uint32_t num_active_experts); @@ -239,6 +260,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( size_t s_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t window, uint32_t num_q_tokens, uint32_t num_kv_tokens, @@ -259,6 +282,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( size_t prob_offset, const struct gptoss_metal_buffer* sum_buffer, size_t sum_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_channels, uint32_t num_tokens, float temperature, @@ -273,13 +298,15 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample( size_t prob_offset, const struct gptoss_metal_buffer* sum_buffer, size_t sum_offset, - const struct gptoss_metal_buffer* prediction_buffer, - size_t prediction_offset, + const struct gptoss_metal_buffer* token_buffer, + size_t token_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint64_t rng_seed, + uint32_t rng_offset, uint32_t num_blocks, uint32_t num_channels, - uint32_t num_channels_per_block, - uint32_t token_offset); + uint32_t num_channels_per_block); #ifdef __cplusplus } // extern "C" diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index 50ed201c..34e273aa 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -147,6 +147,7 @@ struct gptoss_context { struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert) // Input/output buffers. + struct gptoss_metal_buffer control_buffer; struct gptoss_metal_buffer token_buffer; // uint32 token IDs struct gptoss_metal_buffer score_buffer; // unembedding outputs struct gptoss_metal_buffer prob_buffer; diff --git a/gpt_oss/metal/source/matmul.metal b/gpt_oss/metal/source/matmul.metal index 6396f6cc..a4ec60d5 100644 --- a/gpt_oss/metal/source/matmul.metal +++ b/gpt_oss/metal/source/matmul.metal @@ -23,12 +23,16 @@ kernel void gptoss_f32_bf16w_matmul( const device bfloat4* weight [[ buffer(2) ]], const device bfloat* bias [[ buffer(3) ]], device float* output [[ buffer(4) ]], + const device gptoss_control* control [[ buffer(5) ]], uint2 gid [[threadgroup_position_in_grid]], uint simdgroup_tid [[thread_index_in_simdgroup]], uint simdgroup_idx [[simdgroup_index_in_threadgroup]], uint num_simdgroups [[simdgroups_per_threadgroup]]) { const uint simdgroup_size = 32; + if (control->abort != 0) { + return; + } const uint num_column_vecs = args.num_column_vecs; const uint row = gid.x * num_simdgroups + simdgroup_idx; @@ -68,6 +72,7 @@ kernel void gptoss_f32_bf16w_unembedding( const device bfloat4* weight [[ buffer(2) ]], device float* output [[ buffer(3) ]], device metal::atomic_ulong* argmax [[ buffer(4) ]], + const device gptoss_control* control [[ buffer(5) ]], uint2 gid [[threadgroup_position_in_grid]], uint simdgroup_tid [[thread_index_in_simdgroup]], uint simdgroup_idx [[simdgroup_index_in_threadgroup]], @@ -75,6 +80,9 @@ kernel void gptoss_f32_bf16w_unembedding( { const uint simdgroup_size = 32; threadgroup uint2 threadgroup_buffer[32]; + if (control->abort != 0) { + return; + } const uint num_column_vecs = args.num_column_vecs; const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx; diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c index a9a5253c..1316fa50 100644 --- a/gpt_oss/metal/source/metal-kernels.c +++ b/gpt_oss/metal/source/metal-kernels.c @@ -197,6 +197,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings size_t weight_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_channels) { @@ -224,9 +226,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings threadgroup_size, 1, 1, num_tokens, 1, 1, sizeof(args), &args, - 3, - (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer}, - (const size_t[]) {token_offset, weight_offset, output_offset}, + 4, + (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer, control_buffer}, + (const size_t[]) {token_offset, weight_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -239,6 +241,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( size_t weight_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_channels, float epsilon) @@ -271,9 +275,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( /*threadgroup_size=*/1024, 1, 1, num_tokens, 1, 1, sizeof(args), &args, - 3, - (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer}, - (const size_t[]) {input_offset, weight_offset, output_offset}, + 4, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, weight_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -289,6 +293,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_cols, uint32_t num_rows) @@ -329,9 +335,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( threadgroup_size, 1, 1, num_rows / num_simdgroups, num_tokens, 1, sizeof(args), &args, - 4, - (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset}, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -347,6 +353,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_cols, uint32_t num_rows) @@ -387,9 +395,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad threadgroup_size, 1, 1, num_rows / num_simdgroups, num_tokens, 1, sizeof(args), &args, - 4, - (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset}, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -406,6 +414,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi size_t output_offset, const struct gptoss_metal_buffer* argmax_buffer, size_t argmax_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_cols, uint32_t num_rows) @@ -443,9 +453,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi threadgroup_size, 1, 1, num_threadgroups, num_tokens, 1, sizeof(args), &args, - 4, - (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer}, - (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset}, + 5, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer, control_buffer}, + (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -465,6 +475,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, float swiglu_limit, uint32_t expert_stride, uint32_t num_tokens, @@ -517,9 +529,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul threadgroup_size, 1, 1, (2 * num_rows) / num_simdgroups, num_tokens, num_active_experts, sizeof(args), &args, - 6, - (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}, + 7, + (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -539,6 +551,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul size_t bias_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t expert_stride, uint32_t num_tokens, uint32_t num_active_experts, @@ -589,9 +603,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul threadgroup_size, 1, 1, num_rows / num_simdgroups, num_tokens, num_active_experts, sizeof(args), &args, - 6, - (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer}, - (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset}, + 7, + (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -600,6 +614,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope( const struct gptoss_metal_function* f32_rope_fn, size_t threadgroup_size, const struct gptoss_metal_buffer* activations_buffer, + size_t activations_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, float rope_base, float interpolation_scale, float yarn_offset, @@ -642,7 +659,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope( threadgroup_size, 1, 1, num_qk_heads / num_simdgroups, num_tokens, 1, sizeof(args), &args, - 1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL, + 2, + (const struct gptoss_metal_buffer *[]) {activations_buffer, control_buffer}, + (const size_t[]) {activations_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -657,6 +676,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( size_t expert_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_channels, uint32_t num_tokens, uint32_t num_experts) @@ -690,9 +711,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate( threadgroup_size, 1, 1, num_threadgroups, num_tokens, 1, sizeof(args), &args, - 3, - (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer}, - (const size_t[]) {input_offset, expert_offset, output_offset}, + 4, + (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, expert_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -703,6 +724,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk( size_t input_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_tokens, uint32_t num_experts, uint32_t num_active_experts) @@ -726,9 +749,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk( /*threadgroup_size=*/32, 1, 1, num_tokens, 1, 1, sizeof(args), &args, - 2, - (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer}, - (const size_t[]) {input_offset, output_offset}, + 3, + (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer, control_buffer}, + (const size_t[]) {input_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -745,6 +768,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( size_t s_offset, const struct gptoss_metal_buffer* output_buffer, size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t window, uint32_t num_q_tokens, uint32_t num_kv_tokens, @@ -783,9 +808,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( threadgroup_size, 1, 1, num_q_tokens, num_kv_heads, 1, sizeof(args), &args, - 5, - (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer}, - (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset}, + 6, + (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer, control_buffer}, + (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float)); } @@ -802,6 +827,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( size_t prob_offset, const struct gptoss_metal_buffer* sum_buffer, size_t sum_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint32_t num_channels, uint32_t num_tokens, float temperature, @@ -831,9 +858,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax( threadgroup_size, 1, 1, num_threadgroups, num_tokens, 1, sizeof(args), &args, - 4, - (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer}, - (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset}, + 5, + (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer, control_buffer}, + (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset, control_offset}, /*threadgroup_buffer_size=*/0); } @@ -845,13 +872,15 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample( size_t prob_offset, const struct gptoss_metal_buffer* sum_buffer, size_t sum_offset, - const struct gptoss_metal_buffer* prediction_buffer, - size_t prediction_offset, + const struct gptoss_metal_buffer* token_buffer, + size_t token_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, uint64_t rng_seed, + uint32_t rng_offset, uint32_t num_blocks, uint32_t num_channels, - uint32_t num_channels_per_block, - uint32_t token_offset) + uint32_t num_channels_per_block) { if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) { return gptoss_status_invalid_state; @@ -870,8 +899,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample( } const struct gptoss_sample_args args = { - .seed = rng_seed, - .token_offset = token_offset, + .rng_seed = rng_seed, + .rng_offset = rng_offset, .num_blocks = num_blocks, .num_dims = num_channels, .num_dims_per_block = num_channels_per_block, @@ -884,8 +913,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample( threadgroup_size, 1, 1, 1, 1, 1, sizeof(args), &args, - 3, - (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, prediction_buffer}, - (const size_t[]) {prob_offset, sum_offset, prediction_offset}, + 4, + (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, token_buffer, control_buffer}, + (const size_t[]) {prob_offset, sum_offset, token_offset, control_offset}, /*threadgroup_buffer_size=*/0); } diff --git a/gpt_oss/metal/source/moematmul.metal b/gpt_oss/metal/source/moematmul.metal index 6e2f6950..58247484 100644 --- a/gpt_oss/metal/source/moematmul.metal +++ b/gpt_oss/metal/source/moematmul.metal @@ -24,6 +24,7 @@ kernel void gptoss_f32_mf4w_moe_matmul_swiglu( const device uchar* weight_scales [[ buffer(4) ]], const device bfloat* bias [[ buffer(5) ]], device float* output [[ buffer(6) ]], + const device gptoss_control* control [[ buffer(7) ]], uint3 gid [[threadgroup_position_in_grid]], uint tid [[thread_index_in_threadgroup]], uint simdgroup_tid [[thread_index_in_simdgroup]], @@ -32,6 +33,9 @@ kernel void gptoss_f32_mf4w_moe_matmul_swiglu( { const uint simdgroup_size = 32; threadgroup float threadgroup_buffer[32]; + if (control->abort != 0) { + return; + } const uint num_column_vecs = args.num_column_vecs; const uint row = gid.x * num_simdgroups + simdgroup_idx; @@ -130,6 +134,7 @@ kernel void gptoss_f32_mf4w_moe_matmul( const device uchar* weight_scales [[ buffer(4) ]], const device bfloat* bias [[ buffer(5) ]], device float* output [[ buffer(6) ]], + const device gptoss_control* control [[ buffer(7) ]], uint3 gid [[threadgroup_position_in_grid]], uint tid [[thread_index_in_threadgroup]], uint simdgroup_tid [[thread_index_in_simdgroup]], @@ -137,6 +142,9 @@ kernel void gptoss_f32_mf4w_moe_matmul( uint num_simdgroups [[simdgroups_per_threadgroup]]) { const uint simdgroup_size = 32; + if (control->abort != 0) { + return; + } const uint num_column_vecs = args.num_column_vecs; const uint row = gid.x * num_simdgroups + simdgroup_idx; diff --git a/gpt_oss/metal/source/rmsnorm.metal b/gpt_oss/metal/source/rmsnorm.metal index ceb690f0..fc4bcaa2 100644 --- a/gpt_oss/metal/source/rmsnorm.metal +++ b/gpt_oss/metal/source/rmsnorm.metal @@ -14,12 +14,16 @@ kernel void gptoss_f32_bf16w_rmsnorm( const device float4* input [[ buffer(1) ]], const device bfloat4* weights [[ buffer(2) ]], device float4* output [[ buffer(3) ]], + const device gptoss_control* control [[ buffer(4) ]], uint gid [[threadgroup_position_in_grid]], uint tid [[thread_position_in_threadgroup]], uint threadgroup_size [[ threads_per_threadgroup ]]) { const uint simdgroup_size = 32; threadgroup float threadgroup_buffer[32]; + if (control->abort != 0) { + return; + } input += gid * args.num_vecs; output += gid * args.num_vecs; diff --git a/gpt_oss/metal/source/rope.metal b/gpt_oss/metal/source/rope.metal index 2739b5fa..ce4c3c8f 100644 --- a/gpt_oss/metal/source/rope.metal +++ b/gpt_oss/metal/source/rope.metal @@ -13,9 +13,14 @@ kernel void gptoss_f32_rope( constant gptoss_rope_args& args [[ buffer(0) ]], device float2* activations [[ buffer(1) ]], + const device gptoss_control* control [[ buffer(2) ]], uint2 gid [[thread_position_in_grid]]) { const uint num_head_dims = 64; + if (control->abort != 0) { + return; + } + const float head_idx = static_cast(gid.x % (num_head_dims / 2)); const uint token_idx = args.token_offset + gid.y; activations += gid.y * args.token_stride + gid.x; diff --git a/gpt_oss/metal/source/sample.metal b/gpt_oss/metal/source/sample.metal index 8ce4598b..4a0efe3b 100644 --- a/gpt_oss/metal/source/sample.metal +++ b/gpt_oss/metal/source/sample.metal @@ -36,6 +36,7 @@ kernel void gptoss_f32_softmax( const device uint2* argmax [[ buffer(2) ]], device float* prob [[ buffer(3) ]], device float* sum [[ buffer(4) ]], + const device gptoss_control* control [[ buffer(5) ]], uint tidx [[thread_index_in_threadgroup]], uint2 gid [[threadgroup_position_in_grid]], uint2 threadgroup_size [[threads_per_threadgroup]], @@ -44,6 +45,9 @@ kernel void gptoss_f32_softmax( uint num_simdgroups [[simdgroups_per_threadgroup]]) { threadgroup float threadgroup_sumexp[32]; + if (control->abort != 0) { + return; + } score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup; prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup; @@ -86,6 +90,7 @@ kernel void gptoss_f32_sample( device const float* prob [[ buffer(1) ]], device const float* sum [[ buffer(2) ]], device uint* prediction [[ buffer(3) ]], + device gptoss_control* control [[ buffer(4) ]], uint tid [[thread_position_in_threadgroup]], uint threadgroup_size [[threads_per_threadgroup]], uint simdgroup_tid [[thread_index_in_simdgroup]], @@ -95,8 +100,11 @@ kernel void gptoss_f32_sample( threadgroup float threadgroup_sum_buffer[32]; threadgroup uint threadgroup_idx_buffer[32]; threadgroup float threadgroup_cumsum_buffer[32]; + if (control->abort != 0) { + return; + } - const uint sample_word = rng_squares32(args.token_offset, args.seed); + const uint sample_word = rng_squares32(args.rng_offset, args.rng_seed); float sample_cdf = static_cast(sample_word & 0x00FFFFFFu) * 0x1.0p-24f; float cumsum = 0.0f; diff --git a/gpt_oss/metal/source/sdpa.metal b/gpt_oss/metal/source/sdpa.metal index 5050cb41..459bbe28 100644 --- a/gpt_oss/metal/source/sdpa.metal +++ b/gpt_oss/metal/source/sdpa.metal @@ -18,6 +18,7 @@ kernel void gptoss_f32_sdpa_q8_d64( const device float* v [[ buffer(3) ]], const device bfloat* s [[ buffer(4) ]], device float* output [[ buffer(5) ]], + const device gptoss_control* control [[ buffer(6) ]], threadgroup void* threadgroup_buffer [[ threadgroup(0) ]], uint2 gid [[threadgroup_position_in_grid]], uint2 tid [[thread_position_in_threadgroup]], @@ -26,6 +27,9 @@ kernel void gptoss_f32_sdpa_q8_d64( uint num_simdgroups [[simdgroups_per_threadgroup]]) { const uint simdgroup_size = 32; + if (control->abort != 0) { + return; + } const uint num_q_heads = 64; const uint num_kv_heads = 8; diff --git a/gpt_oss/metal/source/topk.metal b/gpt_oss/metal/source/topk.metal index d3532ac6..90f4e51c 100644 --- a/gpt_oss/metal/source/topk.metal +++ b/gpt_oss/metal/source/topk.metal @@ -14,11 +14,15 @@ kernel void gptoss_f32_topk_softmax_e128_k4( constant gptoss_topk_args& args [[ buffer(0) ]], const device float4* input [[ buffer(1) ]], device gptoss_expert_prediction* output [[ buffer(2) ]], + const device gptoss_control* control [[ buffer(3) ]], uint gid [[threadgroup_position_in_grid]], uint tid [[thread_position_in_threadgroup]]) { const uint num_experts = 128; const uint num_active_experts = 4; + if (control->abort != 0) { + return; + } input += gid * (num_experts / 4); output += gid * num_active_experts; @@ -132,11 +136,15 @@ kernel void gptoss_f32_topk_softmax_e32_k4( constant gptoss_topk_args& args [[ buffer(0) ]], const device float* input [[ buffer(1) ]], device gptoss_expert_prediction* output [[ buffer(2) ]], + const device gptoss_control* control [[ buffer(3) ]], uint gid [[threadgroup_position_in_grid]], uint tid [[thread_position_in_threadgroup]]) { const uint num_experts = 32; const uint num_active_experts = 4; + if (control->abort != 0) { + return; + } input += gid * num_experts; output += gid * num_active_experts; diff --git a/gpt_oss/metal/test/embeddings-kernel-tester.hpp b/gpt_oss/metal/test/embeddings-kernel-tester.hpp index fd810c6d..83092a8c 100644 --- a/gpt_oss/metal/test/embeddings-kernel-tester.hpp +++ b/gpt_oss/metal/test/embeddings-kernel-tester.hpp @@ -69,6 +69,8 @@ class EmbeddingsKernelTester { metal::Buffer token_buffer{device_, sizeof(std::uint32_t)}; metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)}; metal::Buffer output_buffer{device_, num_channels() * sizeof(float)}; + metal::Buffer control_buffer{device_, sizeof(gptoss_control)}; + std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control)); std::uint32_t* token_ptr = static_cast(token_buffer.ptr()); for (std::uint32_t t = 0; t < num_tokens(); t++) { @@ -85,6 +87,8 @@ class EmbeddingsKernelTester { /*weight_offset=*/0, output_buffer.handle(), /*output_offset=*/0, + control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_channels()), "gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings"); diff --git a/gpt_oss/metal/test/matmul-kernel-tester.hpp b/gpt_oss/metal/test/matmul-kernel-tester.hpp index ec13af6b..30826f70 100644 --- a/gpt_oss/metal/test/matmul-kernel-tester.hpp +++ b/gpt_oss/metal/test/matmul-kernel-tester.hpp @@ -78,6 +78,8 @@ class MatMulKernelTester { metal::Buffer weight_buffer{device_, num_rows() * num_cols() * sizeof(gptoss_bfloat16)}; metal::Buffer bias_buffer{device_, num_rows() * sizeof(gptoss_bfloat16)}; metal::Buffer output_buffer{device_, num_tokens() * num_rows() * sizeof(float)}; + metal::Buffer control_buffer{device_, sizeof(gptoss_control)}; + std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control)); command_buffer.encode_launch_f32_fill_random( f32_fill_random_fn_, @@ -115,6 +117,8 @@ class MatMulKernelTester { /*bias_offset=*/0, output_buffer.handle(), /*output_offset=*/0, + control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_cols(), num_rows()), diff --git a/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp b/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp index 16a6da64..3111eecb 100644 --- a/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp +++ b/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp @@ -64,6 +64,8 @@ class RMSNormKernelTester { metal::Buffer input_buffer{device_, num_tokens() * num_channels() * sizeof(float)}; metal::Buffer weight_buffer{device_, num_channels() * sizeof(gptoss_bfloat16)}; metal::Buffer output_buffer{device_, num_tokens() * num_channels() * sizeof(float)}; + metal::Buffer control_buffer{device_, sizeof(gptoss_control)}; + std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control)); metal::CommandBuffer command_buffer{command_queue_}; @@ -90,6 +92,8 @@ class RMSNormKernelTester { /*weight_offset=*/0, output_buffer.handle(), /*output_offset=*/0, + control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_channels(), epsilon()), diff --git a/gpt_oss/metal/test/rope-kernel-tester.hpp b/gpt_oss/metal/test/rope-kernel-tester.hpp index 602912a1..cb930621 100644 --- a/gpt_oss/metal/test/rope-kernel-tester.hpp +++ b/gpt_oss/metal/test/rope-kernel-tester.hpp @@ -112,6 +112,8 @@ class RoPEKernelTester { metal::Buffer activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)}; metal::Buffer ref_activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)}; + metal::Buffer control_buffer{device_, sizeof(gptoss_control)}; + std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control)); metal::CommandBuffer command_buffer{command_queue_}; @@ -138,6 +140,9 @@ class RoPEKernelTester { f32_rope_fn_.handle(), threadgroup_size(), activations_buffer.handle(), + /*activations_offset=*/0, + control_buffer.handle(), + /*control_offset=*/0, frequency_base(), /*interpolation_scale=*/1.0f, /*yarn_offset=*/0.0f, diff --git a/gpt_oss/responses_api/inference/metal.py b/gpt_oss/responses_api/inference/metal.py index ec84af7e..9b62b660 100644 --- a/gpt_oss/responses_api/inference/metal.py +++ b/gpt_oss/responses_api/inference/metal.py @@ -5,22 +5,39 @@ from gpt_oss.metal import Context, Model +# Tunables +MAX_OUTPUT_TOKENS = 100 + + def setup_model(checkpoint: str) -> Callable[[list[int], float], int]: """Load the Metal model and return an inference function.""" model = Model(checkpoint) context = Context(model) + seed = 0 + output_tokens = [] + def infer_next_token( tokens: list[int], temperature: float = 0.0, new_request: bool = False ) -> int: """Infer next token using incremental LCP caching when possible.""" + nonlocal output_tokens + + if new_request: + output_tokens = [] + + if len(output_tokens) == 0: + # Context handles LCP caching internally; if `tokens` matches the + # tokens in the KV cache, the KV cache is reused after reset+append. + context.reset() + for t in tokens: + context.append(t) + + output_tokens = context.sample(max_output_tokens=MAX_OUTPUT_TOKENS, + temperature=temperature, + seed=seed) - # Context handles LCP caching internally; if `tokens` matches the - # tokens in the KV cache, the KV cache is reused after reset+append. - context.reset() - for t in tokens: - context.append(t) - return int(context.sample(temperature=temperature)) + return int(output_tokens.pop(0)) return infer_next_token From 1b5b45a77fb1175fc9b165d7efe8af84721e7569 Mon Sep 17 00:00:00 2001 From: ibahmed-oai Date: Tue, 9 Sep 2025 23:14:19 -0700 Subject: [PATCH 87/91] Adding prefill benchmarking for metal backend (#181) --- gpt_oss/metal/benchmark/end-to-end.cc | 155 ++++++++++++++++++++-- gpt_oss/metal/include/gpt-oss/functions.h | 6 +- gpt_oss/metal/python/model.c | 2 +- gpt_oss/metal/source/generate.c | 2 +- gpt_oss/metal/source/model.c | 5 +- 5 files changed, 154 insertions(+), 16 deletions(-) diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc index 0a242340..6637de67 100644 --- a/gpt_oss/metal/benchmark/end-to-end.cc +++ b/gpt_oss/metal/benchmark/end-to-end.cc @@ -2,9 +2,10 @@ #include #include -#include #include +#include #include +#include #include #include #include @@ -12,11 +13,9 @@ #include - constexpr std::uint32_t kNumGeneratedTokens = 100; - -static void end2end(benchmark::State& state, const char* env_var_name) { +static void end2end_decode(benchmark::State& state, const char* env_var_name) { const char* model_path = getenv(env_var_name); if (model_path == NULL) { state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); @@ -24,7 +23,7 @@ static void end2end(benchmark::State& state, const char* env_var_name) { } gptoss_model_t model_ptr = nullptr; - gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr); + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, 0); if (status != gptoss_status_success) { state.SkipWithError(std::format("failed to load model from file {}", model_path)); return; @@ -66,7 +65,7 @@ static void end2end(benchmark::State& state, const char* env_var_name) { do { std::size_t num_current_generated_tokens = 0; status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, - /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); if (status != gptoss_status_success) { state.SkipWithError("failed to sample from the Context object"); return; @@ -85,7 +84,7 @@ static void end2end(benchmark::State& state, const char* env_var_name) { do { std::size_t num_current_generated_tokens = 0; status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, - /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); if (status != gptoss_status_success) { state.SkipWithError("failed to sample from the Context object"); return; @@ -100,9 +99,143 @@ static void end2end(benchmark::State& state, const char* env_var_name) { benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); } -BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH") - ->UseRealTime()->Unit(benchmark::kMillisecond); -BENCHMARK_CAPTURE(end2end, gpt_oss_120b, "GPT_OSS_120B_PATH") - ->UseRealTime()->Unit(benchmark::kMillisecond); +static void end2end_prefill(benchmark::State& state, + const char* model_path_env_var_name, + const char* prompt_env_var_name, + size_t context_length = 0) { + const char* model_path = getenv(model_path_env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", + model_path_env_var_name)); + return; + } + + const char* prompt_file_path = getenv(prompt_env_var_name); + if (prompt_file_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", + prompt_env_var_name)); + return; + } + + // Read prompt contents from file into a std::string + std::ifstream prompt_file(prompt_file_path, + std::ios::in | std::ios::binary); + if (!prompt_file) { + state.SkipWithError( + std::format("failed to open prompt file {}", prompt_file_path)); + return; + } + std::string prompt_str; + prompt_file.seekg(0, std::ios::end); + std::streampos file_size = prompt_file.tellg(); + if (file_size < 0) { + state.SkipWithError(std::format("failed to read prompt file size {}", + prompt_file_path)); + return; + } + prompt_str.resize(static_cast(file_size)); + prompt_file.seekg(0, std::ios::beg); + if (file_size > 0) { + prompt_file.read(prompt_str.data(), file_size); + } + if (!prompt_file) { + state.SkipWithError( + std::format("failed to read prompt file {}", prompt_file_path)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = + gptoss_model_create_from_file(model_path, &model_ptr, 1024); + if (status != gptoss_status_success) { + state.SkipWithError( + std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, + decltype(&gptoss_model_release)> + model(model_ptr, gptoss_model_release); + + gptoss_tokenizer_t tokenizer_ptr = nullptr; + status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to retrieve Tokenizer"); + return; + } + std::unique_ptr, + decltype(&gptoss_tokenizer_release)> + tokenizer(tokenizer_ptr, gptoss_tokenizer_release); + + gptoss_context_t context_ptr = nullptr; + status = + gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, + decltype(&gptoss_context_release)> + context(context_ptr, gptoss_context_release); + + const char* prompt = prompt_str.c_str(); + status = gptoss_context_append_chars(context.get(), prompt, + prompt_str.size(), nullptr); + if (status != gptoss_status_success) { + state.SkipWithError(std::format( + "failed to tokenize prompt from file {}", prompt_file_path)); + return; + } + + size_t num_tokens; + status = gptoss_context_get_num_tokens(context.get(), &num_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to get number of tokens"); + return; + } + if (context_length != 0) { + assert(context_length <= num_tokens); + context->num_tokens = context_length; + } + // Prefill + for (auto _ : state) { + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + context->num_kv_tokens = 0; + } + + state.counters["tokens"] = num_tokens; + state.counters["tokens/s"] = benchmark::Counter( + state.iterations() * num_tokens, benchmark::Counter::kIsRate); +} + +// Decode end-to-end benchmark +BENCHMARK_CAPTURE(end2end_decode, gpt_oss_20b_decode, "GPT_OSS_20B_PATH") + ->UseRealTime() + ->Unit(benchmark::kMillisecond); +BENCHMARK_CAPTURE(end2end_decode, gpt_oss_120b_decode, "GPT_OSS_120B_PATH") + ->UseRealTime() + ->Unit(benchmark::kMillisecond); + +// Prefill end-to-end benchmark +BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_1024, + "GPT_OSS_120B_PATH", "GPT_OSS_PROMPT_FILE_PATH", 1024) + ->UseRealTime() + ->Unit(benchmark::kMillisecond); +BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_1024, "GPT_OSS_20B_PATH", + "GPT_OSS_PROMPT_FILE_PATH", 1024) + ->UseRealTime() + ->Unit(benchmark::kMillisecond); + +BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_3072, + "GPT_OSS_120B_PATH", "GPT_OSS_PROMPT_FILE_PATH", 3072) + ->UseRealTime() + ->Unit(benchmark::kMillisecond); +BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_3072, "GPT_OSS_20B_PATH", + "GPT_OSS_PROMPT_FILE_PATH", 3072) + ->UseRealTime() + ->Unit(benchmark::kMillisecond); BENCHMARK_MAIN(); diff --git a/gpt_oss/metal/include/gpt-oss/functions.h b/gpt_oss/metal/include/gpt-oss/functions.h index 6ddde253..5b0d83ea 100644 --- a/gpt_oss/metal/include/gpt-oss/functions.h +++ b/gpt_oss/metal/include/gpt-oss/functions.h @@ -15,13 +15,17 @@ extern "C" { * * @param path Path to the file containing the model in GPT-OSS format. * @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model. + * @param max_batch_tokens Maximum number of tokens that can be processed in a single batch. + * Larger values may improve prefill performance, but require more memory. + * Specify 0 to use the default value. * * On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument. * On failure, returns an error code and stores null pointer in the model_out argument. */ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( const char* path, - gptoss_model_t* model_out); + gptoss_model_t* model_out, + size_t max_batch_tokens); /* * Query the Tokenizer object associated with the Model. diff --git a/gpt_oss/metal/python/model.c b/gpt_oss/metal/python/model.c index 49202a2c..a1713be7 100644 --- a/gpt_oss/metal/python/model.c +++ b/gpt_oss/metal/python/model.c @@ -12,7 +12,7 @@ static int PyGPTOSSModel_init(PyGPTOSSModel* self, PyObject* args, PyObject* kwa if (!PyArg_ParseTuple(args, "s", &filepath)) { return -1; } - status = gptoss_model_create_from_file(filepath, &self->handle); + status = gptoss_model_create_from_file(filepath, &self->handle, 0); if (status != gptoss_status_success) { // TODO: set exception return -1; diff --git a/gpt_oss/metal/source/generate.c b/gpt_oss/metal/source/generate.c index 36a5527b..63a8569c 100644 --- a/gpt_oss/metal/source/generate.c +++ b/gpt_oss/metal/source/generate.c @@ -200,7 +200,7 @@ int main(int argc, char *argv[]) { struct options options = parse_options(argc, argv); const uint64_t load_start_time = mach_continuous_time(); - status = gptoss_model_create_from_file(options.model, &model); + status = gptoss_model_create_from_file(options.model, &model, 0); if (status != gptoss_status_success) { fprintf(stderr, "Error: failed to load model from file %s\n", options.model); goto error; diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index 7a0450ce..ef346a89 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -79,7 +79,8 @@ static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) { enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( const char* path, - gptoss_model_t* model_out) + gptoss_model_t* model_out, + size_t max_batch_tokens) { *model_out = NULL; @@ -192,7 +193,7 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( model->yarn_multiplier = model_header.yarn_multiplier; model->rmsnorm_epsilon = model_header.rmsnorm_epsilon; - model->max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE; + model->max_batch_tokens = max_batch_tokens == 0 ? GPTOSS_DEFAULT_BATCH_SIZE : max_batch_tokens; struct gptoss_uuid tokenizer_uuid; status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path); From 0b1fb061e0859f12aab5a1ffd60172ab39e11ab8 Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Wed, 10 Sep 2025 14:58:17 -0700 Subject: [PATCH 88/91] Metal: tune threadgroup sizes (#180) --- gpt_oss/metal/CMakeLists.txt | 4 + .../metal/benchmark/end-to-end-threadgroup.cc | 590 ++++++++++++++++++ gpt_oss/metal/benchmark/end-to-end.cc | 20 - gpt_oss/metal/source/context.c | 16 +- gpt_oss/metal/source/include/internal/model.h | 9 + gpt_oss/metal/source/model.c | 10 + 6 files changed, 621 insertions(+), 28 deletions(-) create mode 100644 gpt_oss/metal/benchmark/end-to-end-threadgroup.cc diff --git a/gpt_oss/metal/CMakeLists.txt b/gpt_oss/metal/CMakeLists.txt index d18708f4..52f83b0f 100644 --- a/gpt_oss/metal/CMakeLists.txt +++ b/gpt_oss/metal/CMakeLists.txt @@ -151,6 +151,10 @@ add_executable(end-to-end-bench benchmark/end-to-end.cc) target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss) target_include_directories(end-to-end-bench PRIVATE source/include) +add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc) +target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss) +target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include) + # --- [ Python extension ] ----------------------------------------------- find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module diff --git a/gpt_oss/metal/benchmark/end-to-end-threadgroup.cc b/gpt_oss/metal/benchmark/end-to-end-threadgroup.cc new file mode 100644 index 00000000..93fb1647 --- /dev/null +++ b/gpt_oss/metal/benchmark/end-to-end-threadgroup.cc @@ -0,0 +1,590 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +constexpr std::uint32_t kNumGeneratedTokens = 100; + + +static void attn_qkv_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->attn_qkv_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void AttnQKVThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto attn_qkv_threadgroup_size = 32; attn_qkv_threadgroup_size <= 1024; attn_qkv_threadgroup_size += 32) { + const auto num_simdgroups = attn_qkv_threadgroup_size / 32; + if (5120 % num_simdgroups != 0) { + // Skip incompatible threadgroup sizes + continue; + } + b->Args({attn_qkv_threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments); +BENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments); + +static void attn_out_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->attn_out_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void AttnOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto attn_out_threadgroup_size = 32; attn_out_threadgroup_size <= 1024; attn_out_threadgroup_size += 32) { + const auto num_simdgroups = attn_out_threadgroup_size / 32; + if (2880 % num_simdgroups != 0) { + // Skip incompatible threadgroup sizes + continue; + } + b->Args({attn_out_threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments); +BENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments); + +static void mlp_gate_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->mlp_gate_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void MlpGateThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto mlp_gate_threadgroup_size = 32; mlp_gate_threadgroup_size <= 1024; mlp_gate_threadgroup_size += 32) { + const auto num_simdgroups = mlp_gate_threadgroup_size / 32; + if (128 % num_simdgroups != 0) { + // Skip incompatible threadgroup sizes + continue; + } + b->Args({mlp_gate_threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments); +BENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments); + +static void mlp_swiglu_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->mlp_swiglu_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void MlpSwigluThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) { + const auto num_simdgroups = threadgroup_size / 32; + if (5760 % num_simdgroups != 0) { + // Skip incompatible threadgroup sizes + continue; + } + b->Args({threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments); +BENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments); + +static void mlp_out_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->mlp_out_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void MlpOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) { + const auto num_simdgroups = threadgroup_size / 32; + if (5760 % num_simdgroups != 0) { + // Skip incompatible threadgroup sizes + continue; + } + b->Args({threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments); +BENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments); + +static void mlp_acc_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->mlp_acc_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void MlpAccThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) { + b->Args({threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments); +BENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments); + +static void unembedding_tgsize(benchmark::State& state, const char* env_var_name) { + const char* model_path = getenv(env_var_name); + if (model_path == NULL) { + state.SkipWithError(std::format("environment variable {} is not set", env_var_name)); + return; + } + + gptoss_model_t model_ptr = nullptr; + gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to load model from file {}", model_path)); + return; + } + std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release); + model->unembedding_threadgroup_size = static_cast(state.range(0)); + + gptoss_context_t context_ptr = nullptr; + status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr); + if (status != gptoss_status_success) { + state.SkipWithError("failed to create Context object"); + return; + } + std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release); + + const char* prompt = "why did the chicken cross the road?"; + std::size_t num_prompt_tokens = 0; + status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens); + if (status != gptoss_status_success) { + state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt)); + return; + } + + // Prefill + status = gptoss_context_process(context.get()); + if (status != gptoss_status_success) { + state.SkipWithError("failed to prefill Context object"); + return; + } + const std::size_t num_kvcache_tokens = context->num_kv_tokens; + + std::uint64_t rng_seed = 0; + for (auto _ : state) { + const std::uint64_t current_rng_seed = rng_seed++; + context->num_kv_tokens = num_prompt_tokens; + context->num_tokens = num_prompt_tokens; + + std::array tokens; + std::size_t num_generated_tokens = 0; + do { + std::size_t num_current_generated_tokens = 0; + status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, + /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); + if (status != gptoss_status_success) { + state.SkipWithError("failed to sample from the Context object"); + return; + } + num_generated_tokens += num_current_generated_tokens; + } while (num_generated_tokens < kNumGeneratedTokens); + } + + state.counters["generations"] = + benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate); + state.counters["tokens"] = + benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate); +} + +static void UnembeddingThreadgroupSizeArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"tgsize"}); + for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) { + b->Args({threadgroup_size}); + } +} + +BENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments); +BENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH") + ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments); + +BENCHMARK_MAIN(); diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc index 6637de67..b0f4367c 100644 --- a/gpt_oss/metal/benchmark/end-to-end.cc +++ b/gpt_oss/metal/benchmark/end-to-end.cc @@ -52,27 +52,7 @@ static void end2end_decode(benchmark::State& state, const char* env_var_name) { state.SkipWithError("failed to prefill Context object"); return; } - const std::size_t num_kvcache_tokens = context->num_kv_tokens; - std::uint64_t rng_seed = 0; - for (std::uint32_t i = 0; i < 3; i++) { - const std::uint64_t current_rng_seed = rng_seed++; - context->num_kv_tokens = num_prompt_tokens; - context->num_tokens = num_prompt_tokens; - - std::array tokens; - std::size_t num_generated_tokens = 0; - do { - std::size_t num_current_generated_tokens = 0; - status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed, - /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens); - if (status != gptoss_status_success) { - state.SkipWithError("failed to sample from the Context object"); - return; - } - num_generated_tokens += num_current_generated_tokens; - } while (num_generated_tokens < kNumGeneratedTokens); - } for (auto _ : state) { const std::uint64_t current_rng_seed = rng_seed++; diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index 2d246294..2b82d81d 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -193,7 +193,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings( command_buffer, &model->bf16_f32_embeddings_fn, - /*threadgroup_size=*/512, + model->embeddings_threadgroup_size, &context->token_buffer, input_batch_start * sizeof(uint32_t), &model->shared_weight_buffer, @@ -234,7 +234,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( command_buffer, &model->f32_bf16w_matmul_fn, - /*threadgroup_size=*/256, + model->attn_qkv_threadgroup_size, &context->rmsnorm_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, @@ -318,7 +318,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( command_buffer, &model->f32_bf16w_matmul_fn, - /*threadgroup_size=*/256, + model->attn_out_threadgroup_size, &context->sdpa_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, @@ -359,7 +359,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( command_buffer, &model->f32_bf16w_matmul_fn, - /*threadgroup_size=*/256, + model->mlp_gate_threadgroup_size, &context->rmsnorm_activation_buffer, /*input_offset=*/0, &model->shared_weight_buffer, @@ -417,7 +417,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu( command_buffer, &model->f32_mf4w_moe_matmul_swiglu_fn, - /*threadgroup_size=*/512, + model->mlp_swiglu_threadgroup_size, &context->rmsnorm_activation_buffer, /*input_offset=*/0, &context->expert_activation_buffer, @@ -446,7 +446,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul( command_buffer, &model->f32_mf4w_moe_matmul_fn, - /*threadgroup_size=*/512, + model->mlp_out_threadgroup_size, &context->swiglu_activation_buffer, /*input_offset=*/0, &context->expert_activation_buffer, @@ -474,7 +474,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_accumulate( command_buffer, &model->f32_accumulate_e4_fn, - /*threadgroup_size=*/256, + model->mlp_acc_threadgroup_size, model->max_threadgroups, &context->moe_activation_buffer, /*input_offset=*/0, @@ -528,7 +528,7 @@ static enum gptoss_status process_tokens( status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( command_buffer, &model->f32_bf16w_unembedding_fn, - /*threadgroup_size=*/256, + model->unembedding_threadgroup_size, model->max_threadgroups, &context->rmsnorm_activation_buffer, /*input_offset=*/0, diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index 34e273aa..8b382221 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -92,6 +92,15 @@ struct gptoss_model { size_t per_block_shared_weights_size; size_t per_expert_block_weight_size; + size_t embeddings_threadgroup_size; + size_t attn_qkv_threadgroup_size; + size_t attn_out_threadgroup_size; + size_t mlp_gate_threadgroup_size; + size_t mlp_swiglu_threadgroup_size; + size_t mlp_out_threadgroup_size; + size_t mlp_acc_threadgroup_size; + size_t unembedding_threadgroup_size; + size_t attn_rmsnorm_gain_offset; size_t attn_qkv_weight_offset; size_t attn_qkv_bias_offset; diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index ef346a89..3def7967 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -366,6 +366,16 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( goto cleanup; } + // Kernel launch parameters + model->embeddings_threadgroup_size = 512; + model->attn_qkv_threadgroup_size = 1024; + model->attn_out_threadgroup_size = 768; + model->mlp_gate_threadgroup_size = 256; + model->mlp_swiglu_threadgroup_size = 192; + model->mlp_out_threadgroup_size = 192; + model->mlp_acc_threadgroup_size = 768; + model->unembedding_threadgroup_size = 416; + // Weight buffers const char* current_ptr = (const char*) model->mapping_ptr; From bbc5c482418409c443969a2c2b9428553bdafb5b Mon Sep 17 00:00:00 2001 From: ibahmed-oai Date: Wed, 10 Sep 2025 17:07:29 -0700 Subject: [PATCH 89/91] Metal: Adding optimized dense matmul kernel to optimize prefill perf (#183) --- gpt_oss/metal/source/context.c | 193 ++++++++++++------ .../source/include/internal/kernel-args.h | 25 +++ .../source/include/internal/metal-kernels.h | 54 +++++ gpt_oss/metal/source/include/internal/model.h | 3 + gpt_oss/metal/source/matmul.metal | 189 +++++++++++++++++ gpt_oss/metal/source/metal-kernels.c | 176 ++++++++++++++++ gpt_oss/metal/source/model.c | 15 ++ gpt_oss/metal/test/f32-bf16w-matmul.cc | 27 +++ gpt_oss/metal/test/matmul-kernel-tester.hpp | 161 ++++++++++++--- 9 files changed, 752 insertions(+), 91 deletions(-) diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index 2b82d81d..704ff7f4 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -175,6 +175,7 @@ static enum gptoss_status process_tokens( assert(num_input_tokens <= context->max_batch_tokens); assert(num_output_tokens <= context->max_batch_tokens); assert(num_input_tokens >= num_output_tokens); + const size_t dense_matmul_kernel_token_multiple_constraint = 64; enum gptoss_status status = gptoss_status_success; const struct gptoss_model* model = context->model; @@ -231,28 +232,50 @@ static enum gptoss_status process_tokens( return status; } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - command_buffer, - &model->f32_bf16w_matmul_fn, - model->attn_qkv_threadgroup_size, - &context->rmsnorm_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n, - &model->shared_weight_buffer, - /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, - &context->qkv_activation_buffer, - /*output_offset=*/0, - &context->control_buffer, - /*control_offset=*/0, - /*num_tokens=*/input_batch_size, - /*num_cols=*/model->embedding_dim, - /*num_rows=*/attn_qkv_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); - return status; + if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv( + command_buffer, + &model->f32_bf16w_dense_matmul_qkv_fn, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, + &context->qkv_activation_buffer, + /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + /*num_tokens=*/input_batch_size, + /*num_cols=*/model->embedding_dim, + /*num_rows=*/attn_qkv_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch"); + return status; + } + } else { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( + command_buffer, + &model->f32_bf16w_matmul_fn, + model->attn_qkv_threadgroup_size, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, + &context->qkv_activation_buffer, + /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + /*num_tokens=*/input_batch_size, + /*num_cols=*/model->embedding_dim, + /*num_rows=*/attn_qkv_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); + return status; + } } - status = gptoss_metal_command_buffer_encode_launch_f32_rope( command_buffer, &model->f32_rope_fn, @@ -315,28 +338,50 @@ static enum gptoss_status process_tokens( return status; } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( - command_buffer, - &model->f32_bf16w_matmul_fn, - model->attn_out_threadgroup_size, - &context->sdpa_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n, - &model->shared_weight_buffer, - /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, - &context->residual_activation_buffer, - /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), - &context->control_buffer, - /*control_offset=*/0, - /*num_tokens=*/num_block_output_tokens, - /*num_cols=*/model->num_heads * model->head_dim, - /*num_rows=*/model->embedding_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch"); - return status; + if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output( + command_buffer, + &model->f32_bf16w_dense_matmul_attn_output_fn, + &context->sdpa_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, + &context->residual_activation_buffer, + /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &context->control_buffer, + /*control_offset=*/0, + /*num_tokens=*/num_block_output_tokens, + /*num_cols=*/model->num_heads * model->head_dim, + /*num_rows=*/model->embedding_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch"); + return status; + } + } else { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( + command_buffer, + &model->f32_bf16w_matmul_fn, + model->attn_out_threadgroup_size, + &context->sdpa_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n, + &context->residual_activation_buffer, + /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), + &context->control_buffer, + /*control_offset=*/0, + /*num_tokens=*/num_block_output_tokens, + /*num_cols=*/model->num_heads * model->head_dim, + /*num_rows=*/model->embedding_dim); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch"); + return status; + } } - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm( command_buffer, &model->f32_bf16w_rmsnorm_fn, @@ -355,27 +400,49 @@ static enum gptoss_status process_tokens( GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch"); return status; } - - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - command_buffer, - &model->f32_bf16w_matmul_fn, - model->mlp_gate_threadgroup_size, - &context->rmsnorm_activation_buffer, - /*input_offset=*/0, - &model->shared_weight_buffer, - /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n, - &model->shared_weight_buffer, - /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, - &context->gate_activation_buffer, - /*output_offset=*/0, - &context->control_buffer, - /*control_offset=*/0, - /*num_tokens=*/num_block_output_tokens, - /*num_cols=*/model->embedding_dim, - /*num_rows=*/model->num_experts); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); - return status; + if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate( + command_buffer, + &model->f32_bf16w_dense_matmul_mlp_gate_fn, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, + &context->gate_activation_buffer, + /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + num_block_output_tokens, + model->embedding_dim, + model->num_experts); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch"); + return status; + } + } else { + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( + command_buffer, + &model->f32_bf16w_matmul_fn, + model->mlp_gate_threadgroup_size, + &context->rmsnorm_activation_buffer, + /*input_offset=*/0, + &model->shared_weight_buffer, + /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n, + &model->shared_weight_buffer, + /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n, + &context->gate_activation_buffer, + /*output_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + /*num_tokens=*/num_block_output_tokens, + /*num_cols=*/model->embedding_dim, + /*num_rows=*/model->num_experts); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); + return status; + } } const char* kernel_name = NULL; diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h index 259eaa8a..7dc0b564 100644 --- a/gpt_oss/metal/source/include/internal/kernel-args.h +++ b/gpt_oss/metal/source/include/internal/kernel-args.h @@ -4,6 +4,25 @@ #include #endif +// TODO(ibahmed): specalize using metal function constants. +#define QKV_Bm 64 +#define QKV_Bn 64 +#define QKV_Bk 32 +#define QKV_Sg_Bm 32 +#define QKV_Sg_Bn 32 + +#define ATTN_OUTPUT_Bm 32 +#define ATTN_OUTPUT_Bn 64 +#define ATTN_OUTPUT_Bk 64 +#define ATTN_OUTPUT_Sg_Bm 32 +#define ATTN_OUTPUT_Sg_Bn 16 + +#define MLP_GATE_Bm 64 +#define MLP_GATE_Bn 16 +#define MLP_GATE_Bk 64 +#define MLP_GATE_Sg_Bm 16 +#define MLP_GATE_Sg_Bn 16 + struct gptoss_expert_prediction { uint32_t expert_id; float score; @@ -66,6 +85,12 @@ struct gptoss_matmul_args { uint32_t add; }; +struct gptoss_dense_matmul_args { + uint32_t m; + uint32_t n; + uint32_t k; +}; + struct gptoss_unembedding_args { uint32_t num_column_vecs; uint32_t num_rows_per_threadgroup; diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h index 269f025d..6837bbd2 100644 --- a/gpt_oss/metal/source/include/internal/metal-kernels.h +++ b/gpt_oss/metal/source/include/internal/metal-kernels.h @@ -130,6 +130,60 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad uint32_t num_cols, uint32_t num_rows); +enum gptoss_status +gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows); + +enum gptoss_status +gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows); + +enum gptoss_status +gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows); + enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( const struct gptoss_metal_command_buffer* command_buffer, const struct gptoss_metal_function* f32_bf16w_matmul_fn, diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index 8b382221..b2c080b1 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -78,6 +78,9 @@ struct gptoss_model { struct gptoss_metal_function bf16_f32_embeddings_fn; struct gptoss_metal_function f32_bf16w_rmsnorm_fn; struct gptoss_metal_function f32_bf16w_matmul_fn; + struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn; + struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn; + struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn; struct gptoss_metal_function f32_bf16w_unembedding_fn; struct gptoss_metal_function f32_rope_fn; struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn; diff --git a/gpt_oss/metal/source/matmul.metal b/gpt_oss/metal/source/matmul.metal index a4ec60d5..1fc3c408 100644 --- a/gpt_oss/metal/source/matmul.metal +++ b/gpt_oss/metal/source/matmul.metal @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -143,3 +144,191 @@ kernel void gptoss_f32_bf16w_unembedding( } } } + +// Current constraints for the dense matmul kernel: +// 1- All B* and Sg_* are a multiple of 8. +// 2- Bm is divisible by Sg_n and Bn is divisible by Sg_n. +// 3- M, N and K are all divisible by 8.. +template +inline void _gptoss_f32_bf16w_dense_matmul_impl( + constant gptoss_dense_matmul_args& args, const device float* lhs, + const device bfloat* rhs, const device bfloat* __restrict__ bias, + device float* out, const device gptoss_control* control, threadgroup float* scratch, threadgroup float* bias_tile, + uint sg_id, uint sg_count_per_tg, uint3 gid, uint3 tg_id, uint3 local_tid, + uint3 threadgroup_size) { + + if (control->abort != 0) { + return; + } + + // The kernel assumes that M, K, and N are divisible by 8. + const uint M = args.m; + const uint K = args.k; + const uint N = args.n; + static_assert((Bm % 8u) == 0u, "Bm must be a multiple of 8"); + static_assert((Bn % 8u) == 0u, "Bn must be a multiple of 8"); + static_assert((Bk % 8u) == 0u, "Bk must be a multiple of 8"); + static_assert((Sg_Bm % 8u) == 0u, "Bk must be a multiple of 8"); + static_assert((Sg_Bn % 8u) == 0u, "Bk must be a multiple of 8"); + static_assert((Bn % Sg_Bn) == 0u, "Bn must be a multiple of Sg_Bn"); + static_assert((Bm % Sg_Bm) == 0u, "Bm must be a multiple of Sg_Bm"); + + // Get row and col tg. + const uint row_tg = tg_id.y; + const uint col_tg = tg_id.x; + // Get row and col local tid. + const uint row_tg_offset = row_tg * Bm; + const uint col_tg_offset = col_tg * Bn; + + const uint sg_col_count = Bn / Sg_Bn; + const uint row_sg = sg_id / sg_col_count; + const uint col_sg = sg_id % sg_col_count; + + const uint row_sg_offset = row_sg * Sg_Bm; + const uint col_sg_offset = col_sg * Sg_Bn; + constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8); + // Create an array of simdgroup_float8x8 to hold temp results. + metal::simdgroup_float8x8 OutTiles[temp_result_size]; +#pragma clang loop unroll(full) + for (uint i = 0; i < temp_result_size; i++) { + OutTiles[i] = metal::make_filled_simdgroup_matrix( + static_cast(0.0)); + } + + for (uint k_offset = 0; k_offset < K; k_offset += Bk) { +#pragma clang loop unroll(full) + for (uint k = 0; k < Bk; k += 8) { +#pragma clang loop unroll(full) + for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) { + // const uint m_subtile = row_sg_offset + m_subtile_; + // const uint row_index_in_out_tile = (m_subtile - row_sg_offset) / 8; + const uint row_index_in_out_tile = m_subtile_ / 8; + metal::simdgroup_float8x8 LHStile; + const uint k_id = k + k_offset; + const uint row_offset = row_tg_offset + row_sg_offset + m_subtile_; + metal::simdgroup_load(LHStile, lhs, K, ulong2(k_id, row_offset)); + metal::simdgroup_bfloat8x8 RHStile; +#pragma clang loop unroll(full) + for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) { + const uint col_index_in_out_tile = n_subtile_ / 8; + const uint current_index_out_tile = + row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile; + const uint col_offset = col_tg_offset + col_sg_offset + n_subtile_; + simdgroup_load(RHStile, rhs, K, ulong2(k_id, col_offset), /*transpose=*/true); + // If rhs was not transposed, use the following instead: + // simdgroup_load(RHStile, rhs, N, ulong2(col_offset, k_id)); + simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], + LHStile, RHStile, + OutTiles[current_index_out_tile]); + } + } + } + } + // Epilogue. +#pragma clang loop unroll(full) + for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) { + const uint col_index_in_out_tile = n_subtile_ / 8; + const uint local_col_offset = col_sg_offset + n_subtile_; +#pragma clang loop unroll(full) + for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) { + const uint row_index_in_out_tile = m_subtile_ / 8; + const uint local_row_offset = row_sg_offset + m_subtile_; + const uint current_index_out_tile = + row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile; + simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn, + ulong2(local_col_offset, local_row_offset)); + } + } + // TODO(ibahmed): vectorize these loads an maybe unroll the loop. + const uint thread_count_per_tg = + threadgroup_size.x * threadgroup_size.y * threadgroup_size.z; + for (uint c_local = local_tid.x; c_local < Bn; + c_local += thread_count_per_tg) { + const uint c_global = col_tg_offset + c_local; + bias_tile[c_local] = + (c_global < N) ? static_cast(bias[c_global]) : 0.0f; + } + + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + + // TODO(ibahmed): vectorize these stores and maybe unroll the loop. + for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) { + const uint r = idx / Bn; + const uint c = idx % Bn; + + const uint out_row = row_tg_offset + r; + const uint out_col = col_tg_offset + c; + + if (out_row < M && out_col < N) { + float acc = scratch[idx] + bias_tile[c]; + if (add) { + acc += out[out_row * N + out_col]; + } + out[out_row * N + out_col] = acc; + } + } +} + +kernel void gptoss_f32_bf16w_dense_matmul_qkv( + constant gptoss_dense_matmul_args& args [[buffer(0)]], + const device float* lhs [[buffer(1)]], + const device bfloat* rhs [[buffer(2)]], + const device bfloat* __restrict__ bias [[buffer(3)]], + device float* out [[buffer(4)]], + const device gptoss_control* control [[buffer(5)]], + uint sg_id [[simdgroup_index_in_threadgroup]], + uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 tg_id [[threadgroup_position_in_grid]], + uint3 local_tid [[thread_position_in_threadgroup]], + uint3 threadgroup_size [[threads_per_threadgroup]]) { + threadgroup float scratch[QKV_Bm * QKV_Bn]; + threadgroup float bias_tile[QKV_Bn]; + _gptoss_f32_bf16w_dense_matmul_impl( + args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg, + gid, tg_id, local_tid, threadgroup_size); +} + +kernel void gptoss_f32_bf16w_dense_matmul_attn_output( + constant gptoss_dense_matmul_args& args [[buffer(0)]], + const device float* lhs [[buffer(1)]], + const device bfloat* rhs [[buffer(2)]], + const device bfloat* __restrict__ bias [[buffer(3)]], + device float* out [[buffer(4)]], + const device gptoss_control* control [[buffer(5)]], + uint sg_id [[simdgroup_index_in_threadgroup]], + uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 tg_id [[threadgroup_position_in_grid]], + uint3 local_tid [[thread_position_in_threadgroup]], + uint3 threadgroup_size [[threads_per_threadgroup]]) { + threadgroup float scratch[ATTN_OUTPUT_Bm * ATTN_OUTPUT_Bn]; + threadgroup float bias_tile[ATTN_OUTPUT_Bn]; + _gptoss_f32_bf16w_dense_matmul_impl( + args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg, + gid, tg_id, local_tid, threadgroup_size); +} + +kernel void gptoss_f32_bf16w_dense_matmul_mlp_gate( + constant gptoss_dense_matmul_args& args [[buffer(0)]], + const device float* lhs [[buffer(1)]], + const device bfloat* rhs [[buffer(2)]], + const device bfloat* __restrict__ bias [[buffer(3)]], + device float* out [[buffer(4)]], + const device gptoss_control* control [[buffer(5)]], + uint sg_id [[simdgroup_index_in_threadgroup]], + uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 tg_id [[threadgroup_position_in_grid]], + uint3 local_tid [[thread_position_in_threadgroup]], + uint3 threadgroup_size [[threads_per_threadgroup]]) { + threadgroup float scratch[MLP_GATE_Bm * MLP_GATE_Bn]; + threadgroup float bias_tile[MLP_GATE_Bn]; + _gptoss_f32_bf16w_dense_matmul_impl( + args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg, + gid, tg_id, local_tid, threadgroup_size); +} diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c index 1316fa50..dd5e640a 100644 --- a/gpt_oss/metal/source/metal-kernels.c +++ b/gpt_oss/metal/source/metal-kernels.c @@ -401,6 +401,182 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad /*threadgroup_buffer_size=*/0); } +enum gptoss_status _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows, + uint32_t Bm, + uint32_t Bn, + uint32_t Bk, + uint32_t Sg_Bm, + uint32_t Sg_Bn) +{ + + if (command_buffer->object == NULL || f32_bf16w_dense_matmul_fn->pipeline_state_object == NULL) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: invalid command buffer or pipeline state object"); + return gptoss_status_invalid_state; + } + + if (num_cols % 8 != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 8", + num_cols); + return gptoss_status_invalid_argument; + } + if (num_rows % 8 != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by 8", + num_rows); + return gptoss_status_invalid_argument; + } + if (num_tokens % 8 != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of tokens (%" PRIu32 ") is not divisible by 8", + num_tokens); + return gptoss_status_invalid_argument; + } + + const struct gptoss_dense_matmul_args args = { + .m = num_tokens, + .n = num_rows, + .k = num_cols, + }; + const size_t threads_per_simdgroup = f32_bf16w_dense_matmul_fn->simdgroup_threads; + const uint32_t m = args.m; + const uint32_t n = args.n; + const uint32_t k = args.k; + if (Bm % Sg_Bm != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")", + Bm, Sg_Bm); + return gptoss_status_invalid_argument; + } + if (Bn % Sg_Bn != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")", + Bn, Sg_Bn); + return gptoss_status_invalid_argument; + } + const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup; + const size_t threadgroup_size_y = 1; + const size_t threadgroup_size_z = 1; + const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z; + if (total_threadgroup_size > f32_bf16w_dense_matmul_fn->max_threadgroup_threads) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)", + total_threadgroup_size, f32_bf16w_dense_matmul_fn->max_threadgroup_threads); + return gptoss_status_invalid_argument; + } + if (m % Bm != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: m (%" PRIu32 ") is not divisible by Bm (%" PRIu32 ")", + m, Bm); + return gptoss_status_invalid_argument; + } + if (n % Bn != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")", + n, Bn); + return gptoss_status_invalid_argument; + } + if (k % Bk != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")", + k, Bk); + return gptoss_status_invalid_argument; + } + const size_t grid_x = n / Bn; + const size_t grid_y = m / Bm; + const size_t grid_z = 1; + + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_bf16w_dense_matmul_fn, + threadgroup_size_x, threadgroup_size_y, threadgroup_size_z, + grid_x, grid_y, grid_z, + sizeof(args), &args, + 5, + (const struct gptoss_metal_buffer *[]){input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer}, + (const size_t[]){input_offset, weight_offset, bias_offset, output_offset, control_offset}, + /*threadgroup_buffer_size=*/0); + return gptoss_status_success; +} + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows) +{ + return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl( + command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset, + weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer, + output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, QKV_Bm, QKV_Bn, QKV_Bk, + QKV_Sg_Bm, QKV_Sg_Bn); +} + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows) +{ + return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl( + command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset, + weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer, + output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, ATTN_OUTPUT_Bm, + ATTN_OUTPUT_Bn, ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm, ATTN_OUTPUT_Sg_Bn); +} + +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_rows) +{ + return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl( + command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset, + weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer, + output_offset, control_buffer, control_offset, num_tokens, num_cols, + num_rows, MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk, MLP_GATE_Sg_Bm, + MLP_GATE_Sg_Bn); +} + enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding( const struct gptoss_metal_command_buffer* command_buffer, const struct gptoss_metal_function* f32_bf16w_unembedding_fn, diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index 3def7967..5f79c030 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -325,6 +325,18 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( if (status != gptoss_status_success) { goto cleanup; } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_attn_output", &model->f32_bf16w_dense_matmul_attn_output_fn); + if (status != gptoss_status_success) { + goto cleanup; + } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_mlp_gate", &model->f32_bf16w_dense_matmul_mlp_gate_fn); + if (status != gptoss_status_success) { + goto cleanup; + } status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn); if (status != gptoss_status_success) { goto cleanup; @@ -502,6 +514,9 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release( gptoss_metal_function_release(&model->bf16_f32_embeddings_fn); gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn); gptoss_metal_function_release(&model->f32_bf16w_matmul_fn); + gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn); + gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn); + gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn); gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn); gptoss_metal_function_release(&model->f32_rope_fn); gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn); diff --git a/gpt_oss/metal/test/f32-bf16w-matmul.cc b/gpt_oss/metal/test/f32-bf16w-matmul.cc index 9692e6a7..745bff2e 100644 --- a/gpt_oss/metal/test/f32-bf16w-matmul.cc +++ b/gpt_oss/metal/test/f32-bf16w-matmul.cc @@ -58,3 +58,30 @@ TEST(F32_BF16W_MATMUL, multiple_tokens) { .threadgroup_size(threadgroup_size) .TestF32_BF16W(); } + +TEST(F32_BF16W_DENSE_MATMUL_QKV, seq_len_1024) { + MatMulKernelTester() + .num_tokens(1024) + .num_rows(5120) + .num_cols(2880) + .TestF32_BF16W( + MatMulKernelTester::MatMulKernelType::PREFILL_QKV_OPTIMIZED); +} + +TEST(F32_BF16W_DENSE_MATMUL_ATTN_OUTPUT, seq_len_1024) { + MatMulKernelTester() + .num_tokens(1024) + .num_rows(2880) + .num_cols(4096) + .TestF32_BF16W( + MatMulKernelTester::MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED); +} + +TEST(F32_BF16W_DENSE_MATMUL_MLP_GATE, seq_len_1024) { + MatMulKernelTester() + .num_tokens(1024) + .num_rows(128) + .num_cols(2880) + .TestF32_BF16W( + MatMulKernelTester::MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED); +} \ No newline at end of file diff --git a/gpt_oss/metal/test/matmul-kernel-tester.hpp b/gpt_oss/metal/test/matmul-kernel-tester.hpp index 30826f70..f5958c7b 100644 --- a/gpt_oss/metal/test/matmul-kernel-tester.hpp +++ b/gpt_oss/metal/test/matmul-kernel-tester.hpp @@ -10,9 +10,39 @@ #include #include - namespace gptoss { +template +::testing::AssertionResult +IsNearAbsRel(const char* a_expr, const char* b_expr, const char* abs_expr, + const char* rel_expr, T a, T b, T abs_tol, T rel_tol = 1.0) { + + using std::abs; + if (!std::isfinite(a) || !std::isfinite(b)) { + return ::testing::AssertionFailure() + << "Non-finite value(s): " << a_expr << "=" << a << ", " << b_expr + << "=" << b; + // At least one of abs_tol and rel_tol must be provided + } + const T diff = abs(a - b); + const T rel = rel_tol * std::max(abs(a), abs(b)); + const T thr = std::max(abs_tol, rel); + + if (diff <= thr) + return ::testing::AssertionSuccess(); + + return ::testing::AssertionFailure() + << a_expr << " vs " << b_expr << " differ by " << diff + << " > max(abs_tol=" << abs_tol << ", rel_tol*max(|a|,|b|)=" << rel + << ") with " << abs_expr << "=" << abs_tol << ", " << rel_expr << "=" + << rel_tol << ". \n" + << a_expr << "=" << a << ". \n" + << b_expr << "=" << b; +} + +#define ASSERT_NEAR_ABS_REL(a, b, abs_tol, rel_tol) \ + ASSERT_PRED_FORMAT4(IsNearAbsRel, a, b, abs_tol, rel_tol) + class MatMulKernelTester { public: MatMulKernelTester() { } @@ -70,18 +100,26 @@ class MatMulKernelTester { ASSERT_NE(threadgroup_size(), 0); } - void TestF32_BF16W() const { + enum class MatMulKernelType { + DECODE_OPTIMIZED, + PREFILL_QKV_OPTIMIZED, + PREFILL_ATTN_OUTPUT_OPTIMIZED, + PREFILL_MLP_GATE_OPTIMIZED, + }; + + void TestF32_BF16W(MatMulKernelType kernel_type = MatMulKernelType::DECODE_OPTIMIZED) const { Validate(/*vec_size=*/4); - metal::CommandBuffer command_buffer{command_queue_}; + metal::CommandBuffer command_buffer_initialize{command_queue_}; metal::Buffer input_buffer{device_, num_tokens() * num_cols() * sizeof(float)}; metal::Buffer weight_buffer{device_, num_rows() * num_cols() * sizeof(gptoss_bfloat16)}; metal::Buffer bias_buffer{device_, num_rows() * sizeof(gptoss_bfloat16)}; metal::Buffer output_buffer{device_, num_tokens() * num_rows() * sizeof(float)}; + metal::Buffer output_buffer_copy{device_, num_tokens() * num_rows() * sizeof(float)}; metal::Buffer control_buffer{device_, sizeof(gptoss_control)}; std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control)); - command_buffer.encode_launch_f32_fill_random( + command_buffer_initialize.encode_launch_f32_fill_random( f32_fill_random_fn_, /*threadgroup_size=*/0, /*max_threadgroups=*/kFillRandomMaxThreadgroups, @@ -89,7 +127,7 @@ class MatMulKernelTester { /*output_offset=*/0, num_tokens() * num_cols(), kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0); - command_buffer.encode_launch_bf16_fill_random( + command_buffer_initialize.encode_launch_bf16_fill_random( bf16_fill_random_fn_, /*threadgroup_size=*/0, /*max_threadgroups=*/kFillRandomMaxThreadgroups, @@ -97,7 +135,7 @@ class MatMulKernelTester { /*output_offset=*/0, num_rows() * num_cols(), kSeed + 1, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0); - command_buffer.encode_launch_bf16_fill_random( + command_buffer_initialize.encode_launch_bf16_fill_random( bf16_fill_random_fn_, /*threadgroup_size=*/0, /*max_threadgroups=*/kFillRandomMaxThreadgroups, @@ -105,32 +143,90 @@ class MatMulKernelTester { /*output_offset=*/0, num_rows(), kSeed + 2, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0); - Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( - command_buffer.handle(), - f32_bf16w_matmul_fn_.handle(), - /*threadgroup_size=*/threadgroup_size(), - input_buffer.handle(), - /*input_offset=*/0, - weight_buffer.handle(), - /*weight_offset=*/0, - bias_buffer.handle(), - /*bias_offset=*/0, - output_buffer.handle(), - /*output_offset=*/0, - control_buffer.handle(), - /*control_offset=*/0, - num_tokens(), - num_cols(), - num_rows()), - "gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul"); - - command_buffer.commit(); - command_buffer.wait_completion(); + // Fill output buffer with random values to test matmul with add. + command_buffer_initialize.encode_launch_f32_fill_random( + f32_fill_random_fn_, + /*threadgroup_size=*/0, + /*max_threadgroups=*/kFillRandomMaxThreadgroups, + /*output_buffer=*/output_buffer, + /*output_offset=*/0, num_tokens() * num_rows(), kSeed + 3, + /*offset=*/0, + /*min=*/-1.0f, /*max=*/1.0); + command_buffer_initialize.commit(); + command_buffer_initialize.wait_completion(); + if (kernel_type == + MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED) { + // Copy output buffer to output buffer copy to use when calculating reference. + const uint64_t bytes = + uint64_t(num_tokens()) * uint64_t(num_rows()) * sizeof(float); + void* src = output_buffer.ptr(); + void* dst = output_buffer_copy.ptr(); + assert(src && dst && "Buffers must be CPU-mappable for memcpy"); + + std::memcpy(reinterpret_cast(dst), + reinterpret_cast(src), bytes); + } + + metal::CommandBuffer command_buffer_compute{command_queue_}; + switch (kernel_type) { + case MatMulKernelType::DECODE_OPTIMIZED: + Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( + command_buffer_compute.handle(), f32_bf16w_matmul_fn_.handle(), + /*threadgroup_size=*/threadgroup_size(), input_buffer.handle(), + /*input_offset=*/0, weight_buffer.handle(), + /*weight_offset=*/0, bias_buffer.handle(), + /*bias_offset=*/0, output_buffer.handle(), + /*output_offset=*/0, control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_cols(), num_rows()), + "gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul"); + break; + case MatMulKernelType::PREFILL_QKV_OPTIMIZED: + Check( + gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv( + command_buffer_compute.handle(), + f32_bf16w_dense_matmul_qkv_fn_.handle(), input_buffer.handle(), + /*input_offset=*/0, weight_buffer.handle(), + /*weight_offset=*/0, bias_buffer.handle(), + /*bias_offset=*/0, output_buffer.handle(), + /*output_offset=*/0, control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_cols(), num_rows()), + "gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv"); + break; + case MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED: + Check( + gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output( + command_buffer_compute.handle(), + f32_bf16w_dense_matmul_attn_output_fn_.handle(), + input_buffer.handle(), + /*input_offset=*/0, weight_buffer.handle(), + /*weight_offset=*/0, bias_buffer.handle(), + /*bias_offset=*/0, output_buffer.handle(), + /*output_offset=*/0, control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_cols(), num_rows()), + "gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output"); + break; + case MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED: + Check( + gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate( + command_buffer_compute.handle(), + f32_bf16w_dense_matmul_mlp_gate_fn_.handle(), + input_buffer.handle(), + /*input_offset=*/0, weight_buffer.handle(), + /*weight_offset=*/0, bias_buffer.handle(), + /*bias_offset=*/0, output_buffer.handle(), + /*output_offset=*/0, control_buffer.handle(), + /*control_offset=*/0, num_tokens(), num_cols(), num_rows()), + "gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate"); + break; + } + command_buffer_compute.commit(); + command_buffer_compute.wait_completion(); const float* input_ptr = static_cast(input_buffer.ptr()); const gptoss_bfloat16* weight_ptr = static_cast(weight_buffer.ptr()); const gptoss_bfloat16* bias_ptr = static_cast(bias_buffer.ptr()); const float* output_ptr = static_cast(output_buffer.ptr()); + const float* output_ptr_copy = static_cast(output_buffer_copy.ptr()); for (size_t t = 0; t < num_tokens(); t++) { for (size_t r = 0; r < num_rows(); r++) { double ref_sum = upcast(bias_ptr[r]); @@ -139,7 +235,13 @@ class MatMulKernelTester { const double input_value = upcast(input_ptr[t * num_cols() + c]); ref_sum = std::fma(input_value, ref_weight, ref_sum); } - ASSERT_NEAR(upcast(output_ptr[t * num_rows() + r]), ref_sum, std::abs(ref_sum) * 1.0e-5) + + if (kernel_type == + MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED) { + ref_sum += upcast(output_ptr_copy[t * num_rows() + r]); + } + ASSERT_NEAR_ABS_REL(upcast(output_ptr[t * num_rows() + r]), + ref_sum, 2.0e-4, 1.0e-4) << "token " << t; } } @@ -159,6 +261,9 @@ class MatMulKernelTester { metal::Function f32_fill_random_fn_{library_, "gptoss_f32_fill_random"}; metal::Function bf16_fill_random_fn_{library_, "gptoss_bf16_fill_random"}; metal::Function f32_bf16w_matmul_fn_{library_, "gptoss_f32_bf16w_matmul"}; + metal::Function f32_bf16w_dense_matmul_qkv_fn_{library_, "gptoss_f32_bf16w_dense_matmul_qkv"}; + metal::Function f32_bf16w_dense_matmul_attn_output_fn_{library_, "gptoss_f32_bf16w_dense_matmul_attn_output"}; + metal::Function f32_bf16w_dense_matmul_mlp_gate_fn_{library_, "gptoss_f32_bf16w_dense_matmul_mlp_gate"}; std::uint32_t num_tokens_{1}; std::uint32_t num_rows_{1}; std::uint32_t num_cols_{32}; From 35eb3cc90f58a82fb814ed6904eac1f4c0c7d85e Mon Sep 17 00:00:00 2001 From: Maratyszcza Date: Thu, 11 Sep 2025 16:02:53 -0700 Subject: [PATCH 90/91] Metal: fused QKV projection (matmul+RoPE+KV cache init) kernel (#184) --- gpt_oss/metal/source/context.c | 101 +++++++++------- .../source/include/internal/kernel-args.h | 13 +++ .../source/include/internal/metal-kernels.h | 36 +++++- gpt_oss/metal/source/include/internal/model.h | 1 + gpt_oss/metal/source/matmul.metal | 88 ++++++++++++++ gpt_oss/metal/source/metal-kernels.c | 109 ++++++++++++++++-- gpt_oss/metal/source/model.c | 5 + gpt_oss/metal/source/rope.metal | 10 +- gpt_oss/metal/source/sdpa.metal | 23 ++-- 9 files changed, 313 insertions(+), 73 deletions(-) diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c index 704ff7f4..5cdaee7f 100644 --- a/gpt_oss/metal/source/context.c +++ b/gpt_oss/metal/source/context.c @@ -253,10 +253,51 @@ static enum gptoss_status process_tokens( GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch"); return status; } + + status = gptoss_metal_command_buffer_encode_launch_f32_rope( + command_buffer, + &model->f32_rope_fn, + /*threadgroup_size=*/32, + &context->qkv_activation_buffer, + /*input_offset=*/0, + &context->control_buffer, + /*control_offset=*/0, + model->rope_theta, + model->interpolation_scale, + model->yarn_offset, + model->yarn_scale, + model->yarn_multiplier, + input_batch_size, + model->num_heads, + model->num_kv_heads, + model->head_dim, + /*token_offset=*/input_batch_start); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch"); + return status; + } + + for (uint32_t t = 0; t < input_batch_size; t++) { + for (uint32_t kv = 0; kv < 2; kv++) { + for (uint32_t h = 0; h < model->num_kv_heads; h++) { + status = gptoss_metal_command_buffer_encode_copy_buffer( + command_buffer, + &context->qkv_activation_buffer, + /*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float), + &context->kvcache_buffer, + /*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float), + /*size=*/model->head_dim * sizeof(float)); + if (status != gptoss_status_success) { + GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t); + return status; + } + } + } + } } else { - status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( + status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv( command_buffer, - &model->f32_bf16w_matmul_fn, + &model->f32_bf16w_matmul_qkv_fn, model->attn_qkv_threadgroup_size, &context->rmsnorm_activation_buffer, /*input_offset=*/0, @@ -266,49 +307,24 @@ static enum gptoss_status process_tokens( /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n, &context->qkv_activation_buffer, /*output_offset=*/0, + &context->kvcache_buffer, + /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float), &context->control_buffer, /*control_offset=*/0, /*num_tokens=*/input_batch_size, /*num_cols=*/model->embedding_dim, - /*num_rows=*/attn_qkv_dim); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch"); - return status; - } - } - status = gptoss_metal_command_buffer_encode_launch_f32_rope( - command_buffer, - &model->f32_rope_fn, - /*threadgroup_size=*/32, - &context->qkv_activation_buffer, - /*input_offset=*/0, - &context->control_buffer, - /*control_offset=*/0, - model->rope_theta, - model->interpolation_scale, - model->yarn_offset, - model->yarn_scale, - model->yarn_multiplier, - input_batch_size, - model->num_heads, - model->num_kv_heads, - model->head_dim, - /*token_offset=*/input_batch_start); - if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch"); - return status; - } - - for (uint32_t t = 0; t < input_batch_size; t++) { - status = gptoss_metal_command_buffer_encode_copy_buffer( - command_buffer, - &context->qkv_activation_buffer, - /*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float), - &context->kvcache_buffer, - /*output_offset=*/(n * context->max_tokens + input_batch_start + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float), - /*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float)); + /*num_q_heads=*/model->num_heads, + /*num_kv_heads=*/model->num_kv_heads, + /*attn_head_dim=*/model->head_dim, + /*token_offset=*/input_batch_start, + /*max_tokens=*/context->max_tokens, + /*rope_base=*/model->rope_theta, + /*interpolation_scale=*/model->interpolation_scale, + /*yarn_offset=*/model->yarn_offset, + /*yarn_scale=*/model->yarn_scale, + /*yarn_multiplier=*/model->yarn_multiplier); if (status != gptoss_status_success) { - GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t); + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch"); return status; } } @@ -320,9 +336,7 @@ static enum gptoss_status process_tokens( &context->qkv_activation_buffer, /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float), &context->kvcache_buffer, - /*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float), - &context->kvcache_buffer, - /*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float), + /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float), &model->shared_weight_buffer, /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n, &context->sdpa_activation_buffer, @@ -330,6 +344,7 @@ static enum gptoss_status process_tokens( &context->control_buffer, /*control_offset=*/0, /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX, + /*kv_stride=*/2 * context->max_tokens * model->head_dim, num_block_output_tokens, input_batch_start + input_batch_size - num_block_output_tokens, model->num_heads, model->num_kv_heads, model->head_dim); diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h index 7dc0b564..90dbdcf7 100644 --- a/gpt_oss/metal/source/include/internal/kernel-args.h +++ b/gpt_oss/metal/source/include/internal/kernel-args.h @@ -39,6 +39,7 @@ struct gptoss_topk_args { struct gptoss_sdpa_args { uint32_t qkv_dim; uint32_t num_kv_tokens; + uint32_t kv_stride; uint32_t window; }; @@ -126,6 +127,18 @@ struct gptoss_rope_args { float yarn_multiplier; }; +struct gptoss_qkv_args { + uint32_t num_column_vecs; + uint32_t num_rows; + uint32_t token_offset; + float freq_scale; + float interpolation_scale; + float yarn_offset; + float yarn_scale; + float yarn_multiplier; + uint32_t max_tokens; +}; + struct gptoss_softmax_args { uint32_t num_vecs; uint32_t num_vecs_per_threadgroup; diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h index 6837bbd2..c12a834d 100644 --- a/gpt_oss/metal/source/include/internal/metal-kernels.h +++ b/gpt_oss/metal/source/include/internal/metal-kernels.h @@ -112,6 +112,35 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( uint32_t num_cols, uint32_t num_rows); +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn, + size_t threadgroup_size, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* kv_buffer, + size_t kv_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_q_heads, + uint32_t num_kv_heads, + uint32_t attn_head_dim, + uint32_t token_offset, + uint32_t max_tokens, + float rope_base, + float interpolation_scale, + float yarn_offset, + float yarn_scale, + float yarn_multiplier); + enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( const struct gptoss_metal_command_buffer* command_buffer, const struct gptoss_metal_function* f32_bf16w_matmul_fn, @@ -306,10 +335,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( const struct gptoss_metal_function* f32_sdpa_fn, const struct gptoss_metal_buffer* q_buffer, size_t q_offset, - const struct gptoss_metal_buffer* k_buffer, - size_t k_offset, - const struct gptoss_metal_buffer* v_buffer, - size_t v_offset, + const struct gptoss_metal_buffer* kv_buffer, + size_t kv_offset, const struct gptoss_metal_buffer* s_buffer, size_t s_offset, const struct gptoss_metal_buffer* output_buffer, @@ -317,6 +344,7 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( const struct gptoss_metal_buffer* control_buffer, size_t control_offset, uint32_t window, + uint32_t kv_stride, uint32_t num_q_tokens, uint32_t num_kv_tokens, uint32_t num_q_heads, diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h index b2c080b1..c63578a7 100644 --- a/gpt_oss/metal/source/include/internal/model.h +++ b/gpt_oss/metal/source/include/internal/model.h @@ -78,6 +78,7 @@ struct gptoss_model { struct gptoss_metal_function bf16_f32_embeddings_fn; struct gptoss_metal_function f32_bf16w_rmsnorm_fn; struct gptoss_metal_function f32_bf16w_matmul_fn; + struct gptoss_metal_function f32_bf16w_matmul_qkv_fn; struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn; struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn; struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn; diff --git a/gpt_oss/metal/source/matmul.metal b/gpt_oss/metal/source/matmul.metal index 1fc3c408..8831563b 100644 --- a/gpt_oss/metal/source/matmul.metal +++ b/gpt_oss/metal/source/matmul.metal @@ -67,6 +67,94 @@ kernel void gptoss_f32_bf16w_matmul( } } +kernel void gptoss_f32_bf16w_matmul_qkv( + constant gptoss_qkv_args& args [[ buffer(0) ]], + const device float4* input [[ buffer(1) ]], + const device bfloat4* weight [[ buffer(2) ]], + const device bfloat* bias [[ buffer(3) ]], + device float* q [[ buffer(4) ]], + device float* kv [[ buffer(5) ]], + const device gptoss_control* control [[ buffer(6) ]], + threadgroup void* scratch [[ threadgroup(0) ]], + uint2 gid [[threadgroup_position_in_grid]], + uint simdgroup_tid [[thread_index_in_simdgroup]], + uint simdgroup_idx [[simdgroup_index_in_threadgroup]], + uint num_simdgroups [[simdgroups_per_threadgroup]]) +{ + const uint simdgroup_size = 32; + const uint head_dim = 64; + const uint num_q_heads = 64; + const uint num_kv_heads = 8; + if (control->abort != 0) { + return; + } + + const uint num_column_vecs = args.num_column_vecs; + const uint row = gid.x * num_simdgroups + simdgroup_idx; + + input += gid.y * num_column_vecs + simdgroup_tid; + weight += num_column_vecs * row + simdgroup_tid; + bias += row; + q += gid.y * args.num_rows; + + uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size; + + float4 sum4 = 0.0f; + do { + const bfloat4 w = *weight; + const float4 i = *input; + sum4 = metal::fma(static_cast(w), i, sum4); + + weight += simdgroup_size; + input += simdgroup_size; + } while (--num_iter != 0); + const float2 sum2 = sum4.xy + sum4.zw; + float sum = sum2.x + sum2.y; + sum = metal::simd_sum(sum); + if (metal::simd_is_first()) { + sum += static_cast(*bias); + static_cast(scratch)[simdgroup_idx] = sum; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + if (simdgroup_idx == 0) { + const uint num_half_simdgroups = num_simdgroups / 2; + if (simdgroup_tid < num_half_simdgroups) { + float2 vals = static_cast(scratch)[simdgroup_tid]; + const uint idx = gid.x * num_half_simdgroups + simdgroup_tid; + const uint head_idx = idx / (head_dim / 2); + const uint token_idx = args.token_offset + gid.y; + const uint dim_idx = idx % (head_dim / 2); + if (head_idx < num_q_heads + num_kv_heads) { + const float dim_idx_val = static_cast(dim_idx); + const float inv_extrapolation_freq = metal::precise::exp(dim_idx_val * args.freq_scale); + const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale; + const float alpha = metal::saturate(metal::fma(dim_idx_val, args.yarn_scale, args.yarn_offset)); + const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha); + + const float phi = static_cast(token_idx) * inv_freq; + const float yarn_multiplier = args.yarn_multiplier; + float cosphi; + const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier; + cosphi *= yarn_multiplier; + + vals = (float2) { + vals.x * cosphi - vals.y * sinphi, + vals.x * sinphi + vals.y * cosphi, + }; + } + if (head_idx < num_q_heads) { + reinterpret_cast(q)[idx] = vals; + } else if (head_idx < num_q_heads + num_kv_heads) { + const uint h = head_idx - num_q_heads; + reinterpret_cast(kv + (h * args.max_tokens + token_idx) * 2 * head_dim)[dim_idx] = vals; + } else { + const uint h = head_idx - num_q_heads - num_kv_heads; + reinterpret_cast(kv + (h * args.max_tokens + token_idx) * 2 * head_dim + head_dim)[dim_idx] = vals; + } + } + } +} + kernel void gptoss_f32_bf16w_unembedding( constant gptoss_unembedding_args& args [[ buffer(0) ]], const device float4* input [[ buffer(1) ]], diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c index dd5e640a..3aaeb32f 100644 --- a/gpt_oss/metal/source/metal-kernels.c +++ b/gpt_oss/metal/source/metal-kernels.c @@ -341,6 +341,101 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul( /*threadgroup_buffer_size=*/0); } +enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv( + const struct gptoss_metal_command_buffer* command_buffer, + const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn, + size_t threadgroup_size, + const struct gptoss_metal_buffer* input_buffer, + size_t input_offset, + const struct gptoss_metal_buffer* weight_buffer, + size_t weight_offset, + const struct gptoss_metal_buffer* bias_buffer, + size_t bias_offset, + const struct gptoss_metal_buffer* output_buffer, + size_t output_offset, + const struct gptoss_metal_buffer* kv_buffer, + size_t kv_offset, + const struct gptoss_metal_buffer* control_buffer, + size_t control_offset, + uint32_t num_tokens, + uint32_t num_cols, + uint32_t num_q_heads, + uint32_t num_kv_heads, + uint32_t attn_head_dim, + uint32_t token_offset, + uint32_t max_tokens, + float rope_base, + float interpolation_scale, + float yarn_offset, + float yarn_scale, + float yarn_multiplier) +{ + if (command_buffer->object == NULL || f32_bf16w_matmul_qkv_fn->pipeline_state_object == NULL) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: invalid command buffer or pipeline state object"); + return gptoss_status_invalid_state; + } + + if (threadgroup_size == 0) { + threadgroup_size = f32_bf16w_matmul_qkv_fn->simdgroup_threads; + } else if (threadgroup_size > f32_bf16w_matmul_qkv_fn->max_threadgroup_threads) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)", + threadgroup_size, f32_bf16w_matmul_qkv_fn->max_threadgroup_threads); + return gptoss_status_invalid_argument; + } + + if (num_cols % 4 != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of columns (%" PRIu32 ") is not divisible by 4", + num_cols); + return gptoss_status_invalid_argument; + } + + if (num_q_heads != 64) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of Q heads (%" PRIu32 ") must be 64", + num_q_heads); + return gptoss_status_invalid_argument; + } + if (num_kv_heads != 8) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of KV heads (%" PRIu32 ") must be 8", + num_kv_heads); + return gptoss_status_invalid_argument; + } + if (attn_head_dim != 64) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: attention head dimension (%" PRIu32 ") must be 64", + attn_head_dim); + return gptoss_status_invalid_argument; + } + + const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_qkv_fn->simdgroup_threads; + const uint32_t num_rows = (num_q_heads + 2 * num_kv_heads) * attn_head_dim; + if (num_rows % num_simdgroups != 0) { + GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)", + num_rows, num_simdgroups); + return gptoss_status_invalid_argument; + } + + const struct gptoss_qkv_args args = { + .num_column_vecs = num_cols / 4, + .num_rows = num_rows, + .token_offset = token_offset, + .freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2), + .interpolation_scale = interpolation_scale, + .yarn_offset = yarn_offset, + .yarn_scale = yarn_scale, + .yarn_multiplier = yarn_multiplier, + .max_tokens = max_tokens, + }; + + return gptoss_metal_command_buffer_encode_launch_kernel( + command_buffer, f32_bf16w_matmul_qkv_fn, + threadgroup_size, 1, 1, + num_rows / num_simdgroups, num_tokens, 1, + sizeof(args), &args, + 6, + (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, kv_buffer, control_buffer}, + (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, kv_offset, control_offset}, + /*threadgroup_buffer_size=*/num_simdgroups * sizeof(float)); +} + enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add( const struct gptoss_metal_command_buffer* command_buffer, const struct gptoss_metal_function* f32_bf16w_matmul_fn, @@ -936,10 +1031,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( const struct gptoss_metal_function* f32_sdpa_fn, const struct gptoss_metal_buffer* q_buffer, size_t q_offset, - const struct gptoss_metal_buffer* k_buffer, - size_t k_offset, - const struct gptoss_metal_buffer* v_buffer, - size_t v_offset, + const struct gptoss_metal_buffer* kv_buffer, + size_t kv_offset, const struct gptoss_metal_buffer* s_buffer, size_t s_offset, const struct gptoss_metal_buffer* output_buffer, @@ -947,6 +1040,7 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( const struct gptoss_metal_buffer* control_buffer, size_t control_offset, uint32_t window, + uint32_t kv_stride, uint32_t num_q_tokens, uint32_t num_kv_tokens, uint32_t num_q_heads, @@ -976,6 +1070,7 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( const struct gptoss_sdpa_args args = { .qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads), .num_kv_tokens = num_kv_tokens, + .kv_stride = kv_stride, .window = window, }; @@ -984,9 +1079,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa( threadgroup_size, 1, 1, num_q_tokens, num_kv_heads, 1, sizeof(args), &args, - 6, - (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer, control_buffer}, - (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset, control_offset}, + 5, + (const struct gptoss_metal_buffer *[]) {q_buffer, kv_buffer, s_buffer, output_buffer, control_buffer}, + (const size_t[]) {q_offset, kv_offset, s_offset, output_offset, control_offset}, /*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float)); } diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c index 5f79c030..469ef232 100644 --- a/gpt_oss/metal/source/model.c +++ b/gpt_oss/metal/source/model.c @@ -325,6 +325,10 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file( if (status != gptoss_status_success) { goto cleanup; } + status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul_qkv", &model->f32_bf16w_matmul_qkv_fn); + if (status != gptoss_status_success) { + goto cleanup; + } status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn); if (status != gptoss_status_success) { goto cleanup; @@ -514,6 +518,7 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release( gptoss_metal_function_release(&model->bf16_f32_embeddings_fn); gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn); gptoss_metal_function_release(&model->f32_bf16w_matmul_fn); + gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn); gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn); gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn); gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn); diff --git a/gpt_oss/metal/source/rope.metal b/gpt_oss/metal/source/rope.metal index ce4c3c8f..8bd2f568 100644 --- a/gpt_oss/metal/source/rope.metal +++ b/gpt_oss/metal/source/rope.metal @@ -21,14 +21,14 @@ kernel void gptoss_f32_rope( return; } - const float head_idx = static_cast(gid.x % (num_head_dims / 2)); + const float dim_idx = static_cast(gid.x % (num_head_dims / 2)); const uint token_idx = args.token_offset + gid.y; activations += gid.y * args.token_stride + gid.x; const float2 input_vals = *activations; - const float inv_extrapolation_freq = metal::precise::exp(head_idx * args.freq_scale); + const float inv_extrapolation_freq = metal::precise::exp(dim_idx * args.freq_scale); const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale; - const float alpha = metal::saturate(metal::fma(head_idx, args.yarn_scale, args.yarn_offset)); + const float alpha = metal::saturate(metal::fma(dim_idx, args.yarn_scale, args.yarn_offset)); const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha); const float phi = static_cast(token_idx) * inv_freq; @@ -37,7 +37,7 @@ kernel void gptoss_f32_rope( const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier; cosphi *= yarn_multiplier; - const float output_re = metal::fma(-input_vals.y, sinphi, input_vals.x * cosphi); - const float output_im = metal::fma(input_vals.y, cosphi, input_vals.x * sinphi); + const float output_re = input_vals.x * cosphi - input_vals.y * sinphi; + const float output_im = input_vals.x * sinphi + input_vals.y * cosphi; *activations = (float2) { output_re, output_im }; } diff --git a/gpt_oss/metal/source/sdpa.metal b/gpt_oss/metal/source/sdpa.metal index 459bbe28..d112569f 100644 --- a/gpt_oss/metal/source/sdpa.metal +++ b/gpt_oss/metal/source/sdpa.metal @@ -14,10 +14,9 @@ kernel void gptoss_f32_sdpa_q8_d64( constant gptoss_sdpa_args& args [[ buffer(0) ]], const device float* q [[ buffer(1) ]], - const device float* k [[ buffer(2) ]], - const device float* v [[ buffer(3) ]], - const device bfloat* s [[ buffer(4) ]], - device float* output [[ buffer(5) ]], + const device float* kv [[ buffer(2) ]], + const device bfloat* s [[ buffer(3) ]], + device float* output [[ buffer(4) ]], const device gptoss_control* control [[ buffer(6) ]], threadgroup void* threadgroup_buffer [[ threadgroup(0) ]], uint2 gid [[threadgroup_position_in_grid]], @@ -32,18 +31,16 @@ kernel void gptoss_f32_sdpa_q8_d64( } const uint num_q_heads = 64; - const uint num_kv_heads = 8; const uint head_dim = 64; const uint qmul = 8; - const uint token_stride = 2 * num_kv_heads * head_dim; + const uint token_stride = 2 * head_dim; const uint qt = gid.x; // Q token index const uint h = gid.y; // KV head index q += qt * args.qkv_dim + h * (qmul * head_dim); - k += h * head_dim; - v += h * head_dim; + kv += h * args.kv_stride; output += qt * (num_q_heads * head_dim) + h * (qmul * head_dim); float m0 = static_cast(s[h * qmul + 0]); @@ -84,11 +81,9 @@ kernel void gptoss_f32_sdpa_q8_d64( const uint kt_end = qt + args.num_kv_tokens + 1; const uint kt_start = metal::subsat(kt_end, args.window) + simdgroup_idx; - k += token_stride * kt_start; - v += token_stride * kt_start; + kv += token_stride * kt_start; for (uint kt = kt_start; kt < kt_end; kt += num_simdgroups) { - const float2 kval = reinterpret_cast(k)[simdgroup_tid]; - k += token_stride * num_simdgroups; + const float2 kval = reinterpret_cast(kv)[simdgroup_tid]; float qk0 = metal::dot(q0, kval); float qk1 = metal::dot(q1, kval); @@ -153,8 +148,8 @@ kernel void gptoss_f32_sdpa_q8_d64( m6 = new_m6; m7 = new_m7; - const float2 vval = reinterpret_cast(v)[simdgroup_tid]; - v += token_stride * num_simdgroups; + const float2 vval = reinterpret_cast(kv + head_dim)[simdgroup_tid]; + kv += token_stride * num_simdgroups; out0 = metal::fma(vval, qk0, out0 * alpha0); out1 = metal::fma(vval, qk1, out1 * alpha1); out2 = metal::fma(vval, qk2, out2 * alpha2); From 758e904af5f9f5065621b2c2610498bfbe4b3122 Mon Sep 17 00:00:00 2001 From: shaeenhaque Date: Sun, 14 Sep 2025 03:14:44 +0600 Subject: [PATCH 91/91] Create devcontainer.json --- .devcontainer/devcontainer.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..39bbd268 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,4 @@ +{ + "image": "mcr.microsoft.com/devcontainers/universal:2", + "features": {} +}