Skip to content

Commit a65e58e

Browse files
committed
docs(langchain): rewrite custom LLM provider guide with BaseChatModel support
Complete rewrite of the custom LLM provider documentation with: - Separate comprehensive guides for BaseLLM (text completion) and BaseChatModel (chat) - Correct method signatures (_call vs _generate) - Proper async implementations - Clear registration instructions (register_llm_provider vs register_chat_provider) - Working code examples with correct langchain-core imports - Important notes on choosing the right base class This addresses the gap where users were not properly guided on implementing custom chat models and were being directed to the wrong interface.
1 parent c4a2b4e commit a65e58e

File tree

1 file changed

+135
-26
lines changed

1 file changed

+135
-26
lines changed

docs/user-guides/configuration-guide/custom-initialization.md

Lines changed: 135 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,56 +37,65 @@ def init(app: LLMRails):
3737

3838
## Custom LLM Provider Registration
3939

40-
To register a custom LLM provider, you need to create a class that inherits from `BaseLanguageModel` and register it using `register_llm_provider`.
40+
NeMo Guardrails supports two types of custom LLM providers:
41+
1. **Text Completion Models** (`BaseLLM`) - For models that work with string prompts
42+
2. **Chat Models** (`BaseChatModel`) - For models that work with message-based conversations
4143

42-
It is important to implement the following methods:
44+
### Custom Text Completion LLM (BaseLLM)
4345

44-
**Required**:
46+
To register a custom text completion LLM provider, create a class that inherits from `BaseLLM` and register it using `register_llm_provider`.
4547

46-
- `_call`
47-
- `_llm_type`
48+
**Required methods:**
49+
- `_call` - Synchronous text completion
50+
- `_llm_type` - Returns the LLM type identifier
4851

49-
**Optional**:
50-
51-
- `_acall`
52-
- `_astream`
53-
- `_stream`
54-
- `_identifying_params`
55-
56-
In other words, to create your custom LLM provider, you need to implement the following interface methods: `_call`, `_llm_type`, and optionally `_acall`, `_astream`, `_stream`, and `_identifying_params`. Here's how you can do it:
52+
**Optional methods:**
53+
- `_acall` - Asynchronous text completion (recommended)
54+
- `_stream` - Streaming text completion
55+
- `_astream` - Async streaming text completion
56+
- `_identifying_params` - Returns parameters for model identification
5757

5858
```python
5959
from typing import Any, Iterator, List, Optional
6060
61-
from langchain.base_language import BaseLanguageModel
6261
from langchain_core.callbacks.manager import (
63-
CallbackManagerForLLMRun,
6462
AsyncCallbackManagerForLLMRun,
63+
CallbackManagerForLLMRun,
6564
)
65+
from langchain_core.language_models import BaseLLM
6666
from langchain_core.outputs import GenerationChunk
6767
6868
from nemoguardrails.llm.providers import register_llm_provider
6969
7070
71-
class MyCustomLLM(BaseLanguageModel):
71+
class MyCustomTextLLM(BaseLLM):
72+
"""Custom text completion LLM."""
73+
74+
@property
75+
def _llm_type(self) -> str:
76+
return "custom_text_llm"
7277
7378
def _call(
7479
self,
7580
prompt: str,
7681
stop: Optional[List[str]] = None,
7782
run_manager: Optional[CallbackManagerForLLMRun] = None,
78-
**kwargs,
83+
**kwargs: Any,
7984
) -> str:
80-
pass
85+
"""Synchronous text completion."""
86+
# Your implementation here
87+
return "Generated text response"
8188
8289
async def _acall(
8390
self,
8491
prompt: str,
8592
stop: Optional[List[str]] = None,
8693
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
87-
**kwargs,
94+
**kwargs: Any,
8895
) -> str:
89-
pass
96+
"""Asynchronous text completion (recommended)."""
97+
# Your async implementation here
98+
return "Generated text response"
9099
91100
def _stream(
92101
self,
@@ -95,22 +104,122 @@ class MyCustomLLM(BaseLanguageModel):
95104
run_manager: Optional[CallbackManagerForLLMRun] = None,
96105
**kwargs: Any,
97106
) -> Iterator[GenerationChunk]:
98-
pass
107+
"""Optional: Streaming text completion."""
108+
# Yield chunks of text
109+
yield GenerationChunk(text="chunk1")
110+
yield GenerationChunk(text="chunk2")
111+
112+
113+
register_llm_provider("custom_text_llm", MyCustomTextLLM)
114+
```
115+
116+
### Custom Chat Model (BaseChatModel)
117+
118+
To register a custom chat model, create a class that inherits from `BaseChatModel` and register it using `register_chat_provider`.
119+
120+
**Required methods:**
121+
- `_generate` - Synchronous chat completion
122+
- `_llm_type` - Returns the LLM type identifier
123+
124+
**Optional methods:**
125+
- `_agenerate` - Asynchronous chat completion (recommended)
126+
- `_stream` - Streaming chat completion
127+
- `_astream` - Async streaming chat completion
128+
129+
```python
130+
from typing import Any, Iterator, List, Optional
131+
132+
from langchain_core.callbacks.manager import (
133+
AsyncCallbackManagerForLLMRun,
134+
CallbackManagerForLLMRun,
135+
)
136+
from langchain_core.language_models import BaseChatModel
137+
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
138+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
139+
140+
from nemoguardrails.llm.providers import register_chat_provider
141+
142+
143+
class MyCustomChatModel(BaseChatModel):
144+
"""Custom chat model."""
99145
100-
# rest of the implementation
101-
...
146+
@property
147+
def _llm_type(self) -> str:
148+
return "custom_chat_model"
102149
103-
register_llm_provider("custom_llm", MyCustomLLM)
150+
def _generate(
151+
self,
152+
messages: List[BaseMessage],
153+
stop: Optional[List[str]] = None,
154+
run_manager: Optional[CallbackManagerForLLMRun] = None,
155+
**kwargs: Any,
156+
) -> ChatResult:
157+
"""Synchronous chat completion."""
158+
# Convert messages to your model's format and generate response
159+
response_text = "Generated chat response"
160+
161+
message = AIMessage(content=response_text)
162+
generation = ChatGeneration(message=message)
163+
return ChatResult(generations=[generation])
164+
165+
async def _agenerate(
166+
self,
167+
messages: List[BaseMessage],
168+
stop: Optional[List[str]] = None,
169+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
170+
**kwargs: Any,
171+
) -> ChatResult:
172+
"""Asynchronous chat completion (recommended)."""
173+
# Your async implementation
174+
response_text = "Generated chat response"
175+
176+
message = AIMessage(content=response_text)
177+
generation = ChatGeneration(message=message)
178+
return ChatResult(generations=[generation])
179+
180+
def _stream(
181+
self,
182+
messages: List[BaseMessage],
183+
stop: Optional[List[str]] = None,
184+
run_manager: Optional[CallbackManagerForLLMRun] = None,
185+
**kwargs: Any,
186+
) -> Iterator[ChatGenerationChunk]:
187+
"""Optional: Streaming chat completion."""
188+
# Yield chunks
189+
chunk = ChatGenerationChunk(message=AIMessageChunk(content="chunk1"))
190+
yield chunk
191+
192+
193+
register_chat_provider("custom_chat_model", MyCustomChatModel)
104194
```
105195

106-
You can then use the custom LLM provider in your configuration:
196+
### Using Custom LLM Providers
197+
198+
After registering your custom provider, you can use it in your configuration:
107199

108200
```yaml
109201
models:
110202
- type: main
111-
engine: custom_llm
203+
engine: custom_text_llm # or custom_chat_model
112204
```
113205

206+
### Important Notes
207+
208+
1. **Import from langchain-core:** Always import base classes from `langchain_core.language_models`:
209+
```python
210+
from langchain_core.language_models import BaseLLM, BaseChatModel
211+
```
212+
213+
2. **Implement async methods:** For better performance, always implement `_acall` (for BaseLLM) or `_agenerate` (for BaseChatModel).
214+
215+
3. **Choose the right base class:**
216+
- Use `BaseLLM` for text completion models (prompt → text)
217+
- Use `BaseChatModel` for chat models (messages → message)
218+
219+
4. **Registration functions:**
220+
- Use `register_llm_provider()` for `BaseLLM` subclasses
221+
- Use `register_chat_provider()` for `BaseChatModel` subclasses
222+
114223
## Custom Embedding Provider Registration
115224

116225
You can also register a custom embedding provider by using the `LLMRails.register_embedding_provider` function.

0 commit comments

Comments
 (0)