Skip to content

Commit ce56c53

Browse files
committed
feature: add emr-serverless step for SageMaker Pipelines
1 parent 5d3f175 commit ce56c53

File tree

4 files changed

+552
-0
lines changed

4 files changed

+552
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for EMR Serverless workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import Any, Dict, List, Union, Optional
17+
18+
from sagemaker.workflow.entities import (
19+
RequestType,
20+
)
21+
from sagemaker.workflow.properties import (
22+
Properties,
23+
)
24+
from sagemaker.workflow.retry import StepRetryPolicy
25+
from sagemaker.workflow.step_collections import StepCollection
26+
from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig
27+
28+
29+
class EMRServerlessJobConfig:
30+
"""Config for EMR Serverless job."""
31+
32+
def __init__(
33+
self,
34+
job_driver: Dict,
35+
execution_role_arn: str,
36+
configuration_overrides: Optional[Dict] = None,
37+
execution_timeout_minutes: Optional[int] = None,
38+
name: Optional[str] = None,
39+
tags: Optional[Dict[str, str]] = None,
40+
): # pylint: disable=too-many-positional-arguments
41+
"""Create a definition for EMR Serverless job configuration.
42+
43+
Args:
44+
job_driver (Dict): The job driver for the job run.
45+
execution_role_arn (str): The execution role ARN for the job run.
46+
configuration_overrides (Dict, optional): Configuration overrides for the job run.
47+
execution_timeout_minutes (int, optional): The maximum duration for the job run.
48+
name (str, optional): The optional job run name.
49+
tags (Dict[str, str], optional): The tags assigned to the job run.
50+
"""
51+
self.job_driver = job_driver
52+
self.execution_role_arn = execution_role_arn
53+
self.configuration_overrides = configuration_overrides
54+
self.execution_timeout_minutes = execution_timeout_minutes
55+
self.name = name
56+
self.tags = tags
57+
58+
def to_request(self, application_id: Optional[str] = None) -> RequestType:
59+
"""Convert EMRServerlessJobConfig object to request dict."""
60+
config = {"executionRoleArn": self.execution_role_arn, "jobDriver": self.job_driver}
61+
if application_id is not None:
62+
config["applicationId"] = application_id
63+
if self.configuration_overrides is not None:
64+
config["configurationOverrides"] = self.configuration_overrides
65+
if self.execution_timeout_minutes is not None:
66+
config["executionTimeoutMinutes"] = self.execution_timeout_minutes
67+
if self.name is not None:
68+
config["name"] = self.name
69+
if self.tags is not None:
70+
config["tags"] = self.tags
71+
return config
72+
73+
74+
ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG = (
75+
"EMRServerlessStep {step_name} cannot have both application_id and application_config. "
76+
"To use EMRServerlessStep with application_config, "
77+
"application_id must be explicitly set to None."
78+
)
79+
80+
ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG = (
81+
"EMRServerlessStep {step_name} must have either application_id or application_config"
82+
)
83+
84+
85+
class EMRServerlessStep(ConfigurableRetryStep):
86+
"""EMR Serverless step for workflow with configurable retry policies."""
87+
88+
def __init__(
89+
self,
90+
name: str,
91+
display_name: str,
92+
description: str,
93+
job_config: EMRServerlessJobConfig,
94+
application_id: Optional[str] = None,
95+
application_config: Optional[Dict[str, Any]] = None,
96+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
97+
cache_config: Optional[CacheConfig] = None,
98+
retry_policies: Optional[List[StepRetryPolicy]] = None,
99+
): # pylint: disable=too-many-positional-arguments
100+
"""Constructs an `EMRServerlessStep`.
101+
102+
Args:
103+
name (str): The name of the EMR Serverless step.
104+
display_name (str): The display name of the EMR Serverless step.
105+
description (str): The description of the EMR Serverless step.
106+
job_config (EMRServerlessJobConfig): Job configuration for the EMR Serverless job.
107+
application_id (str, optional): The ID of the existing EMR Serverless application.
108+
application_config (Dict[str, Any], optional): Configuration for creating a new
109+
EMR Serverless application.
110+
depends_on (List[Union[str, Step, StepCollection]], optional): A list of
111+
`Step`/`StepCollection` names or `Step` instances or `StepCollection` instances
112+
that this `EMRServerlessStep` depends on.
113+
cache_config (CacheConfig, optional): A `sagemaker.workflow.steps.CacheConfig` instance.
114+
retry_policies (List[StepRetryPolicy], optional): A list of retry policies.
115+
"""
116+
super().__init__(
117+
name=name,
118+
step_type=StepTypeEnum.EMR_SERVERLESS,
119+
display_name=display_name,
120+
description=description,
121+
depends_on=depends_on,
122+
retry_policies=retry_policies,
123+
)
124+
125+
if application_id is None and application_config is None:
126+
raise ValueError(ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG.format(step_name=name))
127+
128+
if application_id is not None and application_config is not None:
129+
raise ValueError(ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG.format(step_name=name))
130+
131+
emr_serverless_args = {
132+
"ExecutionRoleArn": job_config.execution_role_arn, # Top-level role (used by backend)
133+
"JobConfig": job_config.to_request(
134+
application_id
135+
), # Role also in JobConfig (structure requirement)
136+
}
137+
138+
if application_id is not None:
139+
emr_serverless_args["ApplicationId"] = application_id
140+
elif application_config is not None:
141+
emr_serverless_args["ApplicationConfig"] = application_config
142+
143+
self.args = emr_serverless_args
144+
self.cache_config = cache_config
145+
146+
root_property = Properties(
147+
step_name=name, step=self, shape_name="GetJobRunResponse", service_name="emr-serverless"
148+
)
149+
self._properties = root_property
150+
151+
@property
152+
def arguments(self) -> RequestType:
153+
"""The arguments dict that is used to call EMR Serverless APIs."""
154+
return self.args
155+
156+
@property
157+
def properties(self) -> RequestType:
158+
"""A Properties object representing the EMR Serverless GetJobRunResponse model."""
159+
return self._properties
160+
161+
def to_request(self) -> RequestType:
162+
"""Updates the dictionary with cache configuration and retry policies."""
163+
request_dict = super().to_request()
164+
if self.cache_config:
165+
request_dict.update(self.cache_config.config)
166+
return request_dict

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class StepTypeEnum(Enum):
6969
QUALITY_CHECK = "QualityCheck"
7070
CLARIFY_CHECK = "ClarifyCheck"
7171
EMR = "EMR"
72+
EMR_SERVERLESS = "EMRServerless"
7273
FAIL = "Fail"
7374
AUTOML = "AutoML"
7475

0 commit comments

Comments
 (0)