Skip to content

Commit 33ca397

Browse files
committed
Address Traian and Pouyan's feedback on redundant None-checks and defaults in EmbeddingsCacheConfig
1 parent a53be3c commit 33ca397

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

nemoguardrails/embeddings/basic.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import asyncio
1717
import logging
18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, Dict, List, Optional, Union, cast
1919

2020
from annoy import AnnoyIndex # type: ignore
2121

@@ -73,8 +73,14 @@ def __init__(
7373
self._model: Optional[EmbeddingModel] = None
7474
self._items: List[IndexItem] = []
7575
self._embeddings: List[List[float]] = []
76-
self.embedding_model: Optional[str] = embedding_model
77-
self.embedding_engine: Optional[str] = embedding_engine
76+
self.embedding_model: str = (
77+
embedding_model
78+
if embedding_model
79+
else "sentence-transformers/all-MiniLM-L6-v2"
80+
)
81+
self.embedding_engine: str = (
82+
embedding_engine if embedding_engine else "SentenceTransformers"
83+
)
7884
self.embedding_params = embedding_params or {}
7985
self._embedding_size = 0
8086
self.search_threshold = search_threshold or float("inf")
@@ -124,9 +130,8 @@ def embeddings(self):
124130

125131
def _init_model(self):
126132
"""Initialize the model used for computing the embeddings."""
127-
# Provide defaults if not specified
128-
model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
129-
engine = self.embedding_engine or "SentenceTransformers"
133+
model = self.embedding_model
134+
engine = self.embedding_engine
130135

131136
self._model = init_embedding_model(
132137
embedding_model=model,
@@ -152,9 +157,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
152157
if self._model is None:
153158
self._init_model()
154159

155-
if not self._model:
156-
raise Exception("Couldn't initialize embedding model")
157-
embeddings = await self._model.encode_async(texts)
160+
# self._model can't be None here, or self._init_model() would throw a ValueError
161+
model: EmbeddingModel = cast(EmbeddingModel, self._model)
162+
embeddings = await model.encode_async(texts)
158163
return embeddings
159164

160165
async def add_item(self, item: IndexItem):
@@ -218,7 +223,6 @@ async def _run_batch(self):
218223
if not self._current_batch_finished_event:
219224
raise Exception("self._current_batch_finished_event not initialized")
220225

221-
assert self._current_batch_finished_event is not None
222226
batch_event: asyncio.Event = self._current_batch_finished_event
223227
self._current_batch_finished_event = None
224228

nemoguardrails/embeddings/cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,9 @@ def from_config(cls, config: EmbeddingsCacheConfig):
244244

245245
def get_config(self):
246246
return EmbeddingsCacheConfig(
247-
key_generator=self._key_generator.name if self._key_generator else "sha256",
248-
store=self._cache_store.name if self._cache_store else "filesystem",
249-
store_config=self._store_config,
247+
key_generator=self._key_generator.name if self._key_generator else None,
248+
store=self._cache_store.name if self._cache_store else None,
249+
store_config=self._store_config if self._store_config else None,
250250
)
251251

252252
@singledispatchmethod

nemoguardrails/rails/llm/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,15 @@ class EmbeddingsCacheConfig(BaseModel):
358358
default=False,
359359
description="Whether caching of the embeddings should be enabled or not.",
360360
)
361-
key_generator: str = Field(
361+
key_generator: Optional[str] = Field(
362362
default="sha256",
363363
description="The method to use for generating the cache keys.",
364364
)
365-
store: str = Field(
365+
store: Optional[str] = Field(
366366
default="filesystem",
367367
description="What type of store to use for the cached embeddings.",
368368
)
369-
store_config: Dict[str, Any] = Field(
369+
store_config: Optional[Dict[str, Any]] = Field(
370370
default_factory=dict,
371371
description="Any additional configuration options required for the store. "
372372
"For example, path for `filesystem` or `host`/`port`/`db` for redis.",

0 commit comments

Comments
 (0)