From 18880a1051da1fdf6ddcc49f1cbfc654fb6d1519 Mon Sep 17 00:00:00 2001 From: Alonso Astroza Tagle Date: Tue, 15 Apr 2025 11:42:23 -0400 Subject: [PATCH] feature: update smol agent demo with basic mcp server connection --- its-a-smol-world-mcp/README.md | 30 ++ .../mcp-server/.python-version | 1 + its-a-smol-world-mcp/mcp-server/README.md | 0 its-a-smol-world-mcp/mcp-server/main.py | 6 + .../mcp-server/pyproject.toml | 9 + its-a-smol-world-mcp/mcp-server/server.py | 50 +++ its-a-smol-world-mcp/requirements.txt | 6 + its-a-smol-world-mcp/src/__init__.py | 0 its-a-smol-world-mcp/src/app.py | 74 +++++ its-a-smol-world-mcp/src/constants.py | 3 + its-a-smol-world-mcp/src/smol_mind.py | 306 ++++++++++++++++++ 11 files changed, 485 insertions(+) create mode 100644 its-a-smol-world-mcp/README.md create mode 100644 its-a-smol-world-mcp/mcp-server/.python-version create mode 100644 its-a-smol-world-mcp/mcp-server/README.md create mode 100644 its-a-smol-world-mcp/mcp-server/main.py create mode 100644 its-a-smol-world-mcp/mcp-server/pyproject.toml create mode 100644 its-a-smol-world-mcp/mcp-server/server.py create mode 100644 its-a-smol-world-mcp/requirements.txt create mode 100644 its-a-smol-world-mcp/src/__init__.py create mode 100644 its-a-smol-world-mcp/src/app.py create mode 100644 its-a-smol-world-mcp/src/constants.py create mode 100644 its-a-smol-world-mcp/src/smol_mind.py diff --git a/its-a-smol-world-mcp/README.md b/its-a-smol-world-mcp/README.md new file mode 100644 index 0000000..591c9fe --- /dev/null +++ b/its-a-smol-world-mcp/README.md @@ -0,0 +1,30 @@ +# Outlines MCP Demo + +This is a small update to the [It's a Smol World](https://github.com/dottxt-ai/demos/tree/main/its-a-smol-world) demo, adding Model Context Protocol (MCP) connectivity. + +The core concept remains the same: using a small language model for function calling, but now the client can connect to any MCP-compatible server instead of just using local functions. This means you can leverage the efficiency of a small local model for routing while accessing powerful external tools through the MCP protocol. + +## Installation + +### Windows + +```bash +uv venv --python 3.11 +.venv\Scripts\activate +uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +uv pip install -r requirements.txt +``` + +## Usage + +```bash +python .\src\app.py mcp-server\server.py -d +``` + +## Test Examples + +- "Add 5 and 7" +- "I'd like to order two coffees from starbucks" +- "I need a ride to SEATAC terminal A" +- "What's the weather in san francisco today?" +- "Text Remi and tell him the project is looking good" \ No newline at end of file diff --git a/its-a-smol-world-mcp/mcp-server/.python-version b/its-a-smol-world-mcp/mcp-server/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/its-a-smol-world-mcp/mcp-server/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/its-a-smol-world-mcp/mcp-server/README.md b/its-a-smol-world-mcp/mcp-server/README.md new file mode 100644 index 0000000..e69de29 diff --git a/its-a-smol-world-mcp/mcp-server/main.py b/its-a-smol-world-mcp/mcp-server/main.py new file mode 100644 index 0000000..3543ba8 --- /dev/null +++ b/its-a-smol-world-mcp/mcp-server/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from mcp-server!") + + +if __name__ == "__main__": + main() diff --git a/its-a-smol-world-mcp/mcp-server/pyproject.toml b/its-a-smol-world-mcp/mcp-server/pyproject.toml new file mode 100644 index 0000000..83798ae --- /dev/null +++ b/its-a-smol-world-mcp/mcp-server/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "mcp-server" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "mcp[cli]>=1.6.0", +] diff --git a/its-a-smol-world-mcp/mcp-server/server.py b/its-a-smol-world-mcp/mcp-server/server.py new file mode 100644 index 0000000..86e48a6 --- /dev/null +++ b/its-a-smol-world-mcp/mcp-server/server.py @@ -0,0 +1,50 @@ +from mcp.server.fastmcp import FastMCP + +# Create an MCP server +mcp = FastMCP("Demo") + +# Add an addition tool +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + +# Add a text messaging tool +@mcp.tool() +def send_text(to: str, message: str) -> str: + """Send a text message to a contact""" + # In a real application, this would integrate with a messaging service + return f"Message sent to {to}: {message}" + +# Add a food ordering tool +@mcp.tool() +def order_food(restaurant: str, item: str, quantity: int) -> str: + """Order food from a restaurant""" + # In a real application, this would integrate with a food ordering service + return f"Ordered {quantity} {item}(s) from {restaurant}." + +# Add a ride ordering tool +@mcp.tool() +def order_ride(dest: str) -> str: + """Order a ride from a ride sharing service""" + # In a real application, this would integrate with a ride sharing service + return f"Ride ordered to {dest}. Your driver will arrive in 5 minutes." + +# Add a weather information tool +@mcp.tool() +def get_weather(city: str) -> str: + """Get the weather for a city""" + # In a real application, this would integrate with a weather API + # Using placeholder response for demo purposes + weather_data = { + "New York": "Partly cloudy, 72°F", + "San Francisco": "Foggy, 58°F", + "Los Angeles": "Sunny, 82°F", + "Chicago": "Windy, 55°F", + "Miami": "Rainy, 80°F" + } + return weather_data.get(city, f"Weather information for {city} is not available.") + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport='stdio') \ No newline at end of file diff --git a/its-a-smol-world-mcp/requirements.txt b/its-a-smol-world-mcp/requirements.txt new file mode 100644 index 0000000..9a0f6e2 --- /dev/null +++ b/its-a-smol-world-mcp/requirements.txt @@ -0,0 +1,6 @@ +outlines==0.2.3 +mcp==1.6.0 +transformers +sentencepiece +datasets +accelerate>=0.26.0 \ No newline at end of file diff --git a/its-a-smol-world-mcp/src/__init__.py b/its-a-smol-world-mcp/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/its-a-smol-world-mcp/src/app.py b/its-a-smol-world-mcp/src/app.py new file mode 100644 index 0000000..9d8fb3d --- /dev/null +++ b/its-a-smol-world-mcp/src/app.py @@ -0,0 +1,74 @@ +import time +import itertools +import threading +import sys +import argparse +import asyncio +from smol_mind import SmolMind +from constants import MODEL_NAME + +# Thanks to @torymur for the bunny ascii art! +bunny_ascii = r""" +(\(\ + ( -.-) + o_(")(") +""" + +def spinner(stop_event): + spinner = itertools.cycle(['-', '/', '|', '\\']) + while not stop_event.is_set(): + sys.stdout.write(next(spinner)) + sys.stdout.flush() + sys.stdout.write('\b') + time.sleep(0.1) + +async def main(): + # Add command-line argument parsing + parser = argparse.ArgumentParser(description="SmolMind MCP Client") + parser.add_argument('server_path', help='Path to the MCP server script (.py or .js)') + parser.add_argument('-d', '--debug', action='store_true', help='Enable debug mode') + args = parser.parse_args() + + print("Loading SmolMind MCP client...") + sm = SmolMind(args.server_path, model_name=MODEL_NAME, debug=args.debug) + + try: + # Connect to the server + tools = await sm.connect_to_server() + if args.debug: + print("Using model:", sm.model_name) + print("Debug mode:", "Enabled" if args.debug else "Disabled") + print(f"Available tools: {[tool.name for tool in tools]}") + + print(bunny_ascii) + print("Welcome to the Bunny B1 MCP Client! What do you need?") + + while True: + user_input = input("> ") + if user_input.lower() in ["exit", "quit"]: + print("Goodbye!") + break + + # Create a shared event to stop the spinner + stop_event = threading.Event() + + # Start the spinner in a separate thread + spinner_thread = threading.Thread(target=spinner, args=(stop_event,)) + spinner_thread.daemon = True + spinner_thread.start() + + try: + response = await sm.process_query(user_input) + finally: + # Stop the spinner + stop_event.set() + spinner_thread.join() + sys.stdout.write(' \b') # Erase the spinner + + print(response) + finally: + # Ensure we close the connection + await sm.close() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/its-a-smol-world-mcp/src/constants.py b/its-a-smol-world-mcp/src/constants.py new file mode 100644 index 0000000..5037b11 --- /dev/null +++ b/its-a-smol-world-mcp/src/constants.py @@ -0,0 +1,3 @@ +MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct" +DEVICE = "cuda" +T_TYPE = "bfloat16" \ No newline at end of file diff --git a/its-a-smol-world-mcp/src/smol_mind.py b/its-a-smol-world-mcp/src/smol_mind.py new file mode 100644 index 0000000..d14d9ce --- /dev/null +++ b/its-a-smol-world-mcp/src/smol_mind.py @@ -0,0 +1,306 @@ +import re +import logging +from textwrap import dedent +import outlines +from outlines.samplers import greedy +from transformers import AutoTokenizer, logging as trf_logging +from contextlib import AsyncExitStack +import warnings + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from constants import MODEL_NAME, DEVICE, T_TYPE + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("smol_mind") +trf_logging.set_verbosity_error() + +def format_functions(functions): + formatted_functions = [] + for func in functions: + function_info = f"{func['name']}: {func['description']}\n" + if 'parameters' in func and 'properties' in func['parameters']: + for arg, details in func['parameters']['properties'].items(): + description = details.get('description', 'No description provided') + function_info += f"- {arg}: {description}\n" + formatted_functions.append(function_info) + return "\n".join(formatted_functions) + +SYSTEM_PROMPT_FOR_CHAT_MODEL = dedent(""" + You are an expert designed to call the correct function to solve a problem based on the user's request. + The functions available (with required parameters) to you are: + {functions} + + You will be given a user prompt and you need to decide which function to call. + You will then need to format the function call correctly and return it in the correct format. + The format for the function call is: + [func1(params_name=params_value] + NO other text MUST be included. + + For example: + Request: I want to order a cheese pizza from Pizza Hut. + Response: [order_food(restaurant="Pizza Hut", item="cheese pizza", quantity=1)] + + Request: Is it raining in NY. + Response: [get_weather(city="New York")] + + Request: I need a ride to SFO. + Response: [order_ride(dest="SFO")] + + Request: I want to send a text to John saying Hello. + Response: [send_text(to="John", message="Hello!")] +""") + + +ASSISTANT_PROMPT_FOR_CHAT_MODEL = dedent(""" + I understand and will only return the function call in the correct format. + """ +) +USER_PROMPT_FOR_CHAT_MODEL = dedent(""" + Request: {user_prompt}. +""") + +def continue_prompt(question, functions, tokenizer): + prompt = SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=format_functions(functions)) + prompt += "\n\n" + prompt += USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=question) + return prompt + +def instruct_prompt(question, functions, tokenizer): + messages = [ + {"role": "user", "content": SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=format_functions(functions))}, + {"role": "assistant", "content": ASSISTANT_PROMPT_FOR_CHAT_MODEL }, + {"role": "user", "content": USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=question)}, + ] + fc_prompt = tokenizer.apply_chat_template(messages, tokenize=False) + return fc_prompt + +INTEGER = r"(-)?(0|[1-9][0-9]*)" +STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])' +# We'll limit this to just a max of 42 characters +STRING = f'"{STRING_INNER}{{1,42}}"' +# i.e. 1 is a not a float but 1.0 is. +FLOAT = rf"({INTEGER})(\.[0-9]+)([eE][+-][0-9]+)?" +BOOLEAN = r"(true|false)" +NULL = r"null" + +simple_type_map = { + "string": STRING, + "any": STRING, + "integer": INTEGER, + "number": FLOAT, + "float": FLOAT, + "boolean": BOOLEAN, + "null": NULL, +} + +def build_dict_regex(props): + out_re = r"\{" + args_part = ", ".join( + [f'"{prop}": ' + type_to_regex(props[prop]) for prop in props] + ) + return out_re + args_part + r"\}" + +def type_to_regex(arg_meta): + arg_type = arg_meta["type"] + if arg_type == "object": + arg_type = "dict" + if arg_type == "dict": + try: + result = build_dict_regex(arg_meta["properties"]) + except KeyError: + return "Definition does not contain 'properties' value." + elif arg_type in ["array","tuple"]: + pattern = type_to_regex(arg_meta["items"]) + result = r"\[(" + pattern + ", ){0,8}" + pattern + r"\]" + else: + result = simple_type_map[arg_type] + return result + +def build_standard_fc_regex(function_data): + out_re = r"\[" + function_data["name"] + r"\(" + args_part = ", ".join( + [ + f"{arg}=" + type_to_regex(function_data["parameters"]["properties"][arg]) + for arg in function_data["parameters"]["properties"] + + if arg in function_data["parameters"]["required"] + ] + ) + optional_part = "".join( + [ + f"(, {arg}=" + + type_to_regex(function_data["parameters"]["properties"][arg]) + + r")?" + for arg in function_data["parameters"]["properties"] + if not (arg in function_data["parameters"]["required"]) + ] + ) + return out_re + args_part + optional_part + r"\)]" + +def multi_function_fc_regex(fs): + multi_regex = "|".join([ + rf"({build_standard_fc_regex(f)})" for f in fs + ]) + return multi_regex + +class SmolMind: + def __init__(self, server_path, model_name=MODEL_NAME, debug=False): + self.model_name = model_name + self.debug = debug + self.server_path = server_path + self.instruct = True # Always use instruct mode for MCP + self.functions = [] + self.session = None + self.exit_stack = AsyncExitStack() + self.generator = None + + logger.info(f"Initializing model on device: {DEVICE}") + self.model = outlines.models.transformers( + model_name, + device=DEVICE, + model_kwargs={ + "trust_remote_code": True, + "torch_dtype": T_TYPE, + } + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + async def connect_to_server(self): + """Connect to the MCP server""" + logger.info(f"Connecting to MCP server: {self.server_path}") + + # Determine server type + is_python = self.server_path.endswith('.py') + is_js = self.server_path.endswith('.js') + + if not (is_python or is_js): + raise ValueError("Server script must be a .py or .js file") + + command = "python" if is_python else "node" + server_params = StdioServerParameters( + command=command, + args=[self.server_path], + env=None + ) + + # Connect to the server using AsyncExitStack + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + self.stdio, self.write = stdio_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + + # Initialize the session + await self.session.initialize() + + # List available tools + response = await self.session.list_tools() + mcp_tools = response.tools + + # Convert MCP tools to function format + self.functions = [] + for tool in mcp_tools: + func = { + "name": tool.name, + "description": tool.description or f"Function {tool.name}", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + + # Convert input schema to function properties + if hasattr(tool, "inputSchema") and tool.inputSchema: + if isinstance(tool.inputSchema, dict): + # Extract properties + properties = tool.inputSchema.get("properties", {}) + func["parameters"]["properties"] = properties + + # Extract required parameters + required = tool.inputSchema.get("required", []) + func["parameters"]["required"] = required + + self.functions.append(func) + + # Initialize regex generator + self.fc_regex = multi_function_fc_regex(self.functions) + self.generator = outlines.generate.regex(self.model, self.fc_regex, sampler=greedy()) + + if self.debug: + tool_names = [tool.name for tool in mcp_tools] + logger.info(f"Connected to server with tools: {tool_names}") + + if not self.functions: + logger.warning("No functions found from MCP server") + + return mcp_tools + + async def close(self): + """Close the connection to the server""" + await self.exit_stack.aclose() + + def get_function_call(self, user_prompt): + """Generate function call using regex-based generator""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + if self.instruct: + prompt = instruct_prompt(user_prompt, self.functions, self.tokenizer) + else: + prompt = continue_prompt(user_prompt, self.functions, self.tokenizer) + + response = self.generator(prompt) + + if self.debug: + logger.info(f"Functions: {self.functions}") + logger.info(f"Prompt: {prompt}") + logger.info(f"Generated response: {response}") + + return response + + async def process_query(self, user_prompt): + """Process a user query using SmolMind and MCP tools""" + if not self.functions: + return "No functions available. Please connect to an MCP server first." + + try: + # Generate the function call using regex generator + response = self.get_function_call(user_prompt) + + # Extract function name and arguments with regex + match = re.match(r'\[(.*?)\((.*?)\)\]', response) + if not match: + return f"Could not parse function call: {response}" + + function_name = match.group(1) + args_str = match.group(2) + + # Convert arguments to dictionary + args_dict = {} + if args_str: + # Regex to extract key-value pairs + pattern = r'(\w+)=("[^"]*"|\'[^\']*\'|\d+|\w+)' + for key, value in re.findall(pattern, args_str): + # Clean string values + if value.startswith('"') or value.startswith("'"): + value = value[1:-1] + # Convert numeric values + elif value.isdigit(): + value = int(value) + elif value.lower() == 'true': + value = True + elif value.lower() == 'false': + value = False + args_dict[key] = value + + # Execute the MCP tool call + if self.debug: + logger.info(f"Calling MCP tool: {function_name} with args: {args_dict}") + + result = await self.session.call_tool(function_name, args_dict) + + return result.content + + except Exception as e: + return f"Error processing query: {str(e)}" \ No newline at end of file