Skip to content

Commit 57387fd

Browse files
Merge pull request #25 from MITLibraries/USE-138-write-embeddings-to-timdex-dataset
Write embeddings to timdex dataset using TDA library
2 parents 1cfd0af + d46682a commit 57387fd

File tree

3 files changed

+93
-34
lines changed

3 files changed

+93
-34
lines changed

embeddings/cli.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
import json
33
import logging
44
import time
5-
from collections.abc import Callable
5+
from collections.abc import Callable, Iterator
66
from datetime import timedelta
77
from pathlib import Path
88
from typing import TYPE_CHECKING
99

1010
import click
1111
import jsonlines
1212
import smart_open
13-
from timdex_dataset_api import TIMDEXDataset
13+
from timdex_dataset_api import DatasetEmbedding, TIMDEXDataset, TIMDEXEmbeddings
1414

1515
from embeddings.config import configure_logger, configure_sentry
16+
from embeddings.models.base import Embedding
1617
from embeddings.models.registry import get_model_class
1718
from embeddings.strategies.processor import create_embedding_inputs
1819
from embeddings.strategies.registry import STRATEGY_REGISTRY
@@ -222,6 +223,7 @@ def create_embeddings(
222223
"""Create embeddings for TIMDEX records."""
223224
model: BaseEmbeddingModel = ctx.obj["model"]
224225
model.load()
226+
timdex_dataset: TIMDEXDataset | None = None
225227

226228
# read input records from TIMDEX dataset (default) or a JSONLines file
227229
if input_jsonl:
@@ -230,7 +232,6 @@ def create_embeddings(
230232
jsonlines.Reader(file_obj) as reader,
231233
):
232234
timdex_records = iter(list(reader))
233-
234235
else:
235236
if not dataset_location or not run_id:
236237
raise click.UsageError(
@@ -273,14 +274,26 @@ def create_embeddings(
273274
for embedding in embeddings:
274275
writer.write(embedding.to_dict())
275276
else:
276-
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
277-
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
278-
# Embedding instance will nearly 1:1 map to.
279-
raise NotImplementedError
277+
if not timdex_dataset:
278+
# if input_jsonl, init TIMDEXDataset
279+
timdex_dataset = TIMDEXDataset(dataset_location)
280+
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset)
281+
timdex_embeddings.write(_dataset_embedding_iter(embeddings))
280282

281283
logger.info("Embeddings creation complete.")
282284

283285

284-
if __name__ == "__main__": # pragma: no cover
285-
logger = logging.getLogger("embeddings.main")
286-
main()
286+
def _dataset_embedding_iter(
287+
embeddings: Iterator[Embedding],
288+
) -> Iterator[DatasetEmbedding]:
289+
"""Yield DatasetEmbedding objects from model embeddings."""
290+
for embedding in embeddings:
291+
yield DatasetEmbedding(
292+
timdex_record_id=embedding.timdex_record_id,
293+
run_id=embedding.run_id,
294+
run_record_offset=embedding.run_record_offset,
295+
embedding_model=embedding.model_uri,
296+
embedding_strategy=embedding.embedding_strategy,
297+
embedding_vector=embedding.embedding_vector,
298+
embedding_object=json.dumps(embedding.embedding_token_weights).encode(),
299+
)

tests/test_cli.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from pathlib import Path
2+
from unittest.mock import patch
3+
4+
from timdex_dataset_api import TIMDEXDataset
5+
from timdex_dataset_api.embeddings import TIMDEXEmbeddings
6+
17
from embeddings.cli import main
28

39

@@ -133,6 +139,46 @@ def test_model_required_decorator_works_across_commands(
133139
assert "OK" in result.output
134140

135141

142+
@patch("timdex_dataset_api.TIMDEXDataset.read_dicts_iter")
143+
def test_create_embeddings_writes_to_timdex_dataset(
144+
mock_timdex_dataset_read_dicts_iter, register_mock_model, runner, tmp_path
145+
):
146+
mock_timdex_dataset_read_dicts_iter.return_value = iter(
147+
[
148+
{
149+
"timdex_record_id": "record:1",
150+
"run_id": "run-1",
151+
"run_record_offset": 0,
152+
"transformed_record": '{"title":"Record 1","description":"This is a record about coffee in the mountains."}', # noqa: E501
153+
}
154+
]
155+
)
156+
157+
# init TIMDEX Dataset and Embeddings
158+
timdex_dataset = TIMDEXDataset(location=str(tmp_path / "dataset"))
159+
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset)
160+
161+
result = runner.invoke(
162+
main,
163+
[
164+
"create-embeddings",
165+
"--model-uri",
166+
"test/mock-model",
167+
"--dataset-location",
168+
str(tmp_path / "dataset"),
169+
"--run-id",
170+
"run-1",
171+
"--strategy",
172+
"full_record",
173+
],
174+
)
175+
176+
# TODO @jonavellecuerdo: Update to use TIMDEXEmbeddings # noqa: FIX002
177+
# read method when ready
178+
assert result.exit_code == 0
179+
assert Path(timdex_embeddings.data_embeddings_root).exists()
180+
181+
136182
def test_create_embeddings_requires_strategy(register_mock_model, runner):
137183
result = runner.invoke(
138184
main,

uv.lock

Lines changed: 24 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)