forked from qodo-benchmark/prefect
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfetcher.py
More file actions
476 lines (383 loc) · 14.6 KB
/
fetcher.py
File metadata and controls
476 lines (383 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
"""
SDK data fetcher module.
This module is responsible for fetching deployment and work pool data from the
Prefect API and converting it to the internal data models used by the SDK generator.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from uuid import UUID
import prefect
# Logger for SDK fetcher operations
logger = logging.getLogger(__name__)
from prefect._sdk.models import (
DeploymentInfo,
FlowInfo,
SDKData,
SDKGenerationMetadata,
WorkPoolInfo,
)
from prefect.client.schemas.filters import (
DeploymentFilter,
DeploymentFilterName,
FlowFilter,
FlowFilterId,
FlowFilterName,
)
from prefect.exceptions import ObjectNotFound
from prefect.settings.context import get_current_settings
if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
from prefect.client.schemas.responses import DeploymentResponse
@dataclass
class FetchResult:
"""Result of fetching SDK data from the API.
Attributes:
data: The SDK data if fetching was successful.
warnings: List of warnings encountered during fetching.
errors: List of errors encountered during fetching (non-fatal).
"""
data: SDKData
warnings: list[str] = field(default_factory=list)
errors: list[str] = field(default_factory=list)
class SDKFetcherError(Exception):
"""Base exception for SDK fetcher errors."""
pass
class AuthenticationError(SDKFetcherError):
"""Raised when authentication with the Prefect API fails."""
pass
class APIConnectionError(SDKFetcherError):
"""Raised when the Prefect API cannot be reached."""
pass
class NoDeploymentsError(SDKFetcherError):
"""Raised when no deployments are found."""
pass
async def _check_authentication(client: "PrefectClient") -> None:
"""Check if the client is authenticated.
Args:
client: The Prefect client to check.
Raises:
AuthenticationError: If not authenticated.
APIConnectionError: If the API cannot be reached.
"""
try:
exc = await client.api_healthcheck()
if exc is not None:
# Check if it's an authentication error
exc_str = str(exc).lower()
if (
"unauthorized" in exc_str
or "forbidden" in exc_str
or "401" in exc_str
or "403" in exc_str
):
raise AuthenticationError(
"Not authenticated. Run `prefect cloud login` or configure "
"PREFECT_API_URL."
)
raise APIConnectionError(
f"Could not connect to Prefect API at {client.api_url}. "
f"Check your configuration. Error: {exc}"
)
except Exception as e:
if isinstance(e, (AuthenticationError, APIConnectionError)):
raise
raise APIConnectionError(
f"Could not connect to Prefect API at {client.api_url}. "
f"Check your configuration. Error: {e}"
) from e
async def _fetch_deployments(
client: "PrefectClient",
flow_filter: FlowFilter | None = None,
deployment_filter: DeploymentFilter | None = None,
) -> list["DeploymentResponse"]:
"""Fetch all deployments with pagination.
Args:
client: The Prefect client to use.
flow_filter: Optional filter for flows.
deployment_filter: Optional filter for deployments.
Returns:
List of deployment responses.
"""
page_size = 200
offset = 0
all_deployments: list[DeploymentResponse] = []
while True:
deployments = await client.read_deployments(
flow_filter=flow_filter,
deployment_filter=deployment_filter,
limit=page_size,
offset=offset,
)
if not deployments:
break
all_deployments.extend(deployments)
if len(deployments) < page_size:
break
offset += page_size
return all_deployments
async def _fetch_work_pool(
client: "PrefectClient",
work_pool_name: str,
) -> WorkPoolInfo | None:
"""Fetch a single work pool by name.
Args:
client: The Prefect client to use.
work_pool_name: The name of the work pool to fetch.
Returns:
WorkPoolInfo if found, None if not found.
Raises:
Exception: For non-ObjectNotFound errors (to be caught by gather).
"""
try:
work_pool = await client.read_work_pool(work_pool_name)
# Extract job variables schema from base_job_template
job_vars_schema: dict[str, Any] = {}
base_job_template = work_pool.base_job_template
if base_job_template and "variables" in base_job_template:
job_vars_schema = base_job_template["variables"]
return WorkPoolInfo(
name=work_pool.name,
pool_type=work_pool.type,
job_variables_schema=job_vars_schema,
)
except ObjectNotFound:
return None
# Let other exceptions propagate to be captured by gather(return_exceptions=True)
async def _fetch_work_pools_parallel(
client: "PrefectClient",
work_pool_names: set[str],
) -> tuple[dict[str, WorkPoolInfo], list[str]]:
"""Fetch multiple work pools in parallel.
Args:
client: The Prefect client to use.
work_pool_names: Set of work pool names to fetch.
Returns:
Tuple of (work_pools dict, warnings list).
"""
if not work_pool_names:
return {}, []
warnings: list[str] = []
# Convert to sorted list for deterministic iteration order
pool_names_list = sorted(work_pool_names)
# Fetch work pools in parallel
tasks = [_fetch_work_pool(client, name) for name in pool_names_list]
results = await asyncio.gather(*tasks, return_exceptions=True)
work_pools: dict[str, WorkPoolInfo] = {}
for name, result in zip(pool_names_list, results):
if isinstance(result, BaseException):
warnings.append(
f"Could not fetch work pool '{name}' - `with_infra()` will not be "
f"generated for affected deployments: {result}"
)
elif result is None:
warnings.append(
f"Work pool '{name}' not found - `with_infra()` will not be "
f"generated for affected deployments"
)
else:
# At this point, result is WorkPoolInfo
work_pools[name] = result
return work_pools, warnings
async def _fetch_flows_for_deployments(
client: "PrefectClient",
deployment_flow_ids: set[str],
) -> tuple[dict[str, str], list[str]]:
"""Fetch flow names for the given flow IDs.
Args:
client: The Prefect client to use.
deployment_flow_ids: Set of flow IDs (as strings) to look up.
Returns:
Tuple of (dict mapping flow_id to flow_name, list of warnings).
"""
if not deployment_flow_ids:
return {}, []
warnings: list[str] = []
flow_uuids: list[UUID] = []
# Convert string IDs to UUIDs with defensive handling
for fid in deployment_flow_ids:
try:
flow_uuids.append(UUID(fid))
except (ValueError, TypeError) as e:
warnings.append(f"Invalid flow ID '{fid}' - skipping: {e}")
if not flow_uuids:
return {}, warnings
# Fetch flows by ID with pagination
page_size = 200
offset = 0
flow_id_to_name: dict[str, str] = {}
flow_filter = FlowFilter(id=FlowFilterId(any_=flow_uuids))
while True:
flows = await client.read_flows(
flow_filter=flow_filter,
limit=page_size,
offset=offset,
)
if not flows:
break
for flow in flows:
flow_id_to_name[str(flow.id)] = flow.name
if len(flows) < page_size:
break
offset += page_size
return flow_id_to_name, warnings
async def fetch_sdk_data(
client: "PrefectClient",
flow_names: list[str] | None = None,
deployment_names: list[str] | None = None,
) -> FetchResult:
"""Fetch all data needed for SDK generation.
Args:
client: An active Prefect client.
flow_names: Optional list of flow names to filter to.
deployment_names: Optional list of deployment names to filter to.
These should be in "flow-name/deployment-name" format for exact
matching. Short names (without "/") will match any deployment
with that name across all flows.
Returns:
FetchResult containing SDK data and any warnings/errors.
Raises:
AuthenticationError: If not authenticated.
APIConnectionError: If the API cannot be reached.
NoDeploymentsError: If no deployments match the filters.
"""
warnings: list[str] = []
errors: list[str] = []
# Check authentication first
logger.debug("Checking authentication with Prefect API")
await _check_authentication(client)
# Build filters
flow_filter: FlowFilter | None = None
deployment_filter: DeploymentFilter | None = None
if flow_names:
flow_filter = FlowFilter(name=FlowFilterName(any_=flow_names))
if deployment_names:
# Extract just the deployment name parts (after the /)
deploy_name_parts = []
for full_name in deployment_names:
if "/" in full_name:
_, deploy_name = full_name.split("/", 1)
deploy_name_parts.append(deploy_name)
else:
deploy_name_parts.append(full_name)
deployment_filter = DeploymentFilter(
name=DeploymentFilterName(any_=deploy_name_parts)
)
# Fetch deployments
deployment_responses = await _fetch_deployments(
client,
flow_filter=flow_filter,
deployment_filter=deployment_filter,
)
if not deployment_responses:
if flow_names or deployment_names:
raise NoDeploymentsError(
f"No deployments matched filters. "
f"Filters: flow_names={flow_names}, deployment_names={deployment_names}"
)
raise NoDeploymentsError("No deployments found in workspace.")
# Get unique flow IDs and work pool names
flow_ids: set[str] = set()
work_pool_names: set[str] = set()
for dep in deployment_responses:
flow_ids.add(str(dep.flow_id))
if dep.work_pool_name:
work_pool_names.add(dep.work_pool_name)
# Fetch flow names and work pools in parallel
flow_names_task = _fetch_flows_for_deployments(client, flow_ids)
work_pools_task = _fetch_work_pools_parallel(client, work_pool_names)
(flow_id_to_name, flow_warnings), (work_pools, wp_warnings) = await asyncio.gather(
flow_names_task, work_pools_task
)
warnings.extend(flow_warnings)
warnings.extend(wp_warnings)
# Group deployments by flow
flows: dict[str, FlowInfo] = {}
# Track short name -> full names mapping to detect ambiguity
short_name_matches: dict[str, list[str]] = {} # short_name -> list of full_names
for dep in deployment_responses:
flow_id = str(dep.flow_id)
flow_name = flow_id_to_name.get(flow_id)
if not flow_name:
errors.append(
f"Could not find flow name for deployment '{dep.name}' "
f"(flow_id={flow_id}) - skipping"
)
continue
# If filtering by deployment name, check the full name matches
full_name = f"{flow_name}/{dep.name}"
if deployment_names and dep.name not in deployment_names:
# Only include if the full name matches (filter was by name parts)
# Skip if user specified full names and this doesn't match
found_match = False
matched_short_name: str | None = None
for dn in deployment_names:
if "/" not in dn:
# User gave just deployment name, check against dep.name
if dep.name == dn:
found_match = True
matched_short_name = dn
break
else:
# User gave full name, must match exactly
if full_name == dn:
found_match = True
break
if not found_match:
continue
# Track short name matches for ambiguity warning
if matched_short_name:
if matched_short_name not in short_name_matches:
short_name_matches[matched_short_name] = []
short_name_matches[matched_short_name].append(full_name)
# Create DeploymentInfo
deployment_info = DeploymentInfo(
name=dep.name,
flow_name=flow_name,
full_name=full_name,
parameter_schema=dep.parameter_openapi_schema,
work_pool_name=dep.work_pool_name,
description=dep.description,
)
# Add to flow
if flow_name not in flows:
flows[flow_name] = FlowInfo(name=flow_name, deployments=[])
flows[flow_name].deployments.append(deployment_info)
if not flows:
if flow_names or deployment_names:
raise NoDeploymentsError(
f"No deployments matched filters after processing. "
f"Filters: flow_names={flow_names}, deployment_names={deployment_names}"
)
raise NoDeploymentsError("No deployments could be processed.")
# Warn about ambiguous short names that matched multiple flows
for short_name, full_names in short_name_matches.items():
if len(full_names) > 1:
warnings.append(
f"Short deployment name '{short_name}' matched {len(full_names)} "
f"deployments across different flows: {', '.join(sorted(full_names))}. "
f"Consider using full names (flow/deployment) for precise filtering."
)
# Build metadata - prefer client.api_url, fall back to setting for provenance
client_api_url = getattr(client, "api_url", None)
api_url = str(client_api_url) if client_api_url else get_current_settings().api.url
metadata = SDKGenerationMetadata(
generation_time=datetime.now(timezone.utc).isoformat(),
prefect_version=prefect.__version__,
workspace_name=None, # Could be extracted from Cloud API URL if needed
api_url=api_url,
)
# Build final SDK data
sdk_data = SDKData(
metadata=metadata,
flows=flows,
work_pools=work_pools,
)
return FetchResult(
data=sdk_data,
warnings=warnings,
errors=errors,
)