1515
1616import asyncio
1717import logging
18- from typing import Any , Dict , List , Optional , Union
18+ from typing import Any , Dict , List , Optional , Union , cast
1919
2020from annoy import AnnoyIndex # type: ignore
2121
@@ -73,8 +73,14 @@ def __init__(
7373 self ._model : Optional [EmbeddingModel ] = None
7474 self ._items : List [IndexItem ] = []
7575 self ._embeddings : List [List [float ]] = []
76- self .embedding_model : Optional [str ] = embedding_model
77- self .embedding_engine : Optional [str ] = embedding_engine
76+ self .embedding_model : str = (
77+ embedding_model
78+ if embedding_model
79+ else "sentence-transformers/all-MiniLM-L6-v2"
80+ )
81+ self .embedding_engine : str = (
82+ embedding_engine if embedding_engine else "SentenceTransformers"
83+ )
7884 self .embedding_params = embedding_params or {}
7985 self ._embedding_size = 0
8086 self .search_threshold = search_threshold or float ("inf" )
@@ -124,9 +130,8 @@ def embeddings(self):
124130
125131 def _init_model (self ):
126132 """Initialize the model used for computing the embeddings."""
127- # Provide defaults if not specified
128- model = self .embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
129- engine = self .embedding_engine or "SentenceTransformers"
133+ model = self .embedding_model
134+ engine = self .embedding_engine
130135
131136 self ._model = init_embedding_model (
132137 embedding_model = model ,
@@ -152,9 +157,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
152157 if self ._model is None :
153158 self ._init_model ()
154159
155- if not self ._model :
156- raise Exception ( "Couldn't initialize embedding model" )
157- embeddings = await self . _model .encode_async (texts )
160+ # self._model can't be None here, or self._init_model() would throw a ValueError
161+ model : EmbeddingModel = cast ( EmbeddingModel , self . _model )
162+ embeddings = await model .encode_async (texts )
158163 return embeddings
159164
160165 async def add_item (self , item : IndexItem ):
@@ -218,7 +223,6 @@ async def _run_batch(self):
218223 if not self ._current_batch_finished_event :
219224 raise Exception ("self._current_batch_finished_event not initialized" )
220225
221- assert self ._current_batch_finished_event is not None
222226 batch_event : asyncio .Event = self ._current_batch_finished_event
223227 self ._current_batch_finished_event = None
224228
0 commit comments