|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""The step definitions for EMR Serverless workflow.""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +from typing import Any, Dict, List, Union, Optional |
| 17 | + |
| 18 | +from sagemaker.workflow.entities import ( |
| 19 | + RequestType, |
| 20 | +) |
| 21 | +from sagemaker.workflow.properties import ( |
| 22 | + Properties, |
| 23 | +) |
| 24 | +from sagemaker.workflow.retry import StepRetryPolicy |
| 25 | +from sagemaker.workflow.step_collections import StepCollection |
| 26 | +from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig |
| 27 | + |
| 28 | + |
| 29 | +class EMRServerlessJobConfig: |
| 30 | + """Config for EMR Serverless job.""" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + job_driver: Dict, |
| 35 | + execution_role_arn: str, |
| 36 | + configuration_overrides: Optional[Dict] = None, |
| 37 | + execution_timeout_minutes: Optional[int] = None, |
| 38 | + name: Optional[str] = None, |
| 39 | + tags: Optional[Dict[str, str]] = None, |
| 40 | + ): # pylint: disable=too-many-positional-arguments |
| 41 | + """Create a definition for EMR Serverless job configuration. |
| 42 | +
|
| 43 | + Args: |
| 44 | + job_driver (Dict): The job driver for the job run. |
| 45 | + execution_role_arn (str): The execution role ARN for the job run. |
| 46 | + configuration_overrides (Dict, optional): Configuration overrides for the job run. |
| 47 | + execution_timeout_minutes (int, optional): The maximum duration for the job run. |
| 48 | + name (str, optional): The optional job run name. |
| 49 | + tags (Dict[str, str], optional): The tags assigned to the job run. |
| 50 | + """ |
| 51 | + self.job_driver = job_driver |
| 52 | + self.execution_role_arn = execution_role_arn |
| 53 | + self.configuration_overrides = configuration_overrides |
| 54 | + self.execution_timeout_minutes = execution_timeout_minutes |
| 55 | + self.name = name |
| 56 | + self.tags = tags |
| 57 | + |
| 58 | + def to_request(self, application_id: Optional[str] = None) -> RequestType: |
| 59 | + """Convert EMRServerlessJobConfig object to request dict.""" |
| 60 | + config = {"executionRoleArn": self.execution_role_arn, "jobDriver": self.job_driver} |
| 61 | + if application_id is not None: |
| 62 | + config["applicationId"] = application_id |
| 63 | + if self.configuration_overrides is not None: |
| 64 | + config["configurationOverrides"] = self.configuration_overrides |
| 65 | + if self.execution_timeout_minutes is not None: |
| 66 | + config["executionTimeoutMinutes"] = self.execution_timeout_minutes |
| 67 | + if self.name is not None: |
| 68 | + config["name"] = self.name |
| 69 | + if self.tags is not None: |
| 70 | + config["tags"] = self.tags |
| 71 | + return config |
| 72 | + |
| 73 | + |
| 74 | +ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG = ( |
| 75 | + "EMRServerlessStep {step_name} cannot have both application_id and application_config. " |
| 76 | + "To use EMRServerlessStep with application_config, " |
| 77 | + "application_id must be explicitly set to None." |
| 78 | +) |
| 79 | + |
| 80 | +ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG = ( |
| 81 | + "EMRServerlessStep {step_name} must have either application_id or application_config" |
| 82 | +) |
| 83 | + |
| 84 | + |
| 85 | +class EMRServerlessStep(ConfigurableRetryStep): |
| 86 | + """EMR Serverless step for workflow with configurable retry policies.""" |
| 87 | + |
| 88 | + def __init__( |
| 89 | + self, |
| 90 | + name: str, |
| 91 | + display_name: str, |
| 92 | + description: str, |
| 93 | + job_config: EMRServerlessJobConfig, |
| 94 | + application_id: Optional[str] = None, |
| 95 | + application_config: Optional[Dict[str, Any]] = None, |
| 96 | + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, |
| 97 | + cache_config: Optional[CacheConfig] = None, |
| 98 | + retry_policies: Optional[List[StepRetryPolicy]] = None, |
| 99 | + ): # pylint: disable=too-many-positional-arguments |
| 100 | + """Constructs an `EMRServerlessStep`. |
| 101 | +
|
| 102 | + Args: |
| 103 | + name (str): The name of the EMR Serverless step. |
| 104 | + display_name (str): The display name of the EMR Serverless step. |
| 105 | + description (str): The description of the EMR Serverless step. |
| 106 | + job_config (EMRServerlessJobConfig): Job configuration for the EMR Serverless job. |
| 107 | + application_id (str, optional): The ID of the existing EMR Serverless application. |
| 108 | + application_config (Dict[str, Any], optional): Configuration for creating a new |
| 109 | + EMR Serverless application. |
| 110 | + depends_on (List[Union[str, Step, StepCollection]], optional): A list of |
| 111 | + `Step`/`StepCollection` names or `Step` instances or `StepCollection` instances |
| 112 | + that this `EMRServerlessStep` depends on. |
| 113 | + cache_config (CacheConfig, optional): A `sagemaker.workflow.steps.CacheConfig` instance. |
| 114 | + retry_policies (List[StepRetryPolicy], optional): A list of retry policies. |
| 115 | + """ |
| 116 | + super().__init__( |
| 117 | + name=name, |
| 118 | + step_type=StepTypeEnum.EMR_SERVERLESS, |
| 119 | + display_name=display_name, |
| 120 | + description=description, |
| 121 | + depends_on=depends_on, |
| 122 | + retry_policies=retry_policies, |
| 123 | + ) |
| 124 | + |
| 125 | + if application_id is None and application_config is None: |
| 126 | + raise ValueError(ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG.format(step_name=name)) |
| 127 | + |
| 128 | + if application_id is not None and application_config is not None: |
| 129 | + raise ValueError(ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG.format(step_name=name)) |
| 130 | + |
| 131 | + emr_serverless_args = { |
| 132 | + "ExecutionRoleArn": job_config.execution_role_arn, # Top-level role (used by backend) |
| 133 | + "JobConfig": job_config.to_request( |
| 134 | + application_id |
| 135 | + ), # Role also in JobConfig (structure requirement) |
| 136 | + } |
| 137 | + |
| 138 | + if application_id is not None: |
| 139 | + emr_serverless_args["ApplicationId"] = application_id |
| 140 | + elif application_config is not None: |
| 141 | + emr_serverless_args["ApplicationConfig"] = application_config |
| 142 | + |
| 143 | + self.args = emr_serverless_args |
| 144 | + self.cache_config = cache_config |
| 145 | + |
| 146 | + root_property = Properties( |
| 147 | + step_name=name, step=self, shape_name="GetJobRunResponse", service_name="emr-serverless" |
| 148 | + ) |
| 149 | + self._properties = root_property |
| 150 | + |
| 151 | + @property |
| 152 | + def arguments(self) -> RequestType: |
| 153 | + """The arguments dict that is used to call EMR Serverless APIs.""" |
| 154 | + return self.args |
| 155 | + |
| 156 | + @property |
| 157 | + def properties(self) -> RequestType: |
| 158 | + """A Properties object representing the EMR Serverless GetJobRunResponse model.""" |
| 159 | + return self._properties |
| 160 | + |
| 161 | + def to_request(self) -> RequestType: |
| 162 | + """Updates the dictionary with cache configuration and retry policies.""" |
| 163 | + request_dict = super().to_request() |
| 164 | + if self.cache_config: |
| 165 | + request_dict.update(self.cache_config.config) |
| 166 | + return request_dict |
0 commit comments