diff --git a/examples/gradio/gradio_chat.py b/examples/gradio/gradio_chat.py index da742bd3..99a46f09 100644 --- a/examples/gradio/gradio_chat.py +++ b/examples/gradio/gradio_chat.py @@ -15,25 +15,25 @@ } """.strip() -def chat_with_model(message, history, model_choice, instructions, effort, use_functions, +def chat_with_model(message, history, model_choice, instructions, effort, use_functions, function_name, function_description, function_parameters, use_browser_search, temperature, max_output_tokens, debug_mode): - + if not message.strip(): return history, "" - + # Append user message and empty assistant placeholder (idiomatic Gradio pattern) history = history + [[message, ""]] - + # Build messages list from history (excluding the empty assistant placeholder) messages = [] - + # Convert history to messages format (excluding the last empty assistant message) for user_msg, assistant_msg in history[:-1]: if user_msg: messages.append({ "type": "message", - "role": "user", + "role": "user", "content": [{"type": "input_text", "text": user_msg}] }) if assistant_msg: @@ -42,14 +42,14 @@ def chat_with_model(message, history, model_choice, instructions, effort, use_fu "role": "assistant", "content": [{"type": "output_text", "text": assistant_msg}] }) - + # Add current user message messages.append({ "type": "message", "role": "user", "content": [{"type": "input_text", "text": message}] }) - + # Prepare tools tools = [] if use_functions: @@ -62,18 +62,18 @@ def chat_with_model(message, history, model_choice, instructions, effort, use_fu }) except json.JSONDecodeError: pass - + if use_browser_search: tools.append({"type": "browser_search"}) - + # Get URL based on model (matching streamlit logic) options = ["large", "small"] - URL = ("http://localhost:8081/v1/responses" if model_choice == options[1] + url = ("http://localhost:8081/v1/responses" if model_choice == options[1] else "http://localhost:8000/v1/responses") - + try: response = requests.post( - URL, + url, json={ "input": messages, "stream": True, @@ -86,32 +86,31 @@ def chat_with_model(message, history, model_choice, instructions, effort, use_fu }, stream=True, ) - + full_content = "" - text_delta = "" - current_output_index = 0 + in_reasoning = False - + for line in response.iter_lines(decode_unicode=True): if not line or not line.startswith("data:"): continue data_str = line[len("data:"):].strip() if not data_str: continue - + try: data = json.loads(data_str) except Exception: continue - + event_type = data.get("type", "") - output_index = data.get("output_index", 0) - + + if event_type == "response.output_item.added": - current_output_index = output_index + output_type = data.get("item", {}).get("type", "message") - text_delta = "" - + + if output_type == "reasoning": if not in_reasoning: full_content += "šŸ¤” **Thinking...**\n" @@ -120,56 +119,56 @@ def chat_with_model(message, history, model_choice, instructions, effort, use_fu if in_reasoning: full_content += "\n\n" in_reasoning = False - + elif event_type == "response.reasoning_text.delta": delta = data.get("delta", "") full_content += delta - + # Update last assistant message (idiomatic Gradio pattern) history[-1][1] = full_content yield history, "" - + elif event_type == "response.output_text.delta": delta = data.get("delta", "") full_content += delta - - # Update last assistant message (idiomatic Gradio pattern) + + # Update last assistant message (idiomatic Gradio pattern) history[-1][1] = full_content yield history, "" - + elif event_type == "response.output_item.done": item = data.get("item", {}) if item.get("type") == "function_call": function_call_text = f"\n\nšŸ”Ø Called `{item.get('name')}`\n**Arguments**\n```json\n{item.get('arguments', '')}\n```" full_content += function_call_text - + # Update last assistant message (idiomatic Gradio pattern) history[-1][1] = full_content yield history, "" - + elif item.get("type") == "web_search_call": web_search_text = f"\n\n🌐 **Web Search**\n```json\n{json.dumps(item.get('action', {}), indent=2)}\n```\nāœ… Done" full_content += web_search_text - + # Update last assistant message (idiomatic Gradio pattern) history[-1][1] = full_content yield history, "" - + elif event_type == "response.completed": response_data = data.get("response", {}) if debug_mode: debug_info = response_data.get("metadata", {}).get("__debug", "") if debug_info: full_content += f"\n\n**Debug**\n```\n{debug_info}\n```" - + # Update last assistant message (idiomatic Gradio pattern) history[-1][1] = full_content yield history, "" break - + # Return final history and empty string to clear textbox return history, "" - + except Exception as e: error_message = f"āŒ Error: {str(e)}" history[-1][1] = error_message @@ -179,69 +178,69 @@ def chat_with_model(message, history, model_choice, instructions, effort, use_fu # Create the Gradio interface with gr.Blocks(title="šŸ’¬ Chatbot") as demo: gr.Markdown("# šŸ’¬ Chatbot") - + with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=500) - + with gr.Row(): msg = gr.Textbox(placeholder="Type a message...", scale=4, show_label=False) send_btn = gr.Button("Send", scale=1) - + clear_btn = gr.Button("Clear Chat") - + with gr.Column(scale=1): model_choice = gr.Radio(["large", "small"], value="small", label="Model") - + instructions = gr.Textbox( - label="Instructions", + label="Instructions", value="You are a helpful assistant that can answer questions and help with tasks.", lines=3 ) - + effort = gr.Radio(["low", "medium", "high"], value="medium", label="Reasoning effort") - + gr.Markdown("#### Functions") use_functions = gr.Checkbox(label="Use functions", value=False) - + with gr.Column(visible=False) as function_group: function_name = gr.Textbox(label="Function name", value="get_weather") function_description = gr.Textbox( - label="Function description", + label="Function description", value="Get the weather for a given city" ) function_parameters = gr.Textbox( - label="Function parameters", + label="Function parameters", value=DEFAULT_FUNCTION_PROPERTIES, lines=6 ) - + # Conditional browser search (matching Streamlit logic) # In Streamlit: if "show_browser" in st.query_params: # For Gradio, we'll always show it (simplified) - gr.Markdown("#### Built-in Tools") + gr.Markdown("#### Built-in Tools") use_browser_search = gr.Checkbox(label="Use browser search", value=False) - + temperature = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Temperature") max_output_tokens = gr.Slider(1000, 20000, value=1024, step=100, label="Max output tokens") - + debug_mode = gr.Checkbox(label="Debug mode", value=False) - + # Event handlers def toggle_function_group(use_funcs): return gr.update(visible=use_funcs) - + use_functions.change(toggle_function_group, use_functions, function_group) - + # Chat functionality - inputs = [msg, chatbot, model_choice, instructions, effort, use_functions, + inputs = [msg, chatbot, model_choice, instructions, effort, use_functions, function_name, function_description, function_parameters, use_browser_search, temperature, max_output_tokens, debug_mode] - + msg.submit(chat_with_model, inputs, [chatbot, msg]) send_btn.click(chat_with_model, inputs, [chatbot, msg]) clear_btn.click(lambda: [], outputs=chatbot) if __name__ == "__main__": - demo.launch() \ No newline at end of file + demo.launch() diff --git a/gpt_oss/evals/aime_eval.py b/gpt_oss/evals/aime_eval.py index c6e9d64b..2b4bd805 100644 --- a/gpt_oss/evals/aime_eval.py +++ b/gpt_oss/evals/aime_eval.py @@ -44,9 +44,9 @@ def __init__( num_examples: int | None = None, # restrict to a subset of the data for debugging n_threads: int = 1, ): - path1 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl" + path1 = "https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl" df1 = pandas.read_json(path1, lines=True) - path2 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl" + path2 = "https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl" df2 = pandas.read_json(path2, lines=True) examples = [row.to_dict() for _, row in df1.iterrows()] + [row.to_dict() for _, row in df2.iterrows()] examples = [{ @@ -94,4 +94,3 @@ def fn(row: dict): results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads) return report.aggregate_results(results) - diff --git a/gpt_oss/responses_api/inference/triton.py b/gpt_oss/responses_api/inference/triton.py index cb08be31..be88b648 100644 --- a/gpt_oss/responses_api/inference/triton.py +++ b/gpt_oss/responses_api/inference/triton.py @@ -1,12 +1,12 @@ -import datetime + import os from typing import Callable os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch -import torch.distributed as dist -from gpt_oss.triton.model import Cache, ModelConfig, Transformer + +from gpt_oss.triton.model import Cache, Transformer DEFAULT_TEMPERATURE = 0.0 CONTEXT = 16_384 @@ -73,7 +73,7 @@ def infer_next_token( tokens_so_far = lcp(tokens_so_far, tokens) for cache in caches: cache.truncate(len(tokens_so_far)) - all_tokens = tokens # for pdb + tokens = tokens[len(tokens_so_far) :] if len(tokens) > 1: diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 7fd354bb..ca4fafc8 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -1,12 +1,12 @@ -import pytest + import json -import asyncio + from fastapi import status -from unittest.mock import patch, MagicMock, AsyncMock + class TestResponsesEndpoint: - + def test_basic_response_creation(self, api_client, sample_request_data): response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK @@ -14,7 +14,7 @@ def test_basic_response_creation(self, api_client, sample_request_data): assert "id" in data assert data["object"] == "response" assert data["model"] == sample_request_data["model"] - + def test_response_with_high_reasoning(self, api_client, sample_request_data): sample_request_data["reasoning_effort"] = "high" response = api_client.post("/v1/responses", json=sample_request_data) @@ -22,7 +22,7 @@ def test_response_with_high_reasoning(self, api_client, sample_request_data): data = response.json() assert "id" in data assert data["status"] == "completed" - + def test_response_with_medium_reasoning(self, api_client, sample_request_data): sample_request_data["reasoning_effort"] = "medium" response = api_client.post("/v1/responses", json=sample_request_data) @@ -30,18 +30,18 @@ def test_response_with_medium_reasoning(self, api_client, sample_request_data): data = response.json() assert "id" in data assert data["status"] == "completed" - + def test_response_with_invalid_model(self, api_client, sample_request_data): sample_request_data["model"] = "invalid-model" response = api_client.post("/v1/responses", json=sample_request_data) # Should still accept but might handle differently assert response.status_code == status.HTTP_200_OK - + def test_response_with_empty_input(self, api_client, sample_request_data): sample_request_data["input"] = "" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_response_with_tools(self, api_client, sample_request_data): sample_request_data["tools"] = [ { @@ -50,7 +50,7 @@ def test_response_with_tools(self, api_client, sample_request_data): ] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_response_with_custom_temperature(self, api_client, sample_request_data): for temp in [0.0, 0.5, 1.0, 1.5, 2.0]: sample_request_data["temperature"] = temp @@ -58,7 +58,7 @@ def test_response_with_custom_temperature(self, api_client, sample_request_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert "usage" in data - + def test_streaming_response(self, api_client, sample_request_data): sample_request_data["stream"] = True with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response: @@ -73,32 +73,32 @@ def test_streaming_response(self, api_client, sample_request_data): class TestResponsesWithSession: - + def test_response_with_session_id(self, api_client, sample_request_data): session_id = "test-session-123" sample_request_data["session_id"] = session_id - + # First request response1 = api_client.post("/v1/responses", json=sample_request_data) assert response1.status_code == status.HTTP_200_OK data1 = response1.json() - + # Second request with same session sample_request_data["input"] = "Follow up question" response2 = api_client.post("/v1/responses", json=sample_request_data) assert response2.status_code == status.HTTP_200_OK data2 = response2.json() - + # Should have different response IDs assert data1["id"] != data2["id"] - + def test_response_continuation(self, api_client, sample_request_data): # Create initial response response1 = api_client.post("/v1/responses", json=sample_request_data) assert response1.status_code == status.HTTP_200_OK data1 = response1.json() response_id = data1["id"] - + # Continue the response continuation_request = { "model": sample_request_data["model"], @@ -110,18 +110,18 @@ def test_response_continuation(self, api_client, sample_request_data): class TestErrorHandling: - + def test_missing_required_fields(self, api_client): # Model field has default, so test with empty JSON response = api_client.post("/v1/responses", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - + def test_invalid_reasoning_effort(self, api_client, sample_request_data): sample_request_data["reasoning_effort"] = "invalid" response = api_client.post("/v1/responses", json=sample_request_data) # May handle gracefully or return error assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY] - + def test_malformed_json(self, api_client): response = api_client.post( "/v1/responses", @@ -129,7 +129,7 @@ def test_malformed_json(self, api_client): headers={"Content-Type": "application/json"} ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - + def test_extremely_long_input(self, api_client, sample_request_data): # Test with very long input sample_request_data["input"] = "x" * 100000 @@ -138,7 +138,7 @@ def test_extremely_long_input(self, api_client, sample_request_data): class TestToolIntegration: - + def test_browser_search_tool(self, api_client, sample_request_data): sample_request_data["tools"] = [ { @@ -147,7 +147,7 @@ def test_browser_search_tool(self, api_client, sample_request_data): ] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_function_tool_integration(self, api_client, sample_request_data): sample_request_data["tools"] = [ { @@ -159,7 +159,7 @@ def test_function_tool_integration(self, api_client, sample_request_data): ] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_multiple_tools(self, api_client, sample_request_data): sample_request_data["tools"] = [ { @@ -177,16 +177,16 @@ def test_multiple_tools(self, api_client, sample_request_data): class TestPerformance: - + def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer): performance_timer.start() response = api_client.post("/v1/responses", json=sample_request_data) elapsed = performance_timer.stop() - + assert response.status_code == status.HTTP_200_OK # Response should be reasonably fast for mock inference assert elapsed < 5.0 # 5 seconds threshold - + def test_multiple_sequential_requests(self, api_client, sample_request_data): # Test multiple requests work correctly for i in range(3): @@ -197,12 +197,12 @@ def test_multiple_sequential_requests(self, api_client, sample_request_data): class TestUsageTracking: - + def test_usage_object_structure(self, api_client, sample_request_data): response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK data = response.json() - + assert "usage" in data usage = data["usage"] assert "input_tokens" in usage @@ -210,21 +210,21 @@ def test_usage_object_structure(self, api_client, sample_request_data): assert "total_tokens" in usage # reasoning_tokens may not always be present # assert "reasoning_tokens" in usage - + # Basic validation assert usage["input_tokens"] >= 0 assert usage["output_tokens"] >= 0 assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] - + def test_usage_increases_with_longer_input(self, api_client, sample_request_data): # Short input response1 = api_client.post("/v1/responses", json=sample_request_data) usage1 = response1.json()["usage"] - + # Longer input sample_request_data["input"] = sample_request_data["input"] * 10 response2 = api_client.post("/v1/responses", json=sample_request_data) usage2 = response2.json()["usage"] - + # Longer input should use more tokens - assert usage2["input_tokens"] > usage1["input_tokens"] \ No newline at end of file + assert usage2["input_tokens"] > usage1["input_tokens"]