Skip to content

Commit e91453b

Browse files
authored
Fix ollama arguments (#395)
* Fix ollama arguments * Fix CI + update doc and examples * CHANGELOG
1 parent e4a1f5c commit e91453b

File tree

5 files changed

+99
-21
lines changed

5 files changed

+99
-21
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
### Fixed
66

77
- Fixed documentation for PdfLoader
8+
- Fixed a bug where the `format` argument for `OllamaLLM` was not propagated to the client.
9+
810

911
## 1.9.0
1012

docs/source/user_guide_rag.rst

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ it can be queried using the following:
225225
from neo4j_graphrag.llm import OllamaLLM
226226
llm = OllamaLLM(
227227
model_name="orca-mini",
228+
# model_params={"options": {"temperature": 0}, "format": "json"},
228229
# host="...", # when using a remote server
229230
)
230231
llm.invoke("say something")
@@ -305,17 +306,17 @@ Default Rate Limit Handler
305306
Rate limiting is enabled by default for all LLM instances with the following configuration:
306307

307308
- **Max attempts**: 3
308-
- **Min wait**: 1.0 seconds
309+
- **Min wait**: 1.0 seconds
309310
- **Max wait**: 60.0 seconds
310311
- **Multiplier**: 2.0 (exponential backoff)
311312

312313
.. code:: python
313314
314315
from neo4j_graphrag.llm import OpenAILLM
315-
316+
316317
# Rate limiting is automatically enabled
317318
llm = OpenAILLM(model_name="gpt-4o")
318-
319+
319320
# The LLM will automatically retry on rate limit errors
320321
response = llm.invoke("Hello, world!")
321322
@@ -327,7 +328,7 @@ Rate limiting is enabled by default for all LLM instances with the following con
327328
328329
from neo4j_graphrag.llm import OpenAILLM
329330
from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler
330-
331+
331332
# Customize rate limiting parameters
332333
llm = OpenAILLM(
333334
model_name="gpt-4o",
@@ -348,15 +349,15 @@ You can customize the rate limiting behavior by creating your own rate limit han
348349
349350
from neo4j_graphrag.llm import AnthropicLLM
350351
from neo4j_graphrag.llm.rate_limit import RateLimitHandler
351-
352+
352353
class CustomRateLimitHandler(RateLimitHandler):
353354
"""Implement your custom rate limiting strategy."""
354355
# Implement required methods: handle_sync, handle_async
355356
pass
356-
357+
357358
# Create custom rate limit handler and pass it to the LLM interface
358359
custom_handler = CustomRateLimitHandler()
359-
360+
360361
llm = AnthropicLLM(
361362
model_name="claude-3-sonnet-20240229",
362363
rate_limit_handler=custom_handler,
@@ -370,7 +371,7 @@ For high-throughput applications or when you handle rate limiting externally, yo
370371
.. code:: python
371372
372373
from neo4j_graphrag.llm import CohereLLM, NoOpRateLimitHandler
373-
374+
374375
# Disable rate limiting completely
375376
llm = CohereLLM(
376377
model_name="command-r-plus",

examples/customize/llms/ollama_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
llm = OllamaLLM(
88
model_name="<model_name>",
9+
# model_params={"options": {"temperature": 0}, "format": "json"},
910
# host="...", # if using a remote server
1011
)
1112
res: LLMResponse = llm.invoke("What is the additive color model?")

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import warnings
1718
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast
1819

1920
from pydantic import ValidationError
@@ -59,6 +60,19 @@ def __init__(
5960
self.async_client = ollama.AsyncClient(
6061
**kwargs,
6162
)
63+
if "stream" in self.model_params:
64+
raise ValueError("Streaming is not supported by the OllamaLLM wrapper")
65+
# bug-fix with backward compatibility:
66+
# we mistakenly passed all "model_params" under the options argument
67+
# next two lines to be removed in 2.0
68+
if not any(
69+
key in self.model_params for key in ("options", "format", "keep_alive")
70+
):
71+
warnings.warn(
72+
"""Passing options directly without including them in an 'options' key is deprecated. Ie you must use model_params={"options": {"temperature": 0}}""",
73+
DeprecationWarning,
74+
)
75+
self.model_params = {"options": self.model_params}
6276

6377
def get_messages(
6478
self,
@@ -104,7 +118,7 @@ def invoke(
104118
response = self.client.chat(
105119
model=self.model_name,
106120
messages=self.get_messages(input, message_history, system_instruction),
107-
options=self.model_params,
121+
**self.model_params,
108122
)
109123
content = response.message.content or ""
110124
return LLMResponse(content=content)

tests/unit/llm/test_ollama_llm.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,80 @@ def test_ollama_llm_missing_dependency(mock_import: Mock) -> None:
3535

3636

3737
@patch("builtins.__import__")
38-
def test_ollama_llm_happy_path(mock_import: Mock) -> None:
38+
def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None:
3939
mock_ollama = get_mock_ollama()
4040
mock_import.return_value = mock_ollama
4141
mock_ollama.Client.return_value.chat.return_value = MagicMock(
4242
message=MagicMock(content="ollama chat response"),
4343
)
4444
model = "gpt"
4545
model_params = {"temperature": 0.3}
46+
with pytest.warns(DeprecationWarning) as record:
47+
llm = OllamaLLM(
48+
model,
49+
model_params=model_params,
50+
)
51+
assert len(record) == 1
52+
assert isinstance(record[0].message, Warning)
53+
assert (
54+
'you must use model_params={"options": {"temperature": 0}}'
55+
in record[0].message.args[0]
56+
)
57+
58+
question = "What is graph RAG?"
59+
res = llm.invoke(question)
60+
assert isinstance(res, LLMResponse)
61+
assert res.content == "ollama chat response"
62+
messages = [
63+
{"role": "user", "content": question},
64+
]
65+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
66+
model=model, messages=messages, options={"temperature": 0.3}
67+
)
68+
69+
70+
@patch("builtins.__import__")
71+
def test_ollama_llm_unsupported_streaming(mock_import: Mock) -> None:
72+
mock_ollama = get_mock_ollama()
73+
mock_import.return_value = mock_ollama
74+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
75+
message=MagicMock(content="ollama chat response"),
76+
)
77+
model = "gpt"
78+
model_params = {"stream": True}
79+
with pytest.raises(ValueError):
80+
OllamaLLM(
81+
model,
82+
model_params=model_params,
83+
)
84+
85+
86+
@patch("builtins.__import__")
87+
def test_ollama_llm_happy_path(mock_import: Mock) -> None:
88+
mock_ollama = get_mock_ollama()
89+
mock_import.return_value = mock_ollama
90+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
91+
message=MagicMock(content="ollama chat response"),
92+
)
93+
model = "gpt"
94+
options = {"temperature": 0.3}
95+
model_params = {"options": options, "format": "json"}
4696
question = "What is graph RAG?"
4797
llm = OllamaLLM(
48-
model,
98+
model_name=model,
4999
model_params=model_params,
50100
)
51-
52101
res = llm.invoke(question)
53102
assert isinstance(res, LLMResponse)
54103
assert res.content == "ollama chat response"
55104
messages = [
56105
{"role": "user", "content": question},
57106
]
58107
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
59-
model=model, messages=messages, options=model_params
108+
model=model,
109+
messages=messages,
110+
options=options,
111+
format="json",
60112
)
61113

62114

@@ -68,7 +120,8 @@ def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) ->
68120
message=MagicMock(content="ollama chat response"),
69121
)
70122
model = "gpt"
71-
model_params = {"temperature": 0.3}
123+
options = {"temperature": 0.3}
124+
model_params = {"options": options, "format": "json"}
72125
llm = OllamaLLM(
73126
model,
74127
model_params=model_params,
@@ -81,7 +134,10 @@ def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) ->
81134
messages = [{"role": "system", "content": system_instruction}]
82135
messages.append({"role": "user", "content": question})
83136
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
84-
model=model, messages=messages, options=model_params
137+
model=model,
138+
messages=messages,
139+
options=options,
140+
format="json",
85141
)
86142

87143

@@ -93,7 +149,8 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
93149
message=MagicMock(content="ollama chat response"),
94150
)
95151
model = "gpt"
96-
model_params = {"temperature": 0.3}
152+
options = {"temperature": 0.3}
153+
model_params = {"options": options}
97154
llm = OllamaLLM(
98155
model,
99156
model_params=model_params,
@@ -109,7 +166,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
109166
messages = [m for m in message_history]
110167
messages.append({"role": "user", "content": question})
111168
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
112-
model=model, messages=messages, options=model_params
169+
model=model, messages=messages, options=options
113170
)
114171

115172

@@ -123,7 +180,8 @@ def test_ollama_invoke_with_message_history_and_system_instruction(
123180
message=MagicMock(content="ollama chat response"),
124181
)
125182
model = "gpt"
126-
model_params = {"temperature": 0.3}
183+
options = {"temperature": 0.3}
184+
model_params = {"options": options}
127185
system_instruction = "You are a helpful assistant."
128186
llm = OllamaLLM(
129187
model,
@@ -145,7 +203,7 @@ def test_ollama_invoke_with_message_history_and_system_instruction(
145203
messages.extend(message_history)
146204
messages.append({"role": "user", "content": question})
147205
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
148-
model=model, messages=messages, options=model_params
206+
model=model, messages=messages, options=options
149207
)
150208
assert llm.client.chat.call_count == 1 # type: ignore
151209

@@ -156,7 +214,8 @@ def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock)
156214
mock_import.return_value = mock_ollama
157215
mock_ollama.ResponseError = ollama.ResponseError
158216
model = "gpt"
159-
model_params = {"temperature": 0.3}
217+
options = {"temperature": 0.3}
218+
model_params = {"options": options}
160219
system_instruction = "You are a helpful assistant."
161220
llm = OllamaLLM(
162221
model,
@@ -187,7 +246,8 @@ async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock:
187246

188247
mock_ollama.AsyncClient.return_value.chat = mock_chat_async
189248
model = "gpt"
190-
model_params = {"temperature": 0.3}
249+
options = {"temperature": 0.3}
250+
model_params = {"options": options}
191251
question = "What is graph RAG?"
192252
llm = OllamaLLM(
193253
model,

0 commit comments

Comments
 (0)