Skip to content

Commit 3e1083c

Browse files
author
Namrata Madan
committed
feat: support pipeline versioning
1 parent 23c3840 commit 3e1083c

File tree

4 files changed

+147
-16
lines changed

4 files changed

+147
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ classifiers = [
3232
]
3333
dependencies = [
3434
"attrs>=24,<26",
35-
"boto3>=1.35.36,<2.0",
35+
"boto3>=1.39.5,<2.0",
3636
"cloudpickle>=2.2.1",
3737
"docker",
3838
"fastapi",

src/sagemaker/workflow/pipeline.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ def __init__(
125125
self.sagemaker_session.boto_session.client("scheduler"),
126126
)
127127

128+
@property
129+
def latest_pipeline_version_id(self):
130+
"""Retrieves the latest version id of this pipeline"""
131+
summaries = self.list_pipeline_versions(max_results=1)["PipelineVersionSummaries"]
132+
if not summaries:
133+
return None
134+
else:
135+
return summaries[0].get("PipelineVersionId")
136+
128137
def create(
129138
self,
130139
role_arn: str = None,
@@ -166,7 +175,8 @@ def create(
166175
kwargs,
167176
Tags=tags,
168177
)
169-
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
178+
response = self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
179+
return response
170180

171181
def _create_args(
172182
self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration
@@ -214,15 +224,21 @@ def _create_args(
214224
)
215225
return kwargs
216226

217-
def describe(self) -> Dict[str, Any]:
227+
def describe(self, pipeline_version_id: int = None) -> Dict[str, Any]:
218228
"""Describes a Pipeline in the Workflow service.
219229
230+
Args:
231+
pipeline_version_id (Optional[str]): version ID of the pipeline to describe.
232+
220233
Returns:
221234
Response dict from the service. See `boto3 client documentation
222235
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/\
223236
sagemaker.html#SageMaker.Client.describe_pipeline>`_
224237
"""
225-
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
238+
kwargs = dict(PipelineName=self.name)
239+
if pipeline_version_id:
240+
kwargs["PipelineVersionId"] = pipeline_version_id
241+
return self.sagemaker_session.sagemaker_client.describe_pipeline(**kwargs)
226242

227243
def update(
228244
self,
@@ -257,7 +273,8 @@ def update(
257273
return self.sagemaker_session.sagemaker_client.update_pipeline(self, description)
258274

259275
kwargs = self._create_args(role_arn, description, parallelism_config)
260-
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
276+
response = self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
277+
return response
261278

262279
def upsert(
263280
self,
@@ -332,6 +349,7 @@ def start(
332349
execution_description: str = None,
333350
parallelism_config: ParallelismConfiguration = None,
334351
selective_execution_config: SelectiveExecutionConfig = None,
352+
pipeline_version_id: int = None,
335353
):
336354
"""Starts a Pipeline execution in the Workflow service.
337355
@@ -345,6 +363,8 @@ def start(
345363
over the parallelism configuration of the parent pipeline.
346364
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
347365
selective step execution.
366+
pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not
367+
specified, uses the latest version ID.
348368
349369
Returns:
350370
A `_PipelineExecution` instance, if successful.
@@ -366,6 +386,7 @@ def start(
366386
PipelineExecutionDisplayName=execution_display_name,
367387
ParallelismConfiguration=parallelism_config,
368388
SelectiveExecutionConfig=selective_execution_config,
389+
PipelineVersionId=pipeline_version_id,
369390
)
370391
if self.sagemaker_session.local_mode:
371392
update_args(kwargs, PipelineParameters=parameters)
@@ -461,6 +482,32 @@ def list_executions(
461482
if key in response
462483
}
463484

485+
def list_pipeline_versions(
486+
self, sort_order: str = None, max_results: int = None, next_token: str = None
487+
) -> str:
488+
"""Lists a pipeline's versions.
489+
490+
Args:
491+
sort_order (str): The sort order for results (Ascending/Descending).
492+
max_results (int): The maximum number of pipeline executions to return in the response.
493+
next_token (str): If the result of the previous `ListPipelineExecutions` request was
494+
truncated, the response includes a `NextToken`. To retrieve the next set of pipeline
495+
executions, use the token in the next request.
496+
497+
Returns:
498+
List of Pipeline Version Summaries. See
499+
boto3 client list_pipeline_versions
500+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/list_pipeline_versions.html#
501+
"""
502+
kwargs = dict(PipelineName=self.name)
503+
update_args(
504+
kwargs,
505+
SortOrder=sort_order,
506+
NextToken=next_token,
507+
MaxResults=max_results,
508+
)
509+
return self.sagemaker_session.sagemaker_client.list_pipeline_versions(**kwargs)
510+
464511
def _get_latest_execution_arn(self):
465512
"""Retrieves the latest execution of this pipeline"""
466513
response = self.list_executions(
@@ -855,7 +902,7 @@ def describe(self):
855902
sagemaker.html#SageMaker.Client.describe_pipeline_execution>`_.
856903
"""
857904
return self.sagemaker_session.sagemaker_client.describe_pipeline_execution(
858-
PipelineExecutionArn=self.arn,
905+
PipelineExecutionArn=self.arn
859906
)
860907

861908
def list_steps(self):

tests/integ/sagemaker/workflow/test_workflow.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def test_three_step_definition(
312312
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
313313
create_arn,
314314
)
315+
assert pipeline.latest_pipeline_version_id == 1
315316
finally:
316317
try:
317318
pipeline.delete()
@@ -937,7 +938,6 @@ def test_large_pipeline(sagemaker_session_for_pipeline, role, pipeline_name, reg
937938
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
938939
create_arn,
939940
)
940-
response = pipeline.describe()
941941
assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000
942942

943943
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
@@ -1387,3 +1387,56 @@ def test_caching_behavior(
13871387
except Exception:
13881388
os.remove(script_dir + "/dummy_script.py")
13891389
pass
1390+
1391+
1392+
def test_pipeline_versioning(pipeline_session, role, pipeline_name, script_dir):
1393+
sklearn_train = SKLearn(
1394+
framework_version="0.20.0",
1395+
entry_point=os.path.join(script_dir, "train.py"),
1396+
instance_type="ml.m5.xlarge",
1397+
sagemaker_session=pipeline_session,
1398+
role=role,
1399+
)
1400+
1401+
step1 = TrainingStep(
1402+
name="my-train-1",
1403+
display_name="TrainingStep",
1404+
description="description for Training step",
1405+
step_args=sklearn_train.fit(),
1406+
)
1407+
1408+
step2 = TrainingStep(
1409+
name="my-train-2",
1410+
display_name="TrainingStep",
1411+
description="description for Training step",
1412+
step_args=sklearn_train.fit(),
1413+
)
1414+
pipeline = Pipeline(
1415+
name=pipeline_name,
1416+
steps=[step1],
1417+
sagemaker_session=pipeline_session,
1418+
)
1419+
1420+
try:
1421+
pipeline.create(role)
1422+
1423+
assert pipeline.latest_pipeline_version_id == 1
1424+
1425+
describe_response = pipeline.describe(pipeline_version_id=1)
1426+
assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 1
1427+
1428+
pipeline.steps.append(step2)
1429+
pipeline.upsert(role)
1430+
1431+
assert pipeline.latest_pipeline_version_id == 2
1432+
1433+
describe_response = pipeline.describe(pipeline_version_id=2)
1434+
assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 2
1435+
1436+
assert len(pipeline.list_pipeline_versions()["PipelineVersionSummaries"]) == 2
1437+
1438+
finally:
1439+
try:
1440+
pipeline.delete()
1441+
except Exception:
1442+
pass

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def _raise_does_already_exists_client_error(**kwargs):
391391
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(
392392
name="create_pipeline", side_effect=_raise_does_already_exists_client_error
393393
)
394-
395394
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
396395
"PipelineArn": "pipeline-arn"
397396
}
@@ -429,6 +428,12 @@ def _raise_does_already_exists_client_error(**kwargs):
429428
ResourceArn="pipeline-arn", Tags=tags
430429
)
431430

431+
sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
432+
"PipelineVersionSummaries": [{"PipelineVersionId": 2}]
433+
}
434+
435+
assert pipeline.latest_pipeline_version_id == 2
436+
432437

433438
def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn):
434439

@@ -476,18 +481,11 @@ def _raise_unexpected_client_error(**kwargs):
476481
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
477482

478483

479-
def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn):
484+
def test_pipeline_upsert_resource_doesnt_exist(sagemaker_session_mock, role_arn):
480485

481486
# case 3: resource does not exist
482487
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline")
483488

484-
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
485-
"PipelineArn": "pipeline-arn"
486-
}
487-
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
488-
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
489-
}
490-
491489
tags = [
492490
{"Key": "foo", "Value": "abc"},
493491
{"Key": "bar", "Value": "xyz"},
@@ -542,6 +540,11 @@ def test_pipeline_describe(sagemaker_session_mock):
542540
PipelineName="MyPipeline",
543541
)
544542

543+
pipeline.describe(pipeline_version_id=5)
544+
sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with(
545+
PipelineName="MyPipeline", PipelineVersionId=5
546+
)
547+
545548

546549
def test_pipeline_start(sagemaker_session_mock):
547550
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
@@ -568,6 +571,11 @@ def test_pipeline_start(sagemaker_session_mock):
568571
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
569572
)
570573

574+
pipeline.start(pipeline_version_id=5)
575+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
576+
PipelineName="MyPipeline", PipelineVersionId=5
577+
)
578+
571579

572580
def test_pipeline_start_selective_execution(sagemaker_session_mock):
573581
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
@@ -809,6 +817,29 @@ def test_pipeline_list_executions(sagemaker_session_mock):
809817
assert executions["NextToken"] == "token"
810818

811819

820+
def test_pipeline_list_versions(sagemaker_session_mock):
821+
sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
822+
"PipelineVersionSummaries": [Mock()],
823+
"NextToken": "token",
824+
}
825+
pipeline = Pipeline(
826+
name="MyPipeline",
827+
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
828+
steps=[],
829+
sagemaker_session=sagemaker_session_mock,
830+
)
831+
versions = pipeline.list_pipeline_versions()
832+
assert len(versions["PipelineVersionSummaries"]) == 1
833+
assert versions["NextToken"] == "token"
834+
835+
sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
836+
"PipelineVersionSummaries": [Mock(), Mock()],
837+
}
838+
versions = pipeline.list_pipeline_versions(next_token=versions["NextToken"])
839+
assert len(versions["PipelineVersionSummaries"]) == 2
840+
assert "NextToken" not in versions
841+
842+
812843
def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
813844
pipeline = Pipeline(
814845
name="MyPipeline",

0 commit comments

Comments
 (0)