Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class LLMGenerationActions:
def __init__(
self,
config: RailsConfig,
llm: Union[BaseLLM, BaseChatModel],
llm: Optional[Union[BaseLLM, BaseChatModel]],
llm_task_manager: LLMTaskManager,
get_embedding_search_provider_instance: Callable[
[Optional[EmbeddingSearchProvider]], EmbeddingsIndex
Expand Down
31 changes: 24 additions & 7 deletions nemoguardrails/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,42 @@
# limitations under the License.

import contextvars
from typing import Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)
if TYPE_CHECKING:
from nemoguardrails.logging.explain import ExplainInfo
from nemoguardrails.rails.llm.options import GenerationOptions, LLMStats
from nemoguardrails.streaming import StreamingHandler

streaming_handler_var: contextvars.ContextVar[
Optional["StreamingHandler"]
] = contextvars.ContextVar("streaming_handler", default=None)

# The object that holds additional explanation information.
explain_info_var = contextvars.ContextVar("explain_info", default=None)
explain_info_var: contextvars.ContextVar[
Optional["ExplainInfo"]
] = contextvars.ContextVar("explain_info", default=None)

# The current LLM call.
llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None)
llm_call_info_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"llm_call_info", default=None
)

# All the generation options applicable to the current context.
generation_options_var = contextvars.ContextVar("generation_options", default=None)
generation_options_var: contextvars.ContextVar[
Optional["GenerationOptions"]
] = contextvars.ContextVar("generation_options", default=None)

# The stats about the LLM calls.
llm_stats_var = contextvars.ContextVar("llm_stats", default=None)
llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar(
"llm_stats", default=None
)

# The raw LLM request that comes from the user.
# This is used in passthrough mode.
raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None)
raw_llm_request: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar(
"raw_llm_request", default=None
)

reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"reasoning_trace", default=None
Expand Down
17 changes: 8 additions & 9 deletions nemoguardrails/rails/llm/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, NamedTuple
from typing import TYPE_CHECKING, AsyncGenerator, List, NamedTuple

if TYPE_CHECKING:
from collections.abc import AsyncIterator

from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig

Expand Down Expand Up @@ -111,9 +114,7 @@ def format_chunks(self, chunks: List[str]) -> str:
...

@abstractmethod
async def process_stream(
self, streaming_handler
) -> AsyncGenerator[ChunkBatch, None]:
async def process_stream(self, streaming_handler):
"""Process streaming chunks and yield chunk batches.

This is the main method that concrete buffer strategies must implement.
Expand All @@ -138,9 +139,9 @@ async def process_stream(
... print(f"Processing: {context_formatted}")
... print(f"User: {user_formatted}")
"""
...
yield ChunkBatch([], []) # pragma: no cover

async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]:
async def __call__(self, streaming_handler):
"""Callable interface that delegates to process_stream.

It delegates to the `process_stream` method and can
Expand Down Expand Up @@ -256,9 +257,7 @@ def from_config(cls, config: OutputRailsStreamingConfig):
buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size
)

async def process_stream(
self, streaming_handler
) -> AsyncGenerator[ChunkBatch, None]:
async def process_stream(self, streaming_handler):
"""Process streaming chunks using rolling buffer strategy.

This method implements the rolling buffer logic, accumulating chunks
Expand Down
8 changes: 5 additions & 3 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,9 @@ def _load_path(

# the first .railsignore file found from cwd down to its subdirectories
railsignore_path = utils.get_railsignore_path(config_path)
ignore_patterns = utils.get_railsignore_patterns(railsignore_path)
ignore_patterns = (
utils.get_railsignore_patterns(railsignore_path) if railsignore_path else set()
)

if os.path.isdir(config_path):
for root, _, files in os.walk(config_path, followlinks=True):
Expand Down Expand Up @@ -1165,8 +1167,8 @@ def _parse_colang_files_recursively(
current_file, current_path = colang_files[len(parsed_colang_files)]

with open(current_path, "r", encoding="utf-8") as f:
content = f.read()
try:
content = f.read()
_parsed_config = parse_colang_file(
current_file, content=content, version=colang_version
)
Expand Down Expand Up @@ -1668,7 +1670,7 @@ def streaming_supported(self):
# if we have output rails streaming enabled
# we keep it in case it was needed when we have
# support per rails
if self.rails.output.streaming.enabled:
if self.rails.output.streaming and self.rails.output.streaming.enabled:
return True
return False

Expand Down
Loading