-
Notifications
You must be signed in to change notification settings - Fork 15.3k
Add AWS SageMaker Unified Studio Workflow Operator #45726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 17 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
1118295
Add sagemaker_unified_studio notebook operator, sensor, triggers, and…
agupta01 44e45c8
Fix sagemaker_unified_studio unit tests
agupta01 6c627ef
Add basic system test for sagemaker_unified_studio
agupta01 33c4a3d
Add setup/teardown stubs for sagemaker_unified_studio system test
agupta01 a1e0766
Add more specifics to SUS system test
agupta01 5ee6aad
Update name of SUS helper
agupta01 e378358
Cleanup and format SUS system test
agupta01 b369756
Merge branch 'apache:main' into main
agupta01 209cb3c
Merge branch 'apache:main' into main
agupta01 092c786
Merge branch 'apache:main' into main
agupta01 55f5486
Merge branch 'apache:main' into main
agupta01 c640efa
Fix notebook path in SUS system test
agupta01 41bd997
Merge branch 'apache:main' into main
agupta01 7acc649
Merge branch 'apache:main' into main
agupta01 b7ad9db
Update SUS docs
agupta01 7ca5022
Update SUS docs to include vETL
agupta01 13a10ce
Clarity updates on SUS docs
agupta01 756fcad
Merge branch 'apache:main' into main
agupta01 ffd3363
Merge branch 'apache:main' into main
agupta01 beaadda
Update SUS operator files after providers refactor
agupta01 654f4e1
Add failure on timeout
agupta01 1450dc6
Merge branch 'apache:main' into main
agupta01 ad56227
Set public sagemaker studio SDK as dependency for SUS operator
agupta01 1faaf92
Remove private _openapi usage from SageMaker Studio SDK
agupta01 ae9e2a0
Add extra link for SUS operator
agupta01 f1721c9
Merge branch 'apache:main' into main
agupta01 40a2fb5
Move SUS operator unit tests to new location
agupta01 237f78d
Update SUS system test to use executor_config for environment variables
agupta01 d60d2e9
Fix linting and formatting
agupta01 b23c3d3
Fix system test for localexecutor
agupta01 8e02fbd
Fix formatting
agupta01 af2dc1a
Merge branch 'apache:main' into main
agupta01 f001638
Merge branch 'apache:main' into main
agupta01 99e820e
Update sagemaker-studio lower bound dependency
agupta01 7373926
Fix SMUS system test
agupta01 b581b2e
Fix broken link in SMUS documentation
agupta01 7001393
Merge branch 'apache:main' into main
agupta01 f52762c
Merge branch 'apache:main' into main
agupta01 1307ed3
Fix pre-commit violations
agupta01 de85e5c
Convert tests to pytest
agupta01 692f3f8
Register hook + add license file
agupta01 96468ee
Merge branch 'apache:main' into main
agupta01 3bd17ca
Merge branch 'apache:main' into main
agupta01 ffcd453
Merge branch 'main' into main
o-nikolas 2e05bcf
Merge branch 'apache:main' into main
agupta01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
61 changes: 61 additions & 0 deletions
61
docs/apache-airflow-providers-amazon/operators/sagemakerunifiedstudio.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
.. Licensed to the Apache Software Foundation (ASF) under one | ||
or more contributor license agreements. See the NOTICE file | ||
distributed with this work for additional information | ||
regarding copyright ownership. The ASF licenses this file | ||
to you under the Apache License, Version 2.0 (the | ||
"License"); you may not use this file except in compliance | ||
with the License. You may obtain a copy of the License at | ||
|
||
.. http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
.. Unless required by applicable law or agreed to in writing, | ||
software distributed under the License is distributed on an | ||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations | ||
under the License. | ||
|
||
=============================== | ||
Amazon SageMaker Unified Studio | ||
=============================== | ||
|
||
`Amazon SageMaker Unified Studio <https://docs.aws.amazon.com/sagemaker-unified-studio>`__ is a unified development experience that | ||
brings together AWS data, analytics, artificial intelligence (AI), and machine learning (ML) services. | ||
It provides a place to build, deploy, execute, and monitor end-to-end workflows from a single interface. | ||
This helps drive collaboration across teams and facilitate agile development. | ||
|
||
Airflow provides operators to orchestrate Notebooks, Querybooks, and Visual ETL jobs within SageMaker Unified Studio Workflows. | ||
|
||
Prerequisite Tasks | ||
------------------ | ||
|
||
To use these operators, you must do a few things: | ||
|
||
* Create a SageMaker Unified Studio domain. | ||
* Within your domain, create a project with the "Data analytics and AI-ML model development" project profile. | ||
* Within your project: | ||
* Navigate to the "Compute > Workflow environments" tab, and click "Create" to create a new MWAA environment. | ||
* Create a Notebook, Querybook, or Visual ETL job and save it to your project. | ||
|
||
Operators | ||
--------- | ||
|
||
.. _howto/operator:SageMakerNotebookOperator: | ||
|
||
Create an Amazon SageMaker Unified Studio Workflow | ||
================================================== | ||
|
||
To create an Amazon SageMaker Unified Studio workflow to orchestrate your notebook, querybook, and visual ETL runs you can use | ||
:class:`~airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookOperator`. | ||
|
||
.. exampleinclude:: /../../providers/tests/system/amazon/aws/example_sagemaker_unified_studio.py | ||
:language: python | ||
:dedent: 4 | ||
:start-after: [START howto_operator_sagemaker_unified_studio_notebook] | ||
:end-before: [END howto_operator_sagemaker_unified_studio_notebook] | ||
|
||
|
||
Reference | ||
--------- | ||
|
||
* `What is Amazon SageMaker Unified Studio <https://docs.aws.amazon.com/sagemaker-unified-studio/latest/userguide/what-is-sagemaker-unified-studio.html>`__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
186 changes: 186 additions & 0 deletions
186
providers/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" | ||
|
||
import time | ||
|
||
from sagemaker_studio import ClientConfig | ||
from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest | ||
from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI | ||
|
||
from airflow import AirflowException | ||
from airflow.hooks.base import BaseHook | ||
from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner | ||
|
||
|
||
class SageMakerNotebookHook(BaseHook): | ||
""" | ||
Interact with the Sagemaker Workflows API. | ||
|
||
This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook | ||
|
||
notebook_hook = NotebookHook( | ||
input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}}, | ||
output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'}, | ||
execution_name='notebook_execution', | ||
poll_interval=10, | ||
) | ||
|
||
:param execution_name: The name of the notebook job to be executed, this is same as task_id. | ||
:param input_config: Configuration for the input file. | ||
Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} | ||
:param output_config: Configuration for the output format. It should include an output_formats parameter to control | ||
Example: {'output_formats': ['NOTEBOOK']} | ||
:param compute: compute configuration to use for the notebook execution. This is an required attribute | ||
agupta01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if the execution is on a remote compute. | ||
Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]} | ||
:param termination_condition: conditions to match to terminate the remote execution. | ||
Example: { "MaxRuntimeInSeconds": 3600 } | ||
:param tags: tags to be associated with the remote execution runs. | ||
Example: { "md_analytics": "logs" } | ||
:param poll_interval: Interval in seconds to check the notebook execution status. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
execution_name: str, | ||
input_config: dict = {}, | ||
output_config: dict = {"output_formats": ["NOTEBOOK"]}, | ||
compute: dict = {}, | ||
termination_condition: dict = {}, | ||
tags: dict = {}, | ||
poll_interval: int = 10, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config()) | ||
self.execution_name = execution_name | ||
self.input_config = input_config | ||
self.output_config = output_config | ||
self.compute = compute | ||
self.termination_condition = termination_condition | ||
self.tags = tags | ||
self.poll_interval = poll_interval | ||
|
||
def _get_sagemaker_studio_config(self): | ||
config = ClientConfig() | ||
config.overrides["execution"] = {"local": is_local_runner()} | ||
return config | ||
|
||
def _format_start_execution_input_config(self): | ||
config = { | ||
"notebook_config": { | ||
"input_path": self.input_config.get("input_path"), | ||
"input_parameters": self.input_config.get("input_params"), | ||
}, | ||
} | ||
|
||
return config | ||
|
||
def _format_start_execution_output_config(self): | ||
output_formats = ( | ||
self.output_config.get("output_formats") if self.output_config else ["NOTEBOOK"] | ||
agupta01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
config = { | ||
"notebook_config": { | ||
"output_formats": output_formats, | ||
} | ||
} | ||
return config | ||
|
||
def start_notebook_execution(self): | ||
start_execution_params = { | ||
"execution_name": self.execution_name, | ||
"execution_type": "NOTEBOOK", | ||
"input_config": self._format_start_execution_input_config(), | ||
"output_config": self._format_start_execution_output_config(), | ||
"termination_condition": self.termination_condition, | ||
"tags": self.tags, | ||
} | ||
if self.compute: | ||
start_execution_params["compute"] = self.compute | ||
|
||
request = StartExecutionRequest(**start_execution_params) | ||
|
||
return self._sagemaker_studio.execution_client.start_execution(request) | ||
|
||
def wait_for_execution_completion(self, execution_id, context): | ||
|
||
agupta01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
while True: | ||
time.sleep(self.poll_interval) | ||
response = self.get_execution_response(execution_id) | ||
error_message = response.get("error_details", {}).get("error_message") | ||
status = response["status"] | ||
if "files" in response: | ||
self._set_xcom_files(response["files"], context) | ||
if "s3_path" in response: | ||
self._set_xcom_s3_path(response["s3_path"], context) | ||
|
||
ret = self._handle_state(execution_id, status, error_message) | ||
if ret: | ||
return ret | ||
|
||
def _set_xcom_files(self, files, context): | ||
if not context: | ||
error_message = "context is required" | ||
raise AirflowException(error_message) | ||
for file in files: | ||
context["ti"].xcom_push( | ||
key=f"{file['display_name']}.{file['file_format']}", | ||
value=file["file_path"], | ||
) | ||
|
||
def _set_xcom_s3_path(self, s3_path, context): | ||
if not context: | ||
error_message = "context is required" | ||
raise AirflowException(error_message) | ||
context["ti"].xcom_push( | ||
key="s3_path", | ||
value=s3_path, | ||
) | ||
|
||
def get_execution_response(self, execution_id): | ||
response = self._sagemaker_studio.execution_client.get_execution( | ||
GetExecutionRequest(execution_id=execution_id) | ||
) | ||
return response | ||
agupta01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _handle_state(self, execution_id, status, error_message): | ||
finished_states = ["COMPLETED"] | ||
in_progress_states = ["IN_PROGRESS", "STOPPING"] | ||
|
||
if status in in_progress_states: | ||
self.log.info( | ||
f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.poll_interval}" | ||
) | ||
return None | ||
execution_message = f"Exiting Execution {execution_id} State: {status}" | ||
if status in finished_states: | ||
self.log.info(execution_message) | ||
return {"Status": status, "ExecutionId": execution_id} | ||
else: | ||
self.log.error(f"{execution_message} Message: {error_message}") | ||
if error_message == "": | ||
error_message = execution_message | ||
raise AirflowException(error_message) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.