Skip to content

feat: support pipeline versioning #5248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
]
dependencies = [
"attrs>=24,<26",
"boto3>=1.35.36,<2.0",
"boto3>=1.39.5,<2.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had previously relaxed the requirement to enable airflow and resolve some dependency conflicts. Any reason why this is needed ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we launched new apis UpdatePipelineVersion and ListPipelineVersions that were added to boto in version 1.39.5- boto/boto3@a1a8371#diff-2c623f3c6a917be56c59d43279244996836262cb1e12d9d0786c9c49eef6b43c

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So currently we needed to downgrade to a lower version because airflow, which is another open source dependency depends on a lower boto version : #5245

If we do make this update , it would break customer experiences .

"cloudpickle>=2.2.1",
"docker",
"fastapi",
Expand Down
57 changes: 52 additions & 5 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def __init__(
self.sagemaker_session.boto_session.client("scheduler"),
)

@property
def latest_pipeline_version_id(self):
"""Retrieves the latest version id of this pipeline"""
summaries = self.list_pipeline_versions(max_results=1)["PipelineVersionSummaries"]
if not summaries:
return None
else:
return summaries[0].get("PipelineVersionId")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always guaranteed to be the latest version when sort_order is None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the default sort order is descending


def create(
self,
role_arn: str = None,
Expand Down Expand Up @@ -166,7 +175,8 @@ def create(
kwargs,
Tags=tags,
)
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
response = self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
return response
Comment on lines -169 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit : Is this change needed ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no. Would you like me to revert this?


def _create_args(
self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration
Expand Down Expand Up @@ -214,15 +224,21 @@ def _create_args(
)
return kwargs

def describe(self) -> Dict[str, Any]:
def describe(self, pipeline_version_id: int = None) -> Dict[str, Any]:
"""Describes a Pipeline in the Workflow service.

Args:
pipeline_version_id (Optional[str]): version ID of the pipeline to describe.

Returns:
Response dict from the service. See `boto3 client documentation
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/\
sagemaker.html#SageMaker.Client.describe_pipeline>`_
"""
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
kwargs = dict(PipelineName=self.name)
if pipeline_version_id:
kwargs["PipelineVersionId"] = pipeline_version_id
return self.sagemaker_session.sagemaker_client.describe_pipeline(**kwargs)

def update(
self,
Expand Down Expand Up @@ -257,7 +273,8 @@ def update(
return self.sagemaker_session.sagemaker_client.update_pipeline(self, description)

kwargs = self._create_args(role_arn, description, parallelism_config)
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
response = self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
return response

def upsert(
self,
Expand Down Expand Up @@ -332,6 +349,7 @@ def start(
execution_description: str = None,
parallelism_config: ParallelismConfiguration = None,
selective_execution_config: SelectiveExecutionConfig = None,
pipeline_version_id: int = None,
):
"""Starts a Pipeline execution in the Workflow service.

Expand All @@ -345,6 +363,8 @@ def start(
over the parallelism configuration of the parent pipeline.
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
selective step execution.
pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not
specified, uses the latest version ID.

Returns:
A `_PipelineExecution` instance, if successful.
Expand All @@ -366,6 +386,7 @@ def start(
PipelineExecutionDisplayName=execution_display_name,
ParallelismConfiguration=parallelism_config,
SelectiveExecutionConfig=selective_execution_config,
PipelineVersionId=pipeline_version_id,
)
if self.sagemaker_session.local_mode:
update_args(kwargs, PipelineParameters=parameters)
Expand Down Expand Up @@ -461,6 +482,32 @@ def list_executions(
if key in response
}

def list_pipeline_versions(
self, sort_order: str = None, max_results: int = None, next_token: str = None
) -> str:
"""Lists a pipeline's versions.

Args:
sort_order (str): The sort order for results (Ascending/Descending).
max_results (int): The maximum number of pipeline executions to return in the response.
next_token (str): If the result of the previous `ListPipelineExecutions` request was
truncated, the response includes a `NextToken`. To retrieve the next set of pipeline
executions, use the token in the next request.

Returns:
List of Pipeline Version Summaries. See
boto3 client list_pipeline_versions
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/list_pipeline_versions.html#
"""
kwargs = dict(PipelineName=self.name)
update_args(
kwargs,
SortOrder=sort_order,
NextToken=next_token,
MaxResults=max_results,
)
return self.sagemaker_session.sagemaker_client.list_pipeline_versions(**kwargs)

def _get_latest_execution_arn(self):
"""Retrieves the latest execution of this pipeline"""
response = self.list_executions(
Expand Down Expand Up @@ -855,7 +902,7 @@ def describe(self):
sagemaker.html#SageMaker.Client.describe_pipeline_execution>`_.
"""
return self.sagemaker_session.sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=self.arn,
PipelineExecutionArn=self.arn
)

def list_steps(self):
Expand Down
55 changes: 54 additions & 1 deletion tests/integ/sagemaker/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def test_three_step_definition(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
assert pipeline.latest_pipeline_version_id == 1
finally:
try:
pipeline.delete()
Expand Down Expand Up @@ -937,7 +938,6 @@ def test_large_pipeline(sagemaker_session_for_pipeline, role, pipeline_name, reg
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
response = pipeline.describe()
assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000

pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
Expand Down Expand Up @@ -1387,3 +1387,56 @@ def test_caching_behavior(
except Exception:
os.remove(script_dir + "/dummy_script.py")
pass


def test_pipeline_versioning(pipeline_session, role, pipeline_name, script_dir):
sklearn_train = SKLearn(
framework_version="0.20.0",
entry_point=os.path.join(script_dir, "train.py"),
instance_type="ml.m5.xlarge",
sagemaker_session=pipeline_session,
role=role,
)

step1 = TrainingStep(
name="my-train-1",
display_name="TrainingStep",
description="description for Training step",
step_args=sklearn_train.fit(),
)

step2 = TrainingStep(
name="my-train-2",
display_name="TrainingStep",
description="description for Training step",
step_args=sklearn_train.fit(),
)
pipeline = Pipeline(
name=pipeline_name,
steps=[step1],
sagemaker_session=pipeline_session,
)

try:
pipeline.create(role)

assert pipeline.latest_pipeline_version_id == 1

describe_response = pipeline.describe(pipeline_version_id=1)
assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 1

pipeline.steps.append(step2)
pipeline.upsert(role)

assert pipeline.latest_pipeline_version_id == 2

describe_response = pipeline.describe(pipeline_version_id=2)
assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 2

assert len(pipeline.list_pipeline_versions()["PipelineVersionSummaries"]) == 2

finally:
try:
pipeline.delete()
except Exception:
pass
49 changes: 40 additions & 9 deletions tests/unit/sagemaker/workflow/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ def _raise_does_already_exists_client_error(**kwargs):
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(
name="create_pipeline", side_effect=_raise_does_already_exists_client_error
)

sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
"PipelineArn": "pipeline-arn"
}
Expand Down Expand Up @@ -429,6 +428,12 @@ def _raise_does_already_exists_client_error(**kwargs):
ResourceArn="pipeline-arn", Tags=tags
)

sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
"PipelineVersionSummaries": [{"PipelineVersionId": 2}]
}

assert pipeline.latest_pipeline_version_id == 2


def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn):

Expand Down Expand Up @@ -476,18 +481,11 @@ def _raise_unexpected_client_error(**kwargs):
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()


def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn):
def test_pipeline_upsert_resource_doesnt_exist(sagemaker_session_mock, role_arn):

# case 3: resource does not exist
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline")

sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
"PipelineArn": "pipeline-arn"
}
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
}

tags = [
{"Key": "foo", "Value": "abc"},
{"Key": "bar", "Value": "xyz"},
Expand Down Expand Up @@ -542,6 +540,11 @@ def test_pipeline_describe(sagemaker_session_mock):
PipelineName="MyPipeline",
)

pipeline.describe(pipeline_version_id=5)
sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with(
PipelineName="MyPipeline", PipelineVersionId=5
)


def test_pipeline_start(sagemaker_session_mock):
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
Expand All @@ -568,6 +571,11 @@ def test_pipeline_start(sagemaker_session_mock):
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
)

pipeline.start(pipeline_version_id=5)
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
PipelineName="MyPipeline", PipelineVersionId=5
)


def test_pipeline_start_selective_execution(sagemaker_session_mock):
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
Expand Down Expand Up @@ -809,6 +817,29 @@ def test_pipeline_list_executions(sagemaker_session_mock):
assert executions["NextToken"] == "token"


def test_pipeline_list_versions(sagemaker_session_mock):
sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
"PipelineVersionSummaries": [Mock()],
"NextToken": "token",
}
pipeline = Pipeline(
name="MyPipeline",
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
steps=[],
sagemaker_session=sagemaker_session_mock,
)
versions = pipeline.list_pipeline_versions()
assert len(versions["PipelineVersionSummaries"]) == 1
assert versions["NextToken"] == "token"

sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
"PipelineVersionSummaries": [Mock(), Mock()],
}
versions = pipeline.list_pipeline_versions(next_token=versions["NextToken"])
assert len(versions["PipelineVersionSummaries"]) == 2
assert "NextToken" not in versions


def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
pipeline = Pipeline(
name="MyPipeline",
Expand Down
Loading