@@ -38,18 +38,22 @@ def _encode(
38
38
class HuggingFaceSentenceEmbedder(TransformerSentenceEmbedder):
39
39
def __init__(self, config_string: str, batch_size: int = 128):
40
40
super().__init__(config_string, batch_size)
41
+ self.config_string = config_string
41
42
42
43
@staticmethod
43
44
def load(embedder: dict) -> "HuggingFaceSentenceEmbedder":
45
+ if os.path.exists(embedder["config_string"]):
46
+ config_string = embedder["config_string"]
47
+ else:
48
+ config_string = request_util.get_model_path(embedder["config_string"])
44
49
return HuggingFaceSentenceEmbedder(
45
- config_string=request_util.get_model_path(embedder["config_string"]),
46
- batch_size=embedder["batch_size"],
50
+ config_string=config_string, batch_size=embedder["batch_size"]
47
51
)
48
52
49
53
def to_json(self) -> dict:
50
54
return {
51
55
"cls": "HuggingFaceSentenceEmbedder",
52
- "config_string": self.model.model_card_data.base_model ,
56
+ "config_string": self.config_string ,
53
57
"batch_size": self.batch_size,
54
58
}
55
59
@@ -239,7 +243,9 @@ def _encode(
239
243
self, documents: List[Union[str, Doc]], fit_model: bool
240
244
) -> Generator[List[List[float]], None, None]:
241
245
for documents_batch in util.batch(documents, self.batch_size):
242
- documents_batch = [self._trim_length(doc.replace("\n", " ")) for doc in documents_batch]
246
+ documents_batch = [
247
+ self._trim_length(doc.replace("\n", " ")) for doc in documents_batch
248
+ ]
243
249
try:
244
250
response = self.openai_client.embeddings.create(
245
251
input=documents_batch, model=self.model_name
@@ -270,11 +276,13 @@ def dump(self, project_id: str, embedding_id: str) -> None:
270
276
export_file.parent.mkdir(parents=True, exist_ok=True)
271
277
util.write_json(self.to_json(), export_file, indent=2)
272
278
273
- def _trim_length(self, text: str, max_length: int= 512) -> str:
279
+ def _trim_length(self, text: str, max_length: int = 512) -> str:
274
280
tokens = self._auto_tokenizer(
275
281
text,
276
282
truncation=True,
277
283
max_length=max_length,
278
- return_tensors=None # No tensors needed for just truncating
284
+ return_tensors=None, # No tensors needed for just truncating
285
+ )
286
+ return self._auto_tokenizer.decode(
287
+ tokens["input_ids"], skip_special_tokens=True
279
288
)
280
- return self._auto_tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
0 commit comments