diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index 64fd45ac..00e2175a 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -1,23 +1,25 @@ """MLX Chat Wrapper.""" +import json +import re from typing import ( Any, Callable, Dict, Iterator, List, - Literal, Optional, Sequence, Type, Union, ) +from pydantic import PrivateAttr + from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -26,48 +28,32 @@ HumanMessage, SystemMessage, ) -from langchain_core.outputs import ( - ChatGeneration, - ChatGenerationChunk, - ChatResult, - LLMResult, -) -from langchain_core.runnables import Runnable -from langchain_core.tools import BaseTool +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult, LLMResult from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_community.llms.mlx_pipeline import MLXPipeline +from langchain_community.llms.mlx_pipeline import MLXPipeline # adjust import as needed -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" +DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful, and honest assistant." class ChatMLX(BaseChatModel): - """MLX chat models. - - Works with `MLXPipeline` LLM. - - To use, you should have the ``mlx-lm`` python package installed. - - Example: - .. code-block:: python - - from langchain_community.chat_models import chatMLX - from langchain_community.llms import MLXPipeline + """MLX chat model wrapper.""" - llm = MLXPipeline.from_model_id( - model_id="mlx-community/quantized-gemma-2b-it", - ) - chat = chatMLX(llm=llm) - - """ + @property + def _llm_type(self) -> str: + """Identifier for this LLM type (satisfies BaseChatModel).""" + return "mlx" llm: MLXPipeline system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) - tokenizer: Any = None + + _tokenizer: Any = PrivateAttr() + _tools: Optional[List[dict]] = PrivateAttr(default=None) def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self.tokenizer = self.llm.tokenizer + # stash the MLX tokenizer for later + self._tokenizer = self.llm.tokenizer def _generate( self, @@ -76,11 +62,9 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - llm_input = self._to_chat_prompt(messages) - llm_result = self.llm._generate( - prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs - ) - return self._to_chat_result(llm_result) + prompt = self._to_chat_prompt(messages) + result = self.llm._generate(prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs) + return self._to_chat_result(result) async def _agenerate( self, @@ -89,11 +73,9 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - llm_input = self._to_chat_prompt(messages) - llm_result = await self.llm._agenerate( - prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs - ) - return self._to_chat_result(llm_result) + prompt = self._to_chat_prompt(messages) + result = await self.llm._agenerate(prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs) + return self._to_chat_result(result) def _to_chat_prompt( self, @@ -101,52 +83,81 @@ def _to_chat_prompt( tokenize: bool = False, return_tensors: Optional[str] = None, ) -> str: - """Convert a list of messages into a prompt format expected by wrapped LLM.""" - if not messages: - raise ValueError("At least one HumanMessage must be provided!") - - if not isinstance(messages[-1], HumanMessage): + if not messages or not isinstance(messages[-1], HumanMessage): raise ValueError("Last message must be a HumanMessage!") - messages_dicts = [self._to_chatml_format(m) for m in messages] - return self.tokenizer.apply_chat_template( - messages_dicts, + chunks: List[Dict[str, Any]] = [] + + # If tools are bound, inject them into the system prompt + if self._tools: + names = ", ".join(t["function"]["name"] for t in self._tools) + tools_json = json.dumps({"tools": self._tools}, indent=2) + chunks.append({ + "role": "system", + "content": ( + f"You have access to the following tools:\n{tools_json}\n\n" + f"When needed, respond with a JSON object like this:\n" + f'{{"name": "", "arguments": {{"arg1": "...", "arg2": "..."}}}}\n' + f"Available tools: {names}." + ) + }) + + # Add the actual conversation history + chunks.extend(self._to_chatml_format(m) for m in messages) + + # Let MLX tokenize/apply its chat template + return self._tokenizer.apply_chat_template( + chunks, tokenize=tokenize, add_generation_prompt=True, return_tensors=return_tensors, ) - def _to_chatml_format(self, message: BaseMessage) -> dict: - """Convert LangChain message to ChatML format.""" - - if isinstance(message, SystemMessage): + def _to_chatml_format(self, msg: BaseMessage) -> Dict[str, Any]: + if isinstance(msg, SystemMessage): role = "system" - elif isinstance(message, AIMessage): - role = "assistant" - elif isinstance(message, HumanMessage): + elif isinstance(msg, HumanMessage): role = "user" + elif isinstance(msg, AIMessage): + role = "assistant" else: - raise ValueError(f"Unknown message type: {type(message)}") - - return {"role": role, "content": message.content} + raise ValueError(f"Unknown message type: {type(msg)}") + return {"role": role, "content": msg.content} @staticmethod def _to_chat_result(llm_result: LLMResult) -> ChatResult: - chat_generations = [] - - for g in llm_result.generations[0]: - chat_generation = ChatGeneration( - message=AIMessage(content=g.text), generation_info=g.generation_info + gens: List[ChatGeneration] = [] + for gen in llm_result.generations[0]: + raw = gen.text.strip() + tool_calls: List[Dict[str, Any]] = [] + + # Try full JSON parse first + try: + parsed = json.loads(raw) + if isinstance(parsed, dict) and "name" in parsed: + tool_calls = [parsed] + except json.JSONDecodeError: + # Fallback: regex extraction + name_m = re.search(r'"name"\s*:\s*"([^"]+)"', raw) + if name_m: + name = name_m.group(1) + args_m = re.search(r'"arguments"\s*:\s*({.*?})', raw, re.DOTALL) + args: Dict[str, Any] = {} + if args_m: + try: + args = json.loads(args_m.group(1)) + except json.JSONDecodeError: + pass + tool_calls = [{"name": name, "arguments": args}] + + gens.append( + ChatGeneration( + message=AIMessage(content=raw, tool_calls=tool_calls), + generation_info=gen.generation_info, + ) ) - chat_generations.append(chat_generation) - return ChatResult( - generations=chat_generations, llm_output=llm_result.llm_output - ) - - @property - def _llm_type(self) -> str: - return "mlx-chat-wrapper" + return ChatResult(generations=gens, llm_output=llm_result.llm_output) def _stream( self, @@ -156,122 +167,51 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: import mlx.core as mx + from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.utils import generate_step - try: - import mlx.core as mx - from mlx_lm.sample_utils import make_logits_processors, make_sampler - from mlx_lm.utils import generate_step - - except ImportError: - raise ImportError( - "Could not import mlx_lm python package. " - "Please install it with `pip install mlx_lm`." - ) - model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs) - temp: float = model_kwargs.get("temp", 0.0) - max_new_tokens: int = model_kwargs.get("max_tokens", 100) - repetition_penalty: Optional[float] = model_kwargs.get( - "repetition_penalty", None - ) - repetition_context_size: Optional[int] = model_kwargs.get( - "repetition_context_size", None - ) - top_p: float = model_kwargs.get("top_p", 1.0) - min_p: float = model_kwargs.get("min_p", 0.0) - min_tokens_to_keep: int = model_kwargs.get("min_tokens_to_keep", 1) - - llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np") - - prompt_tokens = mx.array(llm_input[0]) - - eos_token_id = self.tokenizer.eos_token_id - - sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep) - - logits_processors = make_logits_processors( - None, repetition_penalty, repetition_context_size - ) - - for (token, prob), n in zip( - generate_step( - prompt_tokens, - self.llm.model, - sampler=sampler, - logits_processors=logits_processors, - ), - range(max_new_tokens), + mk = kwargs.get("model_kwargs", self.llm.pipeline_kwargs) + temp = mk.get("temp", 0.0) + max_tokens = mk.get("max_tokens", 100) + rep_pen = mk.get("repetition_penalty") + rep_ctx = mk.get("repetition_context_size") + top_p = mk.get("top_p", 1.0) + min_p = mk.get("min_p", 0.0) + keep = mk.get("min_tokens_to_keep", 1) + + inp = self._to_chat_prompt(messages, tokenize=True, return_tensors="np") + prompt_tokens = mx.array(inp[0]) + eos_id = self._tokenizer.eos_token_id + + sampler = make_sampler(temp, top_p, min_p, keep) + proc = make_logits_processors(None, rep_pen, rep_ctx) + + for (token, _), _ in zip( + generate_step(prompt_tokens, self.llm.model, sampler, proc), + range(max_tokens), ): - # identify text to yield - text: Optional[str] = None - if not isinstance(token, int): - text = self.tokenizer.decode(token.item()) - else: - text = self.tokenizer.decode(token) - - # yield text, if any - if text: - chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) + txt = self._tokenizer.decode(token.item() if hasattr(token, "item") else token) + if txt: + chunk = ChatGenerationChunk(message=AIMessageChunk(content=txt)) if run_manager: - run_manager.on_llm_new_token(text, chunk=chunk) + run_manager.on_llm_new_token(txt, chunk=chunk) yield chunk - - # break if stop sequence found - if token == eos_token_id or (stop is not None and text in stop): + if token == eos_id or (stop and txt in stop): break def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + tools: Sequence[Union[Dict[str, Any], Type, Callable, Any]], *, - tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + tool_choice: Optional[Union[dict, str, bool]] = None, **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: - """Bind tool-like objects to this chat model. - - Assumes model is compatible with OpenAI tool-calling API. - - Args: - tools: A list of tool definitions to bind to this chat model. - Supports any tool definition handled by - :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. - tool_choice: Which tool to require the model to call. - Must be the name of the single provided function or - "auto" to automatically determine which function to call - (if any), or a dict of the form: - {"type": "function", "function": {"name": <>}}. - **kwargs: Any additional parameters to pass to the - :class:`~langchain.runnable.Runnable` constructor. - """ - - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - if tool_choice is not None and tool_choice: - if len(formatted_tools) != 1: - raise ValueError( - "When specifying `tool_choice`, you must provide exactly one " - f"tool. Received {len(formatted_tools)} tools." - ) - if isinstance(tool_choice, str): - if tool_choice not in ("auto", "none"): - tool_choice = { - "type": "function", - "function": {"name": tool_choice}, - } - elif isinstance(tool_choice, bool): - tool_choice = formatted_tools[0] - elif isinstance(tool_choice, dict): - if ( - formatted_tools[0]["function"]["name"] - != tool_choice["function"]["name"] - ): - raise ValueError( - f"Tool choice {tool_choice} was specified, but the only " - f"provided tool was {formatted_tools[0]['function']['name']}." - ) - else: - raise ValueError( - f"Unrecognized tool_choice type. Expected str, bool or dict. " - f"Received: {tool_choice}" - ) - kwargs["tool_choice"] = tool_choice - return super().bind(tools=formatted_tools, **kwargs) + ) -> "ChatMLX": + formatted = [convert_to_openai_tool(t) for t in tools] + self._tools = formatted + if tool_choice and len(formatted) != 1: + raise ValueError( + f"Tool choice specified but {len(formatted)} tools were bound; only one allowed." + ) + return super().bind(tools=formatted, **kwargs) +# def unbind_tools(self) -> "ChatMLX": +# """Unbind any tools.""" \ No newline at end of file