Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/prefect/infrastructure/provisioners/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from copy import deepcopy
from types import ModuleType
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

from anyio import run_process
from rich.console import Console
Expand Down Expand Up @@ -72,13 +72,24 @@ async def _install_coiled(self):
coiled = importlib.import_module("coiled")
progress.advance(task)

async def _get_coiled_token(self) -> str:
async def _get_coiled_creds(self) -> Tuple[str, Optional[str]]:
"""
Gets a Coiled API token from the current Coiled configuration.
"""
import dask.config

return dask.config.get("coiled.token", "")
token = dask.config.get("coiled.token", "")
workspace = None

if token:
# this will validate the token, and then determine what workspace to use based on both
# - locally configured `coiled.workspace` (in dask config file)
# - default workspace (in Coiled database)
# Local config takes precedence.
async with coiled.Cloud() as cloud:
workspace = cloud.default_workspace

return token, workspace

async def _create_new_coiled_token(self):
"""
Expand All @@ -90,6 +101,7 @@ async def _create_coiled_credentials_block(
self,
block_document_name: str,
coiled_token: str,
coiled_workspace: Optional[str],
client: "PrefectClient",
) -> BlockDocument:
"""
Expand Down Expand Up @@ -127,6 +139,7 @@ async def _create_coiled_credentials_block(
name=block_document_name,
data={
"api_token": coiled_token,
"workspace": coiled_workspace,
},
block_type_id=credentials_block_type.id,
block_schema_id=credentials_block_schema.id,
Expand Down Expand Up @@ -204,8 +217,8 @@ async def provision(
" or you can use the Prefect UI to create your Coiled push work pool."
)

# Get the current Coiled API token
coiled_api_token = await self._get_coiled_token()
# Get the current Coiled API token and workspace
coiled_api_token, workspace = await self._get_coiled_creds()
if not coiled_api_token:
# Create a new token one wasn't found
if self.console.is_interactive and Confirm.ask(
Expand All @@ -214,7 +227,7 @@ async def provision(
default=True,
):
await self._create_new_coiled_token()
coiled_api_token = await self._get_coiled_token()
coiled_api_token, workspace = await self._get_coiled_creds()
else:
raise RuntimeError(
"Coiled credentials not found. Please create a new token by"
Expand All @@ -232,7 +245,8 @@ async def provision(
progress.start()
block_doc = await self._create_coiled_credentials_block(
credentials_block_name,
coiled_api_token,
coiled_token=coiled_api_token,
coiled_workspace=workspace,
client=client,
)
progress.advance(task)
Expand Down
4 changes: 2 additions & 2 deletions tests/infrastructure/provisioners/test_coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def mock_confirm():
@pytest.fixture
def mock_dask_config():
with patch(
"prefect.infrastructure.provisioners.coiled.CoiledPushProvisioner._get_coiled_token"
"prefect.infrastructure.provisioners.coiled.CoiledPushProvisioner._get_coiled_creds"
) as mock:
mock.return_value = "local-api-token-from-dask-config"
mock.return_value = "local-api-token-from-dask-config", "my-workspace"
yield mock


Expand Down
Loading