Skip to content

fix: enable tool calling for ChatMLX (fixes langchain-ai/langchain#31… #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 120 additions & 180 deletions libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
"""MLX Chat Wrapper."""

import json
import re
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
)

from pydantic import PrivateAttr

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -26,48 +28,32 @@
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult, LLMResult
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_community.llms.mlx_pipeline import MLXPipeline
from langchain_community.llms.mlx_pipeline import MLXPipeline # adjust import as needed

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful, and honest assistant."


class ChatMLX(BaseChatModel):
"""MLX chat models.

Works with `MLXPipeline` LLM.

To use, you should have the ``mlx-lm`` python package installed.

Example:
.. code-block:: python

from langchain_community.chat_models import chatMLX
from langchain_community.llms import MLXPipeline
"""MLX chat model wrapper."""

llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
)
chat = chatMLX(llm=llm)

"""
@property
def _llm_type(self) -> str:
"""Identifier for this LLM type (satisfies BaseChatModel)."""
return "mlx"

llm: MLXPipeline
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None

_tokenizer: Any = PrivateAttr()
_tools: Optional[List[dict]] = PrivateAttr(default=None)

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tokenizer = self.llm.tokenizer
# stash the MLX tokenizer for later
self._tokenizer = self.llm.tokenizer

def _generate(
self,
Expand All @@ -76,11 +62,9 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
prompt = self._to_chat_prompt(messages)
result = self.llm._generate(prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs)
return self._to_chat_result(result)

async def _agenerate(
self,
Expand All @@ -89,64 +73,91 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
prompt = self._to_chat_prompt(messages)
result = await self.llm._agenerate(prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs)
return self._to_chat_result(result)

def _to_chat_prompt(
self,
messages: List[BaseMessage],
tokenize: bool = False,
return_tensors: Optional[str] = None,
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")

if not isinstance(messages[-1], HumanMessage):
if not messages or not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!")

messages_dicts = [self._to_chatml_format(m) for m in messages]
return self.tokenizer.apply_chat_template(
messages_dicts,
chunks: List[Dict[str, Any]] = []

# If tools are bound, inject them into the system prompt
if self._tools:
names = ", ".join(t["function"]["name"] for t in self._tools)
tools_json = json.dumps({"tools": self._tools}, indent=2)
chunks.append({
"role": "system",
"content": (
f"You have access to the following tools:\n{tools_json}\n\n"
f"When needed, respond with a JSON object like this:\n"
f'{{"name": "<tool_name>", "arguments": {{"arg1": "...", "arg2": "..."}}}}\n'
f"Available tools: {names}."
)
})

# Add the actual conversation history
chunks.extend(self._to_chatml_format(m) for m in messages)

# Let MLX tokenize/apply its chat template
return self._tokenizer.apply_chat_template(
chunks,
tokenize=tokenize,
add_generation_prompt=True,
return_tensors=return_tensors,
)

def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""

if isinstance(message, SystemMessage):
def _to_chatml_format(self, msg: BaseMessage) -> Dict[str, Any]:
if isinstance(msg, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, HumanMessage):
elif isinstance(msg, HumanMessage):
role = "user"
elif isinstance(msg, AIMessage):
role = "assistant"
else:
raise ValueError(f"Unknown message type: {type(message)}")

return {"role": role, "content": message.content}
raise ValueError(f"Unknown message type: {type(msg)}")
return {"role": role, "content": msg.content}

@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []

for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
gens: List[ChatGeneration] = []
for gen in llm_result.generations[0]:
raw = gen.text.strip()
tool_calls: List[Dict[str, Any]] = []

# Try full JSON parse first
try:
parsed = json.loads(raw)
if isinstance(parsed, dict) and "name" in parsed:
tool_calls = [parsed]
except json.JSONDecodeError:
# Fallback: regex extraction
name_m = re.search(r'"name"\s*:\s*"([^"]+)"', raw)
if name_m:
name = name_m.group(1)
args_m = re.search(r'"arguments"\s*:\s*({.*?})', raw, re.DOTALL)
args: Dict[str, Any] = {}
if args_m:
try:
args = json.loads(args_m.group(1))
except json.JSONDecodeError:
pass
tool_calls = [{"name": name, "arguments": args}]

gens.append(
ChatGeneration(
message=AIMessage(content=raw, tool_calls=tool_calls),
generation_info=gen.generation_info,
)
)
chat_generations.append(chat_generation)

return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)

@property
def _llm_type(self) -> str:
return "mlx-chat-wrapper"
return ChatResult(generations=gens, llm_output=llm_result.llm_output)

def _stream(
self,
Expand All @@ -156,122 +167,51 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step

try:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step

except ImportError:
raise ImportError(
"Could not import mlx_lm python package. "
"Please install it with `pip install mlx_lm`."
)
model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs)
temp: float = model_kwargs.get("temp", 0.0)
max_new_tokens: int = model_kwargs.get("max_tokens", 100)
repetition_penalty: Optional[float] = model_kwargs.get(
"repetition_penalty", None
)
repetition_context_size: Optional[int] = model_kwargs.get(
"repetition_context_size", None
)
top_p: float = model_kwargs.get("top_p", 1.0)
min_p: float = model_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = model_kwargs.get("min_tokens_to_keep", 1)

llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")

prompt_tokens = mx.array(llm_input[0])

eos_token_id = self.tokenizer.eos_token_id

sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)

logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)

for (token, prob), n in zip(
generate_step(
prompt_tokens,
self.llm.model,
sampler=sampler,
logits_processors=logits_processors,
),
range(max_new_tokens),
mk = kwargs.get("model_kwargs", self.llm.pipeline_kwargs)
temp = mk.get("temp", 0.0)
max_tokens = mk.get("max_tokens", 100)
rep_pen = mk.get("repetition_penalty")
rep_ctx = mk.get("repetition_context_size")
top_p = mk.get("top_p", 1.0)
min_p = mk.get("min_p", 0.0)
keep = mk.get("min_tokens_to_keep", 1)

inp = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")
prompt_tokens = mx.array(inp[0])
eos_id = self._tokenizer.eos_token_id

sampler = make_sampler(temp, top_p, min_p, keep)
proc = make_logits_processors(None, rep_pen, rep_ctx)

for (token, _), _ in zip(
generate_step(prompt_tokens, self.llm.model, sampler, proc),
range(max_tokens),
):
# identify text to yield
text: Optional[str] = None
if not isinstance(token, int):
text = self.tokenizer.decode(token.item())
else:
text = self.tokenizer.decode(token)

# yield text, if any
if text:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
txt = self._tokenizer.decode(token.item() if hasattr(token, "item") else token)
if txt:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=txt))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
run_manager.on_llm_new_token(txt, chunk=chunk)
yield chunk

# break if stop sequence found
if token == eos_token_id or (stop is not None and text in stop):
if token == eos_id or (stop and txt in stop):
break

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[Dict[str, Any], Type, Callable, Any]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
tool_choice: Optional[Union[dict, str, bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.

Assumes model is compatible with OpenAI tool-calling API.

Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, bool):
tool_choice = formatted_tools[0]
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
) -> "ChatMLX":
formatted = [convert_to_openai_tool(t) for t in tools]
self._tools = formatted
if tool_choice and len(formatted) != 1:
raise ValueError(
f"Tool choice specified but {len(formatted)} tools were bound; only one allowed."
)
return super().bind(tools=formatted, **kwargs)
# def unbind_tools(self) -> "ChatMLX":
# """Unbind any tools."""