Skip to content

Commit 2c2e2a2

Browse files
committed
Add embeddings-create and embeddings-load steps to pipeline
Why these changes are being introduced: Now that we have AWS Batch pipelines that can generate embeddings, and TIM is prepared to load them into Opensearch, we need the pipeline lambda to prepare commands for both running the AWS Batch job to create embeddings, and then the TIM command to load the embeddings. This dovetails with work in https://mitlibraries.atlassian.net/browse/USE-215 which has proposed updates to the StepFunction. There are two new pipeline lambda invocations in the StepFunction that will utilize the two new allowed 'next-step' values introduced in this commit. How this addresses that need: * Add "embeddings-create" and "embeddings-load" as valid steps in config * Add SKIP_EMBEDDINGS_SOURCES config for sources that don't need embeddings (alma, gisogm) * Add generate_embeddings_create_command() which determines compute env (cpu vs gpu-spot) based on record count threshold * Add generate_embeddings_load_command() for TIM bulk-update-embeddings command * Add handlers for both new steps in format_input.py * Update handle_load() to flow into embeddings-create instead of end * Add run_id and embeddings fields to ResultPayload * Add unit tests for new functionality Side effects of this change: * Pipeline will now continue to embeddings steps after load completes Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-140
1 parent 737e8c9 commit 2c2e2a2

File tree

6 files changed

+280
-6
lines changed

6 files changed

+280
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Takes input JSON (usually from EventBridge although it can be passed to a manual
1010

1111
#### Required
1212

13-
- `next-step`: The next step of the pipeline to be performed, must be one of `["extract", "transform", "load"]`. Determines which task run commands will be generated as output from the format lambda.
13+
- `next-step`: The next step of the pipeline to be performed. Determines which task run commands will be generated as output from the format lambda.
1414
- `run-date`: Must be in one of the formats ["yyyy-mm-dd", "yyyy-mm-ddThh:mm:ssZ"]. The provided date is used in the input/output file naming scheme for all steps of the pipeline.
1515
- `run-type`: Must be one of `["full", "daily"]`. The provided run type is used in the input/output file naming scheme for all steps of the pipeline. It also determines logic for both the OAI-PMH harvest and load commands as follows:
1616
- `full`: Perform a full harvest of all records from the provided `oai-pmh-host`. During load, create a new OpenSearch index, load all records into it, and then promote the new index.

lambdas/commands.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import uuid
23
from typing import TYPE_CHECKING
34

45
from lambdas import helpers
@@ -11,6 +12,8 @@
1112

1213
CONFIG = Config()
1314

15+
GPU_RECORD_COUNT_THRESHOLD = 500
16+
1417

1518
def generate_extract_command(input_payload: "InputPayload") -> dict:
1619
step = "extract"
@@ -148,3 +151,45 @@ def generate_load_commands(input_payload: "InputPayload") -> dict:
148151
}
149152

150153
return {"failure": f"Unexpected run-type: '{input_payload.run_type}'"}
154+
155+
156+
def generate_embeddings_create_command(
157+
input_payload: "InputPayload",
158+
record_count: int,
159+
) -> dict:
160+
"""Generate AWS Batch job parameters for creating embeddings.
161+
162+
Determines compute environment based on record count:
163+
- cpu (ECS Fargate) for < 500 records
164+
- gpu-spot (EC2 Spot) for >= 500 records
165+
"""
166+
job_compute_env = "gpu-spot" if record_count >= GPU_RECORD_COUNT_THRESHOLD else "cpu"
167+
168+
return {
169+
"create": {
170+
"job_name": f"create-embeddings-{job_compute_env}-{uuid.uuid4()}",
171+
"job_compute_env": job_compute_env,
172+
"command": [
173+
"--verbose",
174+
"create-embeddings",
175+
"--strategy=full_record",
176+
f"--dataset-location={CONFIG.s3_timdex_dataset_location}",
177+
f"--run-id={input_payload.run_id}",
178+
],
179+
}
180+
}
181+
182+
183+
def generate_embeddings_load_command(input_payload: "InputPayload") -> dict:
184+
"""Generate TIM command to update documents with embeddings."""
185+
return {
186+
"load": {
187+
"bulk-update-embeddings-command": [
188+
"--verbose",
189+
"bulk-update-embeddings",
190+
f"--source={input_payload.source}",
191+
f"--run-id={input_payload.run_id}",
192+
CONFIG.s3_timdex_dataset_location,
193+
],
194+
}
195+
}

lambdas/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class Config:
4242
SOURCE_EXCLUSION_LISTS: ClassVar = {"libguides": "/config/libguides/exclusions.csv"}
4343
VALID_DATE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%SZ")
4444
VALID_RUN_TYPES = ("full", "daily")
45-
VALID_STEPS = ("extract", "transform", "load")
45+
VALID_STEPS = ("extract", "transform", "load", "embeddings-create", "embeddings-load")
46+
SKIP_EMBEDDINGS_SOURCES = ("alma", "gisogm")
4647

4748
def __getattr__(self, name: str) -> Any: # noqa: ANN401
4849
"""Provide dot notation access to configurations and env vars on this class."""

lambdas/format_input.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
1+
# ruff: noqa: S608
2+
13
import json
24
import logging
35
import uuid
46
from dataclasses import asdict, dataclass
57
from datetime import UTC, datetime
68
from typing import Literal
79

10+
from timdex_dataset_api.dataset import TIMDEXDataset # type: ignore[import-untyped]
11+
812
from lambdas import alma_prep, commands, errors, helpers
913
from lambdas.config import Config, configure_logger
1014

1115
logger = logging.getLogger(__name__)
1216

1317
CONFIG = Config()
1418

15-
type NextStep = Literal["extract", "transform", "load", "exit-ok", "exit-error", "end"]
19+
type NextStep = Literal[
20+
"extract",
21+
"transform",
22+
"load",
23+
"embeddings-create",
24+
"embeddings-load",
25+
"exit-ok",
26+
"exit-error",
27+
"end",
28+
]
1629

1730

1831
@dataclass
1932
class InputPayload:
2033
run_date: str
21-
run_type: str
34+
run_type: Literal["daily", "full"]
2235
source: str
2336
next_step: NextStep
2437
run_id: str
@@ -118,12 +131,14 @@ class ResultPayload:
118131
next_step: NextStep
119132
run_date: str
120133
run_type: str
134+
run_id: str
121135
source: str
122136
verbose: bool = True
123137
harvester_type: str | None = None
124138
extract: dict | None = None
125139
transform: dict | None = None
126140
load: dict | None = None
141+
embeddings: dict | None = None
127142
message: str | None = None
128143

129144
@classmethod
@@ -132,6 +147,7 @@ def from_input_payload(cls, input_payload: "InputPayload") -> "ResultPayload":
132147
next_step=input_payload.next_step,
133148
run_date=input_payload.run_date,
134149
run_type=input_payload.run_type,
150+
run_id=input_payload.run_id,
135151
source=input_payload.source,
136152
verbose=input_payload.verbose,
137153
)
@@ -154,6 +170,10 @@ def lambda_handler(event: dict, _context: dict) -> dict:
154170
result = handle_transform(input_payload, result)
155171
elif input_payload.next_step == "load":
156172
result = handle_load(input_payload, result)
173+
elif input_payload.next_step == "embeddings-create":
174+
result = handle_embeddings_create(input_payload, result)
175+
elif input_payload.next_step == "embeddings-load":
176+
result = handle_embeddings_load(input_payload, result)
157177
else:
158178
raise ValueError(f"'next-step' not supported: '{input_payload.next_step}'")
159179

@@ -213,7 +233,7 @@ def handle_transform(input_payload: InputPayload, result: ResultPayload) -> Resu
213233

214234

215235
def handle_load(input_payload: InputPayload, result: ResultPayload) -> ResultPayload:
216-
result.next_step = "end"
236+
result.next_step = "embeddings-create"
217237
if not helpers.dataset_records_exist_for_run(input_payload.run_id):
218238
result.next_step = "exit-ok"
219239
message = (
@@ -225,3 +245,101 @@ def handle_load(input_payload: InputPayload, result: ResultPayload) -> ResultPay
225245
return result
226246
result.load = commands.generate_load_commands(input_payload)
227247
return result
248+
249+
250+
def handle_embeddings_create(
251+
input_payload: InputPayload, result: ResultPayload
252+
) -> ResultPayload:
253+
"""Analyze ETL run and prepare parameters for AWS Batch job to create embeddings.
254+
255+
There are currently three compute environments we can create embeddings in:
256+
- ECS Fargate - "cpu"
257+
- EC2 - "gpu"
258+
- EC2 Spot Instances - "gpu-spot"
259+
260+
This lambda handler is responsible for analyzing the size and shape of the ETL run,
261+
and determining which AWS Batch compute environment is most appropriate.
262+
263+
We do not create embeddings for all sources. Those we skip are configured in
264+
CONFIG.SKIP_EMBEDDINGS_SOURCES.
265+
266+
Additionally, at this time, we do not have a scenario or code path that would
267+
utilize the "gpu" compute environment, only "gpu-spot". This is mostly because we
268+
don't require an immediate turnaround for embeddings creation; when the job size
269+
calls for a GPU, we have the luxury of waiting for a spot instance.
270+
"""
271+
result.next_step = "embeddings-load"
272+
273+
if input_payload.source in CONFIG.SKIP_EMBEDDINGS_SOURCES:
274+
result.next_step = "exit-ok"
275+
result.message = (
276+
f"Not currently creating embeddings for source '{input_payload.source}'"
277+
)
278+
return result
279+
280+
# retrieve records count for run
281+
td = TIMDEXDataset(location=CONFIG.s3_timdex_dataset_location)
282+
record_count = td.metadata.conn.query(f"""
283+
select count(*)
284+
from metadata.records
285+
where run_id = '{input_payload.run_id}'
286+
and action in ('index')
287+
""").fetchone()[0]
288+
289+
# exit early if no records to create embeddings for
290+
if record_count == 0:
291+
result.next_step = "exit-ok"
292+
result.message = f"No embeddable records found for run '{input_payload.run_id}'."
293+
return result
294+
295+
job_compute_env = (
296+
"gpu-spot" if record_count >= commands.GPU_RECORD_COUNT_THRESHOLD else "cpu"
297+
)
298+
logger.info(
299+
f"ETL run '{input_payload.run_id}' had {record_count} records indexed, "
300+
f"recommending '{job_compute_env}' compute env."
301+
)
302+
303+
result.embeddings = commands.generate_embeddings_create_command(
304+
input_payload, record_count
305+
)
306+
return result
307+
308+
309+
def handle_embeddings_load(
310+
input_payload: InputPayload, result: ResultPayload
311+
) -> ResultPayload:
312+
"""Prepare TIM command to update documents in Opensearch with embeddings.
313+
314+
We do not create embeddings for all sources. Those we skip are configured in
315+
CONFIG.SKIP_EMBEDDINGS_SOURCES.
316+
"""
317+
result.next_step = "end"
318+
319+
if input_payload.source in CONFIG.SKIP_EMBEDDINGS_SOURCES:
320+
result.next_step = "exit-ok"
321+
result.message = (
322+
f"Not currently indexing embeddings for source '{input_payload.source}'"
323+
)
324+
return result
325+
326+
# retrieve embeddings count for run
327+
td = TIMDEXDataset(location=CONFIG.s3_timdex_dataset_location)
328+
embeddings_count = td.metadata.conn.query(f"""
329+
select count(*)
330+
from data.current_run_embeddings
331+
where run_id = '{input_payload.run_id}'
332+
""").fetchone()[0]
333+
334+
# exit early if no embeddings to load
335+
if embeddings_count == 0:
336+
result.next_step = "exit-ok"
337+
result.message = f"No embeddings found for run '{input_payload.run_id}'."
338+
return result
339+
340+
logger.info(
341+
f"Preparing TIM command to update {embeddings_count} documents with embeddings."
342+
)
343+
344+
result.embeddings = commands.generate_embeddings_load_command(input_payload)
345+
return result

tests/test_commands.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,60 @@ def test_generate_load_commands_unhandled_run_type(run_id):
338338
}
339339
with pytest.raises(ValueError, match=r"Input 'run-type' value must be one of:"):
340340
InputPayload.from_event(event)
341+
342+
343+
def test_generate_embeddings_create_command_cpu(run_id):
344+
"""Record count below threshold uses cpu compute env."""
345+
event = {
346+
"next-step": "embeddings-create",
347+
"run-date": "2022-01-02",
348+
"run-type": "daily",
349+
"source": "testsource",
350+
"run-id": run_id,
351+
}
352+
input_payload = InputPayload.from_event(event)
353+
result = commands.generate_embeddings_create_command(input_payload, record_count=100)
354+
355+
assert result["create"]["job_compute_env"] == "cpu"
356+
assert "create-embeddings-cpu-" in result["create"]["job_name"]
357+
assert f"--run-id={run_id}" in result["create"]["command"]
358+
359+
360+
def test_generate_embeddings_create_command_gpu_spot(run_id):
361+
"""Record count at/above threshold uses gpu-spot compute env."""
362+
event = {
363+
"next-step": "embeddings-create",
364+
"run-date": "2022-01-02",
365+
"run-type": "daily",
366+
"source": "testsource",
367+
"run-id": run_id,
368+
}
369+
input_payload = InputPayload.from_event(event)
370+
result = commands.generate_embeddings_create_command(input_payload, record_count=500)
371+
372+
assert result["create"]["job_compute_env"] == "gpu-spot"
373+
assert "create-embeddings-gpu-spot-" in result["create"]["job_name"]
374+
375+
376+
def test_generate_embeddings_load_command(run_id):
377+
event = {
378+
"next-step": "embeddings-load",
379+
"run-date": "2022-01-02",
380+
"run-type": "daily",
381+
"source": "testsource",
382+
"run-id": run_id,
383+
}
384+
input_payload = InputPayload.from_event(event)
385+
result = commands.generate_embeddings_load_command(input_payload)
386+
387+
assert result == {
388+
"load": {
389+
"bulk-update-embeddings-command": [
390+
"--verbose",
391+
"bulk-update-embeddings",
392+
"--source=testsource",
393+
f"--run-id={run_id}",
394+
"s3://test-timdex-bucket/dataset",
395+
],
396+
}
397+
}

0 commit comments

Comments
 (0)