diff --git a/src/prefect/infrastructure/provisioners/__init__.py b/src/prefect/infrastructure/provisioners/__init__.py index 545a8576a102..bb8270360841 100644 --- a/src/prefect/infrastructure/provisioners/__init__.py +++ b/src/prefect/infrastructure/provisioners/__init__.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Type +from prefect.infrastructure.provisioners.coiled import CoiledPushProvisioner from prefect.infrastructure.provisioners.modal import ModalPushProvisioner from .cloud_run import CloudRunPushProvisioner from .container_instance import ContainerInstancePushProvisioner @@ -15,6 +16,7 @@ "azure-container-instance:push": ContainerInstancePushProvisioner, "ecs:push": ElasticContainerServicePushProvisioner, "modal:push": ModalPushProvisioner, + "coiled:push": CoiledPushProvisioner, } diff --git a/src/prefect/infrastructure/provisioners/coiled.py b/src/prefect/infrastructure/provisioners/coiled.py new file mode 100644 index 000000000000..cf63b91866a0 --- /dev/null +++ b/src/prefect/infrastructure/provisioners/coiled.py @@ -0,0 +1,248 @@ +import importlib +import shlex +import sys +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, Optional + +from anyio import run_process +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.prompt import Confirm + +from prefect.client.schemas.actions import BlockDocumentCreate +from prefect.client.schemas.objects import BlockDocument +from prefect.client.utilities import inject_client +from prefect.exceptions import ObjectNotFound +from prefect.utilities.importtools import lazy_import + +if TYPE_CHECKING: + from prefect.client.orchestration import PrefectClient + + +coiled = lazy_import("coiled") + + +class CoiledPushProvisioner: + """ + A infrastructure provisioner for Coiled push work pools. + """ + + def __init__(self, client: Optional["PrefectClient"] = None): + self._console = Console() + + @property + def console(self): + return self._console + + @console.setter + def console(self, value): + self._console = value + + @staticmethod + def _is_coiled_installed() -> bool: + """ + Checks if the coiled package is installed. + + Returns: + True if the coiled package is installed, False otherwise + """ + try: + importlib.import_module("coiled") + return True + except ModuleNotFoundError: + return False + + async def _install_coiled(self): + """ + Installs the coiled package. + """ + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]Installing coiled..."), + transient=True, + console=self.console, + ) as progress: + task = progress.add_task("coiled install") + progress.start() + global coiled + await run_process( + [shlex.quote(sys.executable), "-m", "pip", "install", "coiled"] + ) + coiled = importlib.import_module("coiled") + progress.advance(task) + + async def _get_coiled_token(self) -> str: + """ + Gets a Coiled API token from the current Coiled configuration. + """ + import dask.config + + return dask.config.get("coiled.token", "") + + async def _create_new_coiled_token(self): + """ + Triggers a Coiled login via the browser if no current token. Will create a new token. + """ + await run_process(["coiled", "login"]) + + async def _create_coiled_credentials_block( + self, + block_document_name: str, + coiled_token: str, + client: "PrefectClient", + ) -> BlockDocument: + """ + Creates a CoiledCredentials block containing the provided token. + + Args: + block_document_name: The name of the block document to create + coiled_token: The Coiled API token + + Returns: + The ID of the created block + """ + assert client is not None, "client injection failed" + try: + credentials_block_type = await client.read_block_type_by_slug( + "coiled-credentials" + ) + except ObjectNotFound: + # Shouldn't happen, but just in case + raise RuntimeError( + "Unable to find CoiledCredentials block type. Please ensure you are" + " using Prefect Cloud." + ) + credentials_block_schema = ( + await client.get_most_recent_block_schema_for_block_type( + block_type_id=credentials_block_type.id + ) + ) + assert ( + credentials_block_schema is not None + ), f"Unable to find schema for block type {credentials_block_type.slug}" + + block_doc = await client.create_block_document( + block_document=BlockDocumentCreate( + name=block_document_name, + data={ + "api_token": coiled_token, + }, + block_type_id=credentials_block_type.id, + block_schema_id=credentials_block_schema.id, + ) + ) + return block_doc + + @inject_client + async def provision( + self, + work_pool_name: str, + base_job_template: Dict[str, Any], + client: Optional["PrefectClient"] = None, + ) -> Dict[str, Any]: + """ + Provisions resources necessary for a Coiled push work pool. + + Provisioned resources: + - A CoiledCredentials block containing a Coiled API token + + Args: + work_pool_name: The name of the work pool to provision resources for + base_job_template: The base job template to update + + Returns: + A copy of the provided base job template with the provisioned resources + """ + credentials_block_name = f"{work_pool_name}-coiled-credentials" + base_job_template_copy = deepcopy(base_job_template) + assert client is not None, "client injection failed" + try: + block_doc = await client.read_block_document_by_name( + credentials_block_name, "coiled-credentials" + ) + self.console.print( + f"Work pool [blue]{work_pool_name!r}[/] will reuse the existing Coiled" + f" credentials block [blue]{credentials_block_name!r}[/blue]" + ) + except ObjectNotFound: + if self._console.is_interactive and not Confirm.ask( + ( + "\n" + "To configure your Coiled push work pool we'll need to store a Coiled" + " API token with Prefect Cloud as a block. We'll pull the token from" + " your local Coiled configuration or create a new token if we" + " can't find one.\n" + "\n" + "Would you like to continue?" + ), + console=self.console, + default=True, + ): + self.console.print( + "No problem! You can always configure your Coiled push work pool" + " later via the Prefect UI." + ) + return base_job_template + + if not self._is_coiled_installed(): + if self.console.is_interactive and Confirm.ask( + ( + "The [blue]coiled[/] package is required to configure" + " authentication for your work pool.\n" + "\n" + "Would you like to install it now?" + ), + console=self.console, + default=True, + ): + await self._install_coiled() + + if not self._is_coiled_installed(): + raise RuntimeError( + "The coiled package is not installed.\n\nPlease try installing coiled," + " or you can use the Prefect UI to create your Coiled push work pool." + ) + + # Get the current Coiled API token + coiled_api_token = await self._get_coiled_token() + if not coiled_api_token: + # Create a new token one wasn't found + if self.console.is_interactive and Confirm.ask( + "Coiled credentials not found. Would you like to create a new token?", + console=self.console, + default=True, + ): + await self._create_new_coiled_token() + coiled_api_token = await self._get_coiled_token() + else: + raise RuntimeError( + "Coiled credentials not found. Please create a new token by" + " running [blue]coiled login[/] and try again." + ) + + # Create the credentials block + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]Saving Coiled credentials..."), + transient=True, + console=self.console, + ) as progress: + task = progress.add_task("create coiled credentials block") + progress.start() + block_doc = await self._create_coiled_credentials_block( + credentials_block_name, + coiled_api_token, + client=client, + ) + progress.advance(task) + + base_job_template_copy["variables"]["properties"]["credentials"]["default"] = { + "$ref": {"block_document_id": str(block_doc.id)} + } + if "image" in base_job_template_copy["variables"]["properties"]: + base_job_template_copy["variables"]["properties"]["image"]["default"] = "" + self.console.print( + f"Successfully configured Coiled push work pool {work_pool_name!r}!", + style="green", + ) + return base_job_template_copy diff --git a/tests/infrastructure/provisioners/test_coiled.py b/tests/infrastructure/provisioners/test_coiled.py new file mode 100644 index 000000000000..eca5cdceb11f --- /dev/null +++ b/tests/infrastructure/provisioners/test_coiled.py @@ -0,0 +1,186 @@ +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import UUID + +import pytest + +from prefect.blocks.core import Block +from prefect.client.orchestration import PrefectClient +from prefect.infrastructure.provisioners.coiled import CoiledPushProvisioner + + +@pytest.fixture(autouse=True) +async def coiled_credentials_block_cls(): + class MockCoiledCredentials(Block): + _block_type_name = "Coiled Credentials" + api_token: str + + await MockCoiledCredentials.register_type_and_schema() + + return MockCoiledCredentials + + +@pytest.fixture +async def coiled_credentials_block_id(coiled_credentials_block_cls: Block): + block_doc_id = await coiled_credentials_block_cls(api_token="existing_token").save( + "work-pool-name-coiled-credentials", overwrite=True + ) + + return block_doc_id + + +@pytest.fixture +def mock_run_process(): + with patch("prefect.infrastructure.provisioners.coiled.run_process") as mock: + yield mock + + +@pytest.fixture +def mock_coiled(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("prefect.infrastructure.provisioners.coiled.coiled", mock) + yield mock + + +@pytest.fixture +def mock_importlib(): + with patch("prefect.infrastructure.provisioners.coiled.importlib") as mock: + yield mock + + +@pytest.fixture +def mock_confirm(): + with patch("prefect.infrastructure.provisioners.coiled.Confirm") as mock: + yield mock + + +@pytest.fixture +def mock_dask_config(): + with patch( + "prefect.infrastructure.provisioners.coiled.CoiledPushProvisioner._get_coiled_token" + ) as mock: + mock.return_value = "local-api-token-from-dask-config" + yield mock + + +async def test_provision( + prefect_client: PrefectClient, + mock_run_process: AsyncMock, + mock_coiled: MagicMock, + mock_dask_config: MagicMock, + mock_confirm: MagicMock, + mock_importlib: MagicMock, +): + """ + Test provision from a clean slate: + - Coiled is not installed + - Coiled token does not exist + - CoiledCredentials block does not exist + """ + provisioner = CoiledPushProvisioner() + provisioner.console.is_interactive = True + + mock_confirm.ask.side_effect = [ + True, + True, + True, + ] # confirm provision, install coiled, create new token + mock_importlib.import_module.side_effect = [ + ModuleNotFoundError, + mock_coiled, + mock_coiled, + ] + # simulate coiled token creation + mock_coiled.config.Config.return_value.get.side_effect = [ + None, + None, + "mock_token", + ] + + work_pool_name = "work-pool-name" + base_job_template = {"variables": {"properties": {"credentials": {}}}} + + result = await provisioner.provision( + work_pool_name, base_job_template, client=prefect_client + ) + + # Check if the block document exists and has expected values + block_document = await prefect_client.read_block_document_by_name( + "work-pool-name-coiled-credentials", "coiled-credentials" + ) + + assert block_document.data["api_token"], str == "mock_token" + + # Check if the base job template was updated + assert result["variables"]["properties"]["credentials"] == { + "default": {"$ref": {"block_document_id": str(block_document.id)}}, + } + + +async def test_provision_existing_coiled_credentials_block( + prefect_client: PrefectClient, + coiled_credentials_block_id: UUID, + mock_run_process: AsyncMock, +): + """ + Test provision with an existing CoiledCredentials block. + """ + provisioner = CoiledPushProvisioner() + + work_pool_name = "work-pool-name" + base_job_template = {"variables": {"properties": {"credentials": {}}}} + + result = await provisioner.provision( + work_pool_name, base_job_template, client=prefect_client + ) + + # Check if the base job template was updated + assert result["variables"]["properties"]["credentials"] == { + "default": {"$ref": {"block_document_id": str(coiled_credentials_block_id)}}, + } + + mock_run_process.assert_not_called() + + +async def test_provision_existing_coiled_credentials( + prefect_client: PrefectClient, + mock_run_process: AsyncMock, + mock_coiled: MagicMock, + mock_dask_config: MagicMock, + mock_confirm: MagicMock, + mock_importlib: MagicMock, +): + """ + Test provision where the user has coiled installed and an existing Coiled configuration. + """ + provisioner = CoiledPushProvisioner() + mock_confirm.ask.side_effect = [ + True, + ] # confirm provision + mock_importlib.import_module.side_effect = [ + mock_coiled, + mock_coiled, + ] # coiled is already installed + mock_coiled.config.Config.return_value.get.side_effect = [ + "mock_token", + ] # coiled config exists + + work_pool_name = "work-pool-name" + base_job_template = {"variables": {"properties": {"credentials": {}}}} + + result = await provisioner.provision( + work_pool_name, base_job_template, client=prefect_client + ) + + # Check if the block document exists and has expected values + block_document = await prefect_client.read_block_document_by_name( + "work-pool-name-coiled-credentials", "coiled-credentials" + ) + + assert block_document.data["api_token"], str == "mock_token" + + # Check if the base job template was updated + assert result["variables"]["properties"]["credentials"] == { + "default": {"$ref": {"block_document_id": str(block_document.id)}}, + } + + mock_run_process.assert_not_called()