Skip to content

Commit 9bd7142

Browse files
mihowcarlos-irreverentlabscarlosgjsclaude
authored
AMI: Pipeline Registration (#106)
* Pipeline registration * Convert worker tests to integration tests with real ML inference Replaces fully mocked unit tests with integration tests that validate the Antenna API contract and run actual ML models. Tests now exercise the worker's unique code path (RESTDataset → rest_collate_fn) with real image loading and inference. Changes: - Add trapdata/api/tests/utils.py with shared test utilities - Add trapdata/api/tests/antenna_api_server.py to mock Antenna API - Rewrite test_worker.py as integration tests (17 tests, all passing) - Update test_api.py to use shared utilities Tests validate: real detector/classifier inference, HTTP image loading, schema compliance, batch processing, and end-to-end workflow. Co-Authored-By: Claude Sonnet 4.5 <[email protected]> * Add AsyncPipelineRegistrationResponse schema Add Pydantic model to validate responses from pipeline registration API. Fields: pipelines_created, pipelines_updated, processing_service_id. Co-Authored-By: Claude Opus 4.5 <[email protected]> * Refactor registration functions to use get_http_session Update get_user_projects() and register_pipelines_for_project() to use the session-based HTTP pattern established in PR #104: - Use get_http_session() context manager for connection pooling - Add retry_max and retry_backoff parameters with defaults - Remove manual header management (session handles auth) - Standardize URL paths (base_url now includes /api/v2) - Use Pydantic model validation for API responses - Fix error handling with hasattr() check Co-Authored-By: Claude Opus 4.5 <[email protected]> * Add integration tests for pipeline registration Add mock Antenna API endpoints: - GET /api/v2/projects/ - list user's projects - POST /api/v2/projects/{id}/pipelines/ - register pipelines Add TestRegistrationIntegration with 2 client tests: - test_get_user_projects - test_register_pipelines_for_project Update TestWorkerEndToEnd.test_full_workflow_with_real_inference to include registration step: register → get jobs → process → post results. Co-Authored-By: Claude Opus 4.5 <[email protected]> * Add git add -p to recommended development practices Co-Authored-By: Claude Opus 4.5 <[email protected]> * Read retry settings from Settings in get_http_session() When max_retries or backoff_factor are not explicitly provided, get_http_session() now reads defaults from Settings (antenna_api_retry_max and antenna_api_retry_backoff). This centralizes retry configuration and allows callers to omit these low-level parameters. Co-Authored-By: Claude Opus 4.5 <[email protected]> * Use Settings pattern in register_pipelines() - Accept Settings object instead of base_url/auth_token params - Remove direct os.environ.get() calls for ANTENNA_API_* vars - Fix error message to reference correct env var (AMI_ANTENNA_API_AUTH_TOKEN) - Remove retry params from get_user_projects() and register_pipelines_for_project() since get_http_session() now reads settings internally - Remove unused os import Co-Authored-By: Claude Opus 4.5 <[email protected]> --------- Co-authored-by: Carlos Garcia Jurado Suarez <[email protected]> Co-authored-by: Carlos Garcia Jurado Suarez <[email protected]> Co-authored-by: Claude Sonnet 4.5 <[email protected]>
1 parent 1a523b2 commit 9bd7142

File tree

7 files changed

+429
-24
lines changed

7 files changed

+429
-24
lines changed

CLAUDE.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ This file helps AI agents (like Claude) work efficiently with the AMI Data Compa
1313
3. **Always prefer command line tools** to avoid expensive API requests (e.g., use git and jq instead of reading whole files)
1414
4. **Use bulk operations and prefetch patterns** to minimize database queries
1515
5. **Commit often** - Small, focused commits make debugging easier
16-
6. **Use TDD whenever possible** - Tests prevent regressions and document expected behavior
17-
7. **Keep it simple** - Always think hard and evaluate more complex approaches and alternative approaches before moving forward
16+
6. **Use `git add -p` for staging** - Interactive staging to add only relevant changes, creating logical commits
17+
7. **Use TDD whenever possible** - Tests prevent regressions and document expected behavior
18+
8. **Keep it simple** - Always think hard and evaluate more complex approaches and alternative approaches before moving forward
1819

1920
### Think Holistically
2021

trapdata/api/schemas.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,31 @@ class ProcessingServiceInfoResponse(pydantic.BaseModel):
382382
]
383383
],
384384
)
385+
386+
387+
class AsyncPipelineRegistrationRequest(pydantic.BaseModel):
388+
"""
389+
Request to register pipelines from an async processing service
390+
"""
391+
392+
processing_service_name: str
393+
pipelines: list[PipelineConfigResponse] = []
394+
395+
396+
class AsyncPipelineRegistrationResponse(pydantic.BaseModel):
397+
"""
398+
Response from registering pipelines with a project.
399+
"""
400+
401+
pipelines_created: list[str] = pydantic.Field(
402+
default_factory=list,
403+
description="List of pipeline slugs that were created",
404+
)
405+
pipelines_updated: list[str] = pydantic.Field(
406+
default_factory=list,
407+
description="List of pipeline slugs that were updated",
408+
)
409+
processing_service_id: int | None = pydantic.Field(
410+
default=None,
411+
description="ID of the processing service that was created or updated",
412+
)

trapdata/api/tests/antenna_api_server.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
AntennaTaskResult,
1515
AntennaTaskResults,
1616
AntennaTasksListResponse,
17+
AsyncPipelineRegistrationRequest,
18+
AsyncPipelineRegistrationResponse,
1719
)
1820

1921
app = FastAPI()
2022

2123
# State management for tests
2224
_jobs_queue: dict[int, list[AntennaPipelineProcessingTask]] = {}
2325
_posted_results: dict[int, list[AntennaTaskResult]] = {}
26+
_projects: list[dict] = []
27+
_registered_pipelines: dict[int, list[str]] = {} # project_id -> pipeline slugs
2428

2529

2630
@app.get("/api/v2/jobs")
@@ -84,6 +88,52 @@ def post_results(job_id: int, payload: list[dict]):
8488
return {"status": "ok"}
8589

8690

91+
@app.get("/api/v2/projects/")
92+
def get_projects():
93+
"""Return list of projects the user has access to.
94+
95+
Returns:
96+
Paginated response with list of projects
97+
"""
98+
return {"results": _projects}
99+
100+
101+
@app.post("/api/v2/projects/{project_id}/pipelines/")
102+
def register_pipelines(project_id: int, payload: dict):
103+
"""Register pipelines for a project.
104+
105+
Args:
106+
project_id: Project ID to register pipelines for
107+
payload: AsyncPipelineRegistrationRequest as dict
108+
109+
Returns:
110+
AsyncPipelineRegistrationResponse
111+
"""
112+
# Validate request
113+
request = AsyncPipelineRegistrationRequest(**payload)
114+
115+
# Check if project exists
116+
project_ids = [p["id"] for p in _projects]
117+
if project_id not in project_ids:
118+
raise HTTPException(status_code=404, detail="Project not found")
119+
120+
# Track registered pipelines
121+
if project_id not in _registered_pipelines:
122+
_registered_pipelines[project_id] = []
123+
124+
created = []
125+
for pipeline in request.pipelines:
126+
if pipeline.slug not in _registered_pipelines[project_id]:
127+
_registered_pipelines[project_id].append(pipeline.slug)
128+
created.append(pipeline.slug)
129+
130+
return AsyncPipelineRegistrationResponse(
131+
pipelines_created=created,
132+
pipelines_updated=[],
133+
processing_service_id=1,
134+
)
135+
136+
87137
# Test helper methods
88138

89139

@@ -109,7 +159,31 @@ def get_posted_results(job_id: int) -> list[AntennaTaskResult]:
109159
return _posted_results.get(job_id, [])
110160

111161

162+
def setup_projects(projects: list[dict]):
163+
"""Setup projects for testing.
164+
165+
Args:
166+
projects: List of project dicts with 'id' and 'name' fields
167+
"""
168+
_projects.clear()
169+
_projects.extend(projects)
170+
171+
172+
def get_registered_pipelines(project_id: int) -> list[str]:
173+
"""Get list of pipeline slugs registered for a project.
174+
175+
Args:
176+
project_id: Project ID to get pipelines for
177+
178+
Returns:
179+
List of pipeline slugs
180+
"""
181+
return _registered_pipelines.get(project_id, [])
182+
183+
112184
def reset():
113185
"""Clear all state between tests."""
114186
_jobs_queue.clear()
115187
_posted_results.clear()
188+
_projects.clear()
189+
_registered_pipelines.clear()

trapdata/api/tests/test_worker.py

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717
AntennaPipelineProcessingTask,
1818
AntennaTaskResult,
1919
AntennaTaskResultError,
20+
PipelineConfigResponse,
2021
PipelineResultsResponse,
2122
)
2223
from trapdata.api.tests import antenna_api_server
2324
from trapdata.api.tests.antenna_api_server import app as antenna_app
2425
from trapdata.api.tests.image_server import StaticFileTestServer
2526
from trapdata.api.tests.utils import get_test_image_urls, patch_antenna_api_requests
26-
from trapdata.cli.worker import _get_jobs, _process_job
27+
from trapdata.cli.worker import (
28+
_get_jobs,
29+
_process_job,
30+
get_user_projects,
31+
register_pipelines_for_project,
32+
)
2733
from trapdata.tests import TEST_IMAGES_BASE_PATH
2834

2935
# ---------------------------------------------------------------------------
@@ -473,10 +479,13 @@ def _make_settings(self):
473479

474480
def test_full_workflow_with_real_inference(self):
475481
"""
476-
Complete workflow: fetch jobs → fetch tasks → load images →
482+
Complete workflow: register → fetch jobs → fetch tasks → load images →
477483
run detection → run classification → post results.
478484
"""
479-
# Setup job with 2 test images
485+
pipeline_slug = "quebec_vermont_moths_2023"
486+
487+
# Setup project and job with 2 test images
488+
antenna_api_server.setup_projects([{"id": 1, "name": "Test Project"}])
480489
image_urls = get_test_image_urls(
481490
self.file_server, self.test_images_dir, subdir="vermont", num=2
482491
)
@@ -491,25 +500,33 @@ def test_full_workflow_with_real_inference(self):
491500
]
492501
antenna_api_server.setup_job(job_id=200, tasks=tasks)
493502

494-
# Step 1: Get jobs
495503
with patch_antenna_api_requests(self.antenna_client):
504+
# Step 1: Register pipeline
505+
pipeline_configs = [
506+
PipelineConfigResponse(name="Vermont Moths", slug=pipeline_slug, version=1)
507+
]
508+
success, _ = register_pipelines_for_project(
509+
base_url="http://testserver/api/v2",
510+
auth_token="test-token",
511+
project_id=1,
512+
service_name="Test Worker",
513+
pipeline_configs=pipeline_configs,
514+
)
515+
assert success is True
516+
517+
# Step 2: Get jobs
496518
job_ids = _get_jobs(
497519
"http://testserver/api/v2",
498520
"test-token",
499-
"quebec_vermont_moths_2023",
500-
)
501-
502-
assert 200 in job_ids
503-
504-
# Step 2: Process job
505-
with patch_antenna_api_requests(self.antenna_client):
506-
result = _process_job(
507-
"quebec_vermont_moths_2023", 200, self._make_settings()
521+
pipeline_slug,
508522
)
523+
assert 200 in job_ids
509524

510-
assert result is True
525+
# Step 3: Process job
526+
result = _process_job(pipeline_slug, 200, self._make_settings())
527+
assert result is True
511528

512-
# Step 3: Validate results posted
529+
# Step 4: Validate results posted
513530
posted_results = antenna_api_server.get_posted_results(200)
514531
assert len(posted_results) == 2
515532

@@ -566,3 +583,52 @@ def test_multiple_batches_processed(self):
566583
assert all(
567584
isinstance(r.result, PipelineResultsResponse) for r in posted_results
568585
)
586+
587+
588+
# ---------------------------------------------------------------------------
589+
# TestRegistrationIntegration - Basic tests for registration client functions
590+
# ---------------------------------------------------------------------------
591+
592+
593+
class TestRegistrationIntegration(TestCase):
594+
"""Integration tests for registration client functions."""
595+
596+
@classmethod
597+
def setUpClass(cls):
598+
cls.antenna_client = TestClient(antenna_app)
599+
600+
def setUp(self):
601+
antenna_api_server.reset()
602+
603+
def test_get_user_projects(self):
604+
"""Client can fetch list of projects."""
605+
antenna_api_server.setup_projects([
606+
{"id": 1, "name": "Project A"},
607+
{"id": 2, "name": "Project B"},
608+
])
609+
610+
with patch_antenna_api_requests(self.antenna_client):
611+
result = get_user_projects("http://testserver/api/v2", "test-token")
612+
613+
assert len(result) == 2
614+
assert result[0]["id"] == 1
615+
616+
def test_register_pipelines_for_project(self):
617+
"""Client can register pipelines for a project."""
618+
antenna_api_server.setup_projects([{"id": 10, "name": "Test Project"}])
619+
620+
pipeline_configs = [
621+
PipelineConfigResponse(name="Test Pipeline", slug="test_pipeline", version=1)
622+
]
623+
624+
with patch_antenna_api_requests(self.antenna_client):
625+
success, message = register_pipelines_for_project(
626+
base_url="http://testserver/api/v2",
627+
auth_token="test-token",
628+
project_id=10,
629+
service_name="Test Service",
630+
pipeline_configs=pipeline_configs,
631+
)
632+
633+
assert success is True
634+
assert "Created" in message

trapdata/api/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def get_crop_fname(source_image: SourceImage, bbox: BoundingBox) -> str:
4040

4141
def get_http_session(
4242
auth_token: str | None = None,
43-
max_retries: int = 3,
44-
backoff_factor: float = 0.5,
43+
max_retries: int | None = None,
44+
backoff_factor: float | None = None,
4545
status_forcelist: tuple[int, ...] = (500, 502, 503, 504),
4646
) -> requests.Session:
4747
"""
@@ -53,8 +53,8 @@ def get_http_session(
5353
5454
Args:
5555
auth_token: Optional authentication token (adds "Token {token}" to Authorization header)
56-
max_retries: Maximum number of retry attempts (default: 3)
57-
backoff_factor: Exponential backoff multiplier in seconds (default: 0.5)
56+
max_retries: Maximum number of retry attempts (default: from settings.antenna_api_retry_max)
57+
backoff_factor: Exponential backoff multiplier in seconds (default: from settings.antenna_api_retry_backoff)
5858
Delays will be: backoff_factor * (2 ** retry_number)
5959
e.g., 0.5s, 1s, 2s for default settings
6060
status_forcelist: HTTP status codes that trigger a retry (default: 500, 502, 503, 504)
@@ -69,6 +69,16 @@ def get_http_session(
6969
>>> session = get_http_session(auth_token="abc123")
7070
>>> response = session.get("https://api.example.com/data")
7171
"""
72+
# Read defaults from settings if not explicitly provided
73+
if max_retries is None or backoff_factor is None:
74+
from trapdata.settings import read_settings
75+
76+
settings = read_settings()
77+
if max_retries is None:
78+
max_retries = settings.antenna_api_retry_max
79+
if backoff_factor is None:
80+
backoff_factor = settings.antenna_api_retry_backoff
81+
7282
session = requests.Session()
7383

7484
retry_strategy = Retry(

trapdata/cli/base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,38 @@ def worker(
128128
run_worker(pipelines=pipelines)
129129

130130

131+
@cli.command("register")
132+
def register(
133+
name: Annotated[
134+
str,
135+
typer.Argument(
136+
help="Name for the processing service registration (e.g., 'AMI Data Companion on DRAC gpu-03'). "
137+
"Hostname will be added automatically.",
138+
),
139+
],
140+
project: Annotated[
141+
list[int] | None,
142+
typer.Option(
143+
help="Specific project IDs to register pipelines for. "
144+
"If not specified, registers for all accessible projects.",
145+
),
146+
] = None,
147+
):
148+
"""
149+
Register available pipelines with the Antenna platform for specified projects.
150+
151+
This command registers all available pipeline configurations with the Antenna platform
152+
for the specified projects (or all accessible projects if none specified).
153+
154+
Examples:
155+
ami register --name "AMI Data Companion on DRAC gpu-03" --project 1 --project 2
156+
ami register --name "My Processing Service" # registers for all accessible projects
157+
"""
158+
from trapdata.cli.worker import register_pipelines
159+
160+
project_ids = project if project else []
161+
register_pipelines(project_ids=project_ids, service_name=name)
162+
163+
131164
if __name__ == "__main__":
132165
cli()

0 commit comments

Comments
 (0)