Skip to content

Commit 0c2a65e

Browse files
committed
feat(runnable): complete rewrite of RunnableRails with full LangChain Runnable protocol support
- Implement comprehensive async/sync invoke, batch, and streaming support - Add robust input/output transformation for all LangChain formats (ChatPromptValue, BaseMessage, dict, string) - Enhance chaining behavior with intelligent __or__ method handling RunnableBinding and complex chains - Add concurrency controls, error handling, and configurable blocking messages - Implement proper tool calling support with tool call passthrough - Add extensive test suite (14 test files, 2800+ lines) covering all major functionality including batching, streaming, composition, piping, and tool calling - Reorganize and expand test structure for better maintainability apply review suggestions
1 parent d809788 commit 0c2a65e

14 files changed

+2883
-124
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 764 additions & 111 deletions
Large diffs are not rendered by default.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Tests for basic RunnableRails operations (invoke, async, batch, stream).
18+
"""
19+
20+
import pytest
21+
from langchain_core.messages import AIMessage, HumanMessage
22+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
23+
from langchain_core.runnables import RunnablePassthrough
24+
25+
from nemoguardrails import RailsConfig
26+
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
27+
from tests.utils import FakeLLM
28+
29+
30+
def test_runnable_rails_basic():
31+
"""Test basic functionality of updated RunnableRails."""
32+
llm = FakeLLM(
33+
responses=[
34+
"Hello there! How can I help you today?",
35+
]
36+
)
37+
config = RailsConfig.from_content(config={"models": []})
38+
model_with_rails = RunnableRails(config, llm=llm)
39+
40+
result = model_with_rails.invoke("Hi there")
41+
42+
assert isinstance(result, str)
43+
assert "Hello there" in result
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_runnable_rails_async():
48+
"""Test async functionality of updated RunnableRails."""
49+
llm = FakeLLM(
50+
responses=[
51+
"Hello there! How can I help you today?",
52+
]
53+
)
54+
config = RailsConfig.from_content(config={"models": []})
55+
model_with_rails = RunnableRails(config, llm=llm)
56+
57+
result = await model_with_rails.ainvoke("Hi there")
58+
59+
assert isinstance(result, str)
60+
assert "Hello there" in result
61+
62+
63+
def test_runnable_rails_batch():
64+
"""Test batch functionality of updated RunnableRails."""
65+
llm = FakeLLM(
66+
responses=[
67+
"Response 1",
68+
"Response 2",
69+
]
70+
)
71+
config = RailsConfig.from_content(config={"models": []})
72+
model_with_rails = RunnableRails(config, llm=llm)
73+
74+
results = model_with_rails.batch(["Question 1", "Question 2"])
75+
76+
assert len(results) == 2
77+
assert results[0] == "Response 1"
78+
assert results[1] == "Response 2"
79+
80+
81+
def test_updated_runnable_rails_stream():
82+
"""Test streaming functionality of updated RunnableRails."""
83+
llm = FakeLLM(
84+
responses=[
85+
"Hello there!",
86+
]
87+
)
88+
config = RailsConfig.from_content(config={"models": []})
89+
model_with_rails = RunnableRails(config, llm=llm)
90+
91+
chunks = []
92+
for chunk in model_with_rails.stream("Hi there"):
93+
chunks.append(chunk)
94+
95+
assert len(chunks) == 2
96+
assert chunks[0].content == "Hello "
97+
assert chunks[1].content == "there!"
98+
99+
100+
def test_runnable_rails_with_message_history():
101+
"""Test handling of message history with updated RunnableRails."""
102+
llm = FakeLLM(
103+
responses=[
104+
"Yes, Paris is the capital of France.",
105+
]
106+
)
107+
config = RailsConfig.from_content(config={"models": []})
108+
model_with_rails = RunnableRails(config, llm=llm)
109+
110+
history = [
111+
HumanMessage(content="Hello"),
112+
AIMessage(content="Hi there!"),
113+
HumanMessage(content="What's the capital of France?"),
114+
]
115+
116+
result = model_with_rails.invoke(history)
117+
118+
assert isinstance(result, AIMessage)
119+
assert "Paris" in result.content
120+
121+
122+
def test_runnable_rails_with_chat_template():
123+
"""Test updated RunnableRails with chat templates."""
124+
llm = FakeLLM(
125+
responses=[
126+
"Yes, Paris is the capital of France.",
127+
]
128+
)
129+
config = RailsConfig.from_content(config={"models": []})
130+
model_with_rails = RunnableRails(config, llm=llm)
131+
132+
prompt = ChatPromptTemplate.from_messages(
133+
[
134+
MessagesPlaceholder(variable_name="history"),
135+
("human", "{question}"),
136+
]
137+
)
138+
139+
chain = prompt | model_with_rails
140+
141+
result = chain.invoke(
142+
{
143+
"history": [
144+
HumanMessage(content="Hello"),
145+
AIMessage(content="Hi there!"),
146+
],
147+
"question": "What's the capital of France?",
148+
}
149+
)
150+
151+
assert isinstance(result, AIMessage)
152+
assert "Paris" in result.content
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for batch_as_completed methods."""
17+
18+
import pytest
19+
20+
from nemoguardrails import RailsConfig
21+
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
22+
from tests.utils import FakeLLM
23+
24+
25+
@pytest.fixture
26+
def rails():
27+
"""Create a RunnableRails instance for testing."""
28+
config = RailsConfig.from_content(config={"models": []})
29+
llm = FakeLLM(responses=["response 1", "response 2", "response 3"])
30+
return RunnableRails(config, llm=llm)
31+
32+
33+
def test_batch_as_completed_exists(rails):
34+
"""Test that batch_as_completed method exists."""
35+
assert hasattr(rails, "batch_as_completed")
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_abatch_as_completed_exists(rails):
40+
"""Test that abatch_as_completed method exists."""
41+
assert hasattr(rails, "abatch_as_completed")
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
from langchain_core.messages import AIMessage, HumanMessage
18+
19+
from nemoguardrails import RailsConfig
20+
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
21+
from tests.utils import FakeLLM
22+
23+
24+
def test_batch_processing():
25+
"""Test batch processing of multiple inputs."""
26+
llm = FakeLLM(
27+
responses=[
28+
"Paris.",
29+
"Rome.",
30+
"Berlin.",
31+
]
32+
)
33+
config = RailsConfig.from_content(config={"models": []})
34+
model_with_rails = RunnableRails(config, llm=llm)
35+
36+
inputs = [
37+
"What's the capital of France?",
38+
"What's the capital of Italy?",
39+
"What's the capital of Germany?",
40+
]
41+
42+
results = model_with_rails.batch(inputs)
43+
44+
assert len(results) == 3
45+
assert results[0] == "Paris."
46+
assert results[1] == "Rome."
47+
assert results[2] == "Berlin."
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_abatch_processing():
52+
"""Test async batch processing of multiple inputs."""
53+
llm = FakeLLM(
54+
responses=[
55+
"Paris.",
56+
"Rome.",
57+
"Berlin.",
58+
]
59+
)
60+
config = RailsConfig.from_content(config={"models": []})
61+
model_with_rails = RunnableRails(config, llm=llm)
62+
63+
inputs = [
64+
"What's the capital of France?",
65+
"What's the capital of Italy?",
66+
"What's the capital of Germany?",
67+
]
68+
69+
results = await model_with_rails.abatch(inputs)
70+
71+
assert len(results) == 3
72+
assert results[0] == "Paris."
73+
assert results[1] == "Rome."
74+
assert results[2] == "Berlin."
75+
76+
77+
def test_batch_with_different_input_types():
78+
"""Test batch processing with different input types."""
79+
llm = FakeLLM(
80+
responses=[
81+
"Paris.",
82+
"Rome.",
83+
"Berlin.",
84+
]
85+
)
86+
config = RailsConfig.from_content(config={"models": []})
87+
model_with_rails = RunnableRails(config, llm=llm)
88+
89+
inputs = [
90+
"What's the capital of France?",
91+
HumanMessage(content="What's the capital of Italy?"),
92+
{"input": "What's the capital of Germany?"},
93+
]
94+
95+
results = model_with_rails.batch(inputs)
96+
97+
assert len(results) == 3
98+
assert results[0] == "Paris."
99+
assert isinstance(results[1], AIMessage)
100+
assert results[1].content == "Rome."
101+
assert isinstance(results[2], dict)
102+
assert results[2]["output"] == "Berlin."
103+
104+
105+
def test_stream_output():
106+
"""Test streaming output (simplified for now)."""
107+
llm = FakeLLM(
108+
responses=[
109+
"Paris.",
110+
]
111+
)
112+
config = RailsConfig.from_content(config={"models": []})
113+
model_with_rails = RunnableRails(config, llm=llm)
114+
115+
chunks = []
116+
for chunk in model_with_rails.stream("What's the capital of France?"):
117+
chunks.append(chunk)
118+
119+
# Currently, stream just yields the full response as a single chunk
120+
assert len(chunks) == 1
121+
assert chunks[0].content == "Paris."
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_astream_output():
126+
"""Test async streaming output (simplified for now)."""
127+
llm = FakeLLM(
128+
responses=[
129+
"hello what can you do?",
130+
],
131+
streaming=True,
132+
)
133+
config = RailsConfig.from_content(config={"models": [], "streaming": True})
134+
model_with_rails = RunnableRails(config, llm=llm)
135+
136+
# Collect all chunks from the stream
137+
chunks = []
138+
async for chunk in model_with_rails.astream("What's the capital of France?"):
139+
chunks.append(chunk)
140+
141+
# Stream should yield individual word chunks
142+
assert len(chunks) == 5
143+
assert chunks[0].content == "hello "
144+
assert chunks[1].content == "what "
145+
assert chunks[2].content == "can "
146+
assert chunks[3].content == "you "
147+
assert chunks[4].content == "do?"

0 commit comments

Comments
 (0)