22import json
33import logging
44import time
5- from collections .abc import Callable
5+ from collections .abc import Callable , Iterator
66from datetime import timedelta
77from pathlib import Path
88from typing import TYPE_CHECKING
99
1010import click
1111import jsonlines
1212import smart_open
13- from timdex_dataset_api import TIMDEXDataset
13+ from timdex_dataset_api import DatasetEmbedding , TIMDEXDataset , TIMDEXEmbeddings
1414
1515from embeddings .config import configure_logger , configure_sentry
16+ from embeddings .models .base import Embedding
1617from embeddings .models .registry import get_model_class
1718from embeddings .strategies .processor import create_embedding_inputs
1819from embeddings .strategies .registry import STRATEGY_REGISTRY
@@ -222,6 +223,7 @@ def create_embeddings(
222223 """Create embeddings for TIMDEX records."""
223224 model : BaseEmbeddingModel = ctx .obj ["model" ]
224225 model .load ()
226+ timdex_dataset : TIMDEXDataset | None = None
225227
226228 # read input records from TIMDEX dataset (default) or a JSONLines file
227229 if input_jsonl :
@@ -230,7 +232,6 @@ def create_embeddings(
230232 jsonlines .Reader (file_obj ) as reader ,
231233 ):
232234 timdex_records = iter (list (reader ))
233-
234235 else :
235236 if not dataset_location or not run_id :
236237 raise click .UsageError (
@@ -273,14 +274,26 @@ def create_embeddings(
273274 for embedding in embeddings :
274275 writer .write (embedding .to_dict ())
275276 else :
276- # WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
277- # NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
278- # Embedding instance will nearly 1:1 map to.
279- raise NotImplementedError
277+ if not timdex_dataset :
278+ # if input_jsonl, init TIMDEXDataset
279+ timdex_dataset = TIMDEXDataset (dataset_location )
280+ timdex_embeddings = TIMDEXEmbeddings (timdex_dataset )
281+ timdex_embeddings .write (_dataset_embedding_iter (embeddings ))
280282
281283 logger .info ("Embeddings creation complete." )
282284
283285
284- if __name__ == "__main__" : # pragma: no cover
285- logger = logging .getLogger ("embeddings.main" )
286- main ()
286+ def _dataset_embedding_iter (
287+ embeddings : Iterator [Embedding ],
288+ ) -> Iterator [DatasetEmbedding ]:
289+ """Yield DatasetEmbedding objects from model embeddings."""
290+ for embedding in embeddings :
291+ yield DatasetEmbedding (
292+ timdex_record_id = embedding .timdex_record_id ,
293+ run_id = embedding .run_id ,
294+ run_record_offset = embedding .run_record_offset ,
295+ embedding_model = embedding .model_uri ,
296+ embedding_strategy = embedding .embedding_strategy ,
297+ embedding_vector = embedding .embedding_vector ,
298+ embedding_object = json .dumps (embedding .embedding_token_weights ).encode (),
299+ )
0 commit comments