Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS SageMaker Unified Studio Workflow Operator #45726

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

===============================
Amazon SageMaker Unified Studio
===============================

`Amazon SageMaker Unified Studio <https://docs.aws.amazon.com/sagemaker-unified-studio>`__ is a unified development experience that
brings together AWS data, analytics, artificial intelligence (AI), and machine learning (ML) services.
It provides a place to build, deploy, execute, and monitor end-to-end workflows from a single interface.
This helps drive collaboration across teams and facilitate agile development.

Airflow provides operators to orchestrate Notebooks, Querybooks, and Visual ETL jobs within SageMaker Unified Studio Workflows.

Prerequisite Tasks
------------------

To use these operators, you must do a few things:

* Create a SageMaker Unified Studio domain.
* Within your domain, create a project with the "Data analytics and AI-ML model development" project profile.
* Within your project:
* Navigate to the "Compute > Workflow environments" tab, and click "Create" to create a new MWAA environment.
* Create a Notebook, Querybook, or Visual ETL job and save it to your project.

Operators
---------

.. _howto/operator:SageMakerNotebookOperator:

Create an Amazon SageMaker Unified Studio Workflow
==================================================

To create an Amazon SageMaker Unified Studio workflow to orchestrate your notebook, querybook, and visual ETL runs you can use
:class:`~airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookOperator`.

.. exampleinclude:: /../../providers/tests/system/amazon/aws/example_sagemaker_unified_studio.py
:language: python
:dedent: 4
:start-after: [START howto_operator_sagemaker_unified_studio_notebook]
:end-before: [END howto_operator_sagemaker_unified_studio_notebook]


Reference
---------

* `What is Amazon SageMaker Unified Studio <https://docs.aws.amazon.com/sagemaker-unified-studio/latest/userguide/what-is-sagemaker-unified-studio.html>`__
9 changes: 9 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ connectTimeoutMS
connexion
containerConfiguration
containerd
ContainerEntrypoint
ContainerGroup
containerinstance
ContainerPort
Expand Down Expand Up @@ -819,6 +820,7 @@ ImageAnnotatorClient
imageORfile
imagePullPolicy
imagePullSecrets
ImageUri
imageVersion
Imap
imap
Expand Down Expand Up @@ -852,6 +854,7 @@ InstanceFlexibilityPolicy
InstanceGroupConfig
InstanceSelection
instanceTemplates
InstanceType
instantiation
integrations
interdependencies
Expand All @@ -869,6 +872,7 @@ IPv4
ipv4
IPv6
ipv6
ipynb
iPython
irreproducible
IRSA
Expand Down Expand Up @@ -1043,6 +1047,7 @@ masterType
Matomo
matomo
Maxime
MaxRuntimeInSeconds
mb
md
mediawiki
Expand Down Expand Up @@ -1361,6 +1366,8 @@ quantile
Quantiles
quantiles
Quantopian
Querybooks
querybook
Qubole
qubole
QuboleCheckHook
Expand Down Expand Up @@ -1880,6 +1887,8 @@ virtualenvs
vm
VolumeMount
volumeMounts
VolumeKmsKeyId
VolumeSizeInGB
vpc
WaiterModel
wape
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module contains the Amazon SageMaker Unified Studio Notebook hook."""

import time

from sagemaker_studio import ClientConfig
from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest
from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI

from airflow import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner


class SageMakerNotebookHook(BaseHook):
"""
Interact with the Sagemaker Workflows API.

This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API.

Examples:
.. code-block:: python

from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook

notebook_hook = NotebookHook(
input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}},
output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'},
execution_name='notebook_execution',
poll_interval=10,
)

:param execution_name: The name of the notebook job to be executed, this is same as task_id.
:param input_config: Configuration for the input file.
Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}}
:param output_config: Configuration for the output format. It should include an output_formats parameter to control
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence seems to just tail off in the middle?

Example: {'output_formats': ['NOTEBOOK']}
:param compute: compute configuration to use for the notebook execution. This is an required attribute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:param compute: compute configuration to use for the notebook execution. This is an required attribute
:param compute: compute configuration to use for the notebook execution. This is a required attribute

if the execution is on a remote compute.
Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]}
:param termination_condition: conditions to match to terminate the remote execution.
Example: { "MaxRuntimeInSeconds": 3600 }
:param tags: tags to be associated with the remote execution runs.
Example: { "md_analytics": "logs" }
:param poll_interval: Interval in seconds to check the notebook execution status.
"""

def __init__(
self,
execution_name: str,
input_config: dict = {},
output_config: dict = {"output_formats": ["NOTEBOOK"]},
compute: dict = {},
termination_condition: dict = {},
tags: dict = {},
poll_interval: int = 10,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config())
self.execution_name = execution_name
self.input_config = input_config
self.output_config = output_config
self.compute = compute
self.termination_condition = termination_condition
self.tags = tags
self.poll_interval = poll_interval

def _get_sagemaker_studio_config(self):
config = ClientConfig()
config.overrides["execution"] = {"local": is_local_runner()}
return config

def _format_start_execution_input_config(self):
config = {
"notebook_config": {
"input_path": self.input_config.get("input_path"),
"input_parameters": self.input_config.get("input_params"),
},
}

return config

def _format_start_execution_output_config(self):
output_formats = (
self.output_config.get("output_formats") if self.output_config else ["NOTEBOOK"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ternary is unnecessary right? There is a default value provided in the constructor so the output_config can't be empty?

)
config = {
"notebook_config": {
"output_formats": output_formats,
}
}
return config

def start_notebook_execution(self):
start_execution_params = {
"execution_name": self.execution_name,
"execution_type": "NOTEBOOK",
"input_config": self._format_start_execution_input_config(),
"output_config": self._format_start_execution_output_config(),
"termination_condition": self.termination_condition,
"tags": self.tags,
}
if self.compute:
start_execution_params["compute"] = self.compute

request = StartExecutionRequest(**start_execution_params)

return self._sagemaker_studio.execution_client.start_execution(request)

def wait_for_execution_completion(self, execution_id, context):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

while True:
time.sleep(self.poll_interval)
response = self.get_execution_response(execution_id)
error_message = response.get("error_details", {}).get("error_message")
status = response["status"]
if "files" in response:
self._set_xcom_files(response["files"], context)
if "s3_path" in response:
self._set_xcom_s3_path(response["s3_path"], context)

ret = self._handle_state(execution_id, status, error_message)
if ret:
return ret

def _set_xcom_files(self, files, context):
if not context:
error_message = "context is required"
raise AirflowException(error_message)
for file in files:
context["ti"].xcom_push(
key=f"{file['display_name']}.{file['file_format']}",
value=file["file_path"],
)

def _set_xcom_s3_path(self, s3_path, context):
if not context:
error_message = "context is required"
raise AirflowException(error_message)
context["ti"].xcom_push(
key="s3_path",
value=s3_path,
)

def get_execution_response(self, execution_id):
response = self._sagemaker_studio.execution_client.get_execution(
GetExecutionRequest(execution_id=execution_id)
)
return response
Comment on lines +164 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
response = self._sagemaker_studio.execution_client.get_execution(
GetExecutionRequest(execution_id=execution_id)
)
return response
return self._sagemaker_studio.execution_client.get_execution(
GetExecutionRequest(execution_id=execution_id)
)

Any specific reason you're wrapping this one-liner in a method? I thought maybe for mocking purposes but I don't see anything like that below.


def _handle_state(self, execution_id, status, error_message):
finished_states = ["COMPLETED"]
in_progress_states = ["IN_PROGRESS", "STOPPING"]

if status in in_progress_states:
self.log.info(
f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.poll_interval}"
)
return None
execution_message = f"Exiting Execution {execution_id} State: {status}"
if status in finished_states:
self.log.info(execution_message)
return {"Status": status, "ExecutionId": execution_id}
else:
self.log.error(f"{execution_message} Message: {error_message}")
if error_message == "":
error_message = execution_message
raise AirflowException(error_message)
Loading
Loading