-
Notifications
You must be signed in to change notification settings - Fork 179
Expand file tree
/
Copy pathaggregator.py
More file actions
537 lines (432 loc) · 18.8 KB
/
Copy pathaggregator.py
File metadata and controls
537 lines (432 loc) · 18.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
"""Fair share aggregator for computing kernel resource usage.
This module provides the FairShareAggregator that performs pure computation:
1. Prepares kernel usage records (resource-seconds) in 5-minute slices
2. Aggregates usage deltas by user/project/domain for bucket updates
3. Calculates scheduling ranks for fair share sequencing (future)
The aggregator is stateless and does not interact with databases directly.
Repository operations are handled by FairShareObserver.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from uuid import UUID
from ai.backend.common.types import ResourceSlot
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.repositories.resource_usage_history import (
KernelUsageRecordCreatorSpec,
)
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
# Observation slice duration (5 minutes)
SLICE_DURATION_SECONDS = 300
if TYPE_CHECKING:
from ai.backend.manager.data.kernel.types import KernelInfo
# =============================================================================
# Bucket Key Types
# =============================================================================
@dataclass(frozen=True)
class UserUsageBucketKey:
"""Key for user usage bucket aggregation.
Uniquely identifies a user's usage bucket within a resource group and time period.
"""
user_uuid: UUID
project_id: UUID
domain_name: str
resource_group: str
period_date: date
@dataclass(frozen=True)
class ProjectUsageBucketKey:
"""Key for project usage bucket aggregation.
Uniquely identifies a project's usage bucket within a resource group and time period.
"""
project_id: UUID
domain_name: str
resource_group: str
period_date: date
@dataclass(frozen=True)
class DomainUsageBucketKey:
"""Key for domain usage bucket aggregation.
Uniquely identifies a domain's usage bucket within a resource group and time period.
"""
domain_name: str
resource_group: str
period_date: date
# =============================================================================
# Result Types
# =============================================================================
@dataclass
class KernelUsagePreparationResult:
"""Result of preparing kernel usage records.
Attributes:
specs: List of kernel usage record specs for bulk creation
kernel_observation_times: Mapping of kernel_id to last_observed_at timestamp
observed_count: Number of kernels with generated specs
"""
specs: list[KernelUsageRecordCreatorSpec] = field(default_factory=list)
kernel_observation_times: dict[UUID, datetime] = field(default_factory=dict)
observed_count: int = 0
@dataclass
class BucketDelta:
"""Separated resource amount and duration for a usage bucket.
Stores raw resource amounts and duration separately instead of
pre-multiplied resource-seconds. The product ``amount * duration_seconds``
is computed at SQL query time where PostgreSQL auto-extends NUMERIC precision,
eliminating overflow risk for large memory values.
Attributes:
slots: Raw resource amounts (e.g., {"cpu": 2, "mem": 4096000000})
duration_seconds: Total usage duration in seconds
"""
slots: ResourceSlot = field(default_factory=ResourceSlot)
duration_seconds: int = 0
@dataclass
class UsageBucketAggregationResult:
"""Result of aggregating kernel usage into daily buckets.
Attributes:
user_usage_deltas: Aggregated usage deltas by user for bucket updates
project_usage_deltas: Aggregated usage deltas by project for bucket updates
domain_usage_deltas: Aggregated usage deltas by domain for bucket updates
"""
user_usage_deltas: dict[UserUsageBucketKey, BucketDelta] = field(default_factory=dict)
project_usage_deltas: dict[ProjectUsageBucketKey, BucketDelta] = field(default_factory=dict)
domain_usage_deltas: dict[DomainUsageBucketKey, BucketDelta] = field(default_factory=dict)
class FairShareAggregator:
"""Computes kernel usage for fair share scheduling.
This class performs pure computation without database interactions.
It prepares kernel usage records aligned to 5-minute clock boundaries.
Partial slices are only allowed at:
- Start: When starts_at is not on a boundary (first observation only)
- End: When the kernel terminates at a non-boundary time
For RUNNING kernels, only complete slices up to the last boundary are generated.
"""
def prepare_kernel_usage_records(
self,
kernels: Sequence[KernelInfo],
scaling_group: str,
now: datetime,
) -> KernelUsagePreparationResult:
"""Prepare kernel usage records for bulk creation.
For each kernel:
1. Determine period: last_observed_at (or starts_at) to end_time
- RUNNING kernels: end_time = floor(now) to last boundary
- TERMINATED kernels: end_time = terminated_at
2. Split period into 5-minute slices aligned to clock boundaries
3. Create usage record specs for each slice
Args:
kernels: Kernels to process
scaling_group: The scaling group name
now: Current time from DB
Returns:
KernelUsagePreparationResult containing specs and metadata
"""
result = KernelUsagePreparationResult()
for kernel in kernels:
kernel_specs, observation_end = self._prepare_kernel_usage_specs(
kernel, scaling_group, now
)
if kernel_specs:
result.specs.extend(kernel_specs)
result.kernel_observation_times[UUID(str(kernel.id))] = observation_end
result.observed_count += 1
return result
def aggregate_kernel_usage_to_buckets(
self,
specs: Sequence[KernelUsageRecordCreatorSpec],
) -> UsageBucketAggregationResult:
"""Aggregate kernel usage specs into daily bucket deltas.
Splits each spec's resource usage across day boundaries and aggregates
by user/project/domain. Buckets are aligned to day boundaries (midnight).
For example, a spec covering 23:57-00:03 will be split:
- 23:57-00:00 (3 minutes) -> day 1 bucket
- 00:00-00:03 (3 minutes) -> day 2 bucket
Args:
specs: Kernel usage record specs to aggregate
Returns:
UsageBucketAggregationResult with deltas for each bucket
"""
user_deltas: dict[UserUsageBucketKey, BucketDelta] = defaultdict(BucketDelta)
project_deltas: dict[ProjectUsageBucketKey, BucketDelta] = defaultdict(BucketDelta)
domain_deltas: dict[DomainUsageBucketKey, BucketDelta] = defaultdict(BucketDelta)
for spec in specs:
# Split spec across day boundaries and aggregate
daily_splits = self._split_spec_by_day(spec)
for period_date, raw_slots, segment_seconds in daily_splits:
self._add_to_bucket_deltas(
spec=spec,
period_date=period_date,
raw_slots=raw_slots,
segment_seconds=segment_seconds,
user_deltas=user_deltas,
project_deltas=project_deltas,
domain_deltas=domain_deltas,
)
return UsageBucketAggregationResult(
user_usage_deltas=dict(user_deltas),
project_usage_deltas=dict(project_deltas),
domain_usage_deltas=dict(domain_deltas),
)
def _split_spec_by_day(
self,
spec: KernelUsageRecordCreatorSpec,
) -> list[tuple[date, ResourceSlot, int]]:
"""Split a spec's resource usage across day boundaries.
Most 5-minute specs will fit within a single day, but specs crossing
midnight (e.g., 23:57-00:02) need to be split.
Returns the raw (un-multiplied) occupied_slots and the segment
duration in seconds so that callers can store them separately
in the normalized usage_bucket_entries table.
Args:
spec: Kernel usage record spec to split
Returns:
List of (period_date, raw_slots, segment_seconds) tuples.
``raw_slots`` is the kernel's occupied_slots (not resource-seconds).
"""
raw_slots = spec.occupied_slots or ResourceSlot()
result: list[tuple[date, ResourceSlot, int]] = []
total_seconds = (spec.period_end - spec.period_start).total_seconds()
if total_seconds <= 0:
return result
# Fast path: most specs don't cross midnight
if spec.period_start.date() == spec.period_end.date():
return [(spec.period_start.date(), raw_slots, int(total_seconds))]
# Slow path: split across day boundaries
current_start = spec.period_start
while current_start < spec.period_end:
current_date = current_start.date()
# Next midnight
next_midnight = datetime.combine(
current_date + timedelta(days=1),
datetime.min.time(),
tzinfo=current_start.tzinfo,
)
# Segment ends at next midnight or spec end, whichever is earlier
segment_end = min(next_midnight, spec.period_end)
segment_seconds = int((segment_end - current_start).total_seconds())
if segment_seconds > 0:
result.append((current_date, raw_slots, segment_seconds))
current_start = segment_end
return result
def _add_to_bucket_deltas(
self,
spec: KernelUsageRecordCreatorSpec,
period_date: date,
raw_slots: ResourceSlot,
segment_seconds: int,
user_deltas: dict[UserUsageBucketKey, BucketDelta],
project_deltas: dict[ProjectUsageBucketKey, BucketDelta],
domain_deltas: dict[DomainUsageBucketKey, BucketDelta],
) -> None:
"""Add resource usage to bucket deltas for a day.
Accumulates raw resource amounts and duration separately.
Slots are accumulated additively (sum of ``raw_slots`` across all
slices within the same bucket key) while ``duration_seconds`` tracks
total observation time.
Args:
spec: Original spec (for entity identifiers)
period_date: Date of the bucket
raw_slots: Raw occupied_slots (not pre-multiplied)
segment_seconds: Duration of this segment in seconds
user_deltas: User deltas to update (mutated)
project_deltas: Project deltas to update (mutated)
domain_deltas: Domain deltas to update (mutated)
"""
# User bucket key
user_key = UserUsageBucketKey(
user_uuid=spec.user_uuid,
project_id=spec.project_id,
domain_name=spec.domain_name,
resource_group=spec.resource_group,
period_date=period_date,
)
ud = user_deltas[user_key]
user_deltas[user_key] = BucketDelta(
slots=ud.slots + raw_slots,
duration_seconds=ud.duration_seconds + segment_seconds,
)
# Project bucket key
project_key = ProjectUsageBucketKey(
project_id=spec.project_id,
domain_name=spec.domain_name,
resource_group=spec.resource_group,
period_date=period_date,
)
pd = project_deltas[project_key]
project_deltas[project_key] = BucketDelta(
slots=pd.slots + raw_slots,
duration_seconds=pd.duration_seconds + segment_seconds,
)
# Domain bucket key
domain_key = DomainUsageBucketKey(
domain_name=spec.domain_name,
resource_group=spec.resource_group,
period_date=period_date,
)
dd = domain_deltas[domain_key]
domain_deltas[domain_key] = BucketDelta(
slots=dd.slots + raw_slots,
duration_seconds=dd.duration_seconds + segment_seconds,
)
def _prepare_kernel_usage_specs(
self,
kernel: KernelInfo,
scaling_group: str,
now: datetime,
) -> tuple[list[KernelUsageRecordCreatorSpec], datetime]:
"""Prepare usage record specs for a single kernel.
Generates 5-minute slices aligned to clock boundaries.
Partial slices are only allowed at:
- Start: When starts_at is not on a boundary (first observation only)
- End: When the kernel terminates at a non-boundary time
For RUNNING kernels, only complete slices up to the last boundary are generated.
Args:
kernel: Kernel to process
scaling_group: The scaling group
now: Current time from DB
Returns:
Tuple of (list of specs, last_observed_at to save)
"""
is_first_observation = kernel.lifecycle.last_observed_at is None
is_terminated = kernel.lifecycle.terminated_at is not None
# Determine start time
start_time: datetime
if is_first_observation:
if kernel.lifecycle.starts_at is None:
raise ValueError(f"Kernel {kernel.id} has no starts_at for first observation")
start_time = kernel.lifecycle.starts_at
else:
if kernel.lifecycle.last_observed_at is None:
raise ValueError(
f"Kernel {kernel.id} has no last_observed_at for non-first observation"
)
start_time = kernel.lifecycle.last_observed_at
# Determine end time based on kernel status
end_time: datetime
if is_terminated:
# TERMINATED: use terminated_at (allows partial end slice)
if kernel.lifecycle.terminated_at is None:
raise ValueError(
f"Kernel {kernel.id} marked as terminated but has no terminated_at"
)
end_time = kernel.lifecycle.terminated_at
else:
# RUNNING: floor to last complete boundary (no partial end slice)
end_time = self._floor_to_boundary(now)
log.debug(
"[Aggregator] Kernel {}: is_first={}, is_terminated={}, "
"start_time={}, end_time={}, now={}",
kernel.id,
is_first_observation,
is_terminated,
start_time,
end_time,
now,
)
# Validate time range
if end_time <= start_time:
# Not enough time has passed to generate any slices
log.debug(
"[Aggregator] Kernel {}: skipped (end_time <= start_time)",
kernel.id,
)
return [], start_time
# Generate 5-minute slices
specs = self._generate_slice_specs(
kernel=kernel,
scaling_group=scaling_group,
start_time=start_time,
end_time=end_time,
)
log.debug(
"[Aggregator] Kernel {}: generated {} specs",
kernel.id,
len(specs),
)
# Return end_time as the new last_observed_at
# For RUNNING: this is floored to boundary
# For TERMINATED: this is terminated_at
return specs, end_time
def _generate_slice_specs(
self,
kernel: KernelInfo,
scaling_group: str,
start_time: datetime,
end_time: datetime,
) -> list[KernelUsageRecordCreatorSpec]:
"""Generate 5-minute slice specs aligned to clock boundaries.
Slices are aligned to 5-minute clock boundaries (00:00, 00:05, 00:10, etc.).
The caller is responsible for adjusting start_time and end_time to control
whether partial slices are allowed.
Args:
kernel: Kernel info
scaling_group: The scaling group
start_time: Start of observation period
end_time: End of observation period
Returns:
List of usage record specs for each slice
"""
specs: list[KernelUsageRecordCreatorSpec] = []
current_start = start_time
while current_start < end_time:
# Calculate next 5-minute boundary after current_start
start_epoch = int(current_start.timestamp())
next_boundary_epoch = (
(start_epoch // SLICE_DURATION_SECONDS) + 1
) * SLICE_DURATION_SECONDS
next_boundary = datetime.fromtimestamp(next_boundary_epoch, tz=current_start.tzinfo)
# Slice ends at the next boundary or end_time, whichever is earlier
current_end = min(next_boundary, end_time)
slice_seconds = (current_end - current_start).total_seconds()
if slice_seconds <= 0:
break
# Calculate resource-seconds for this slice
resource_seconds = self._calculate_resource_seconds(
kernel.resource.occupied_slots,
slice_seconds,
)
spec = KernelUsageRecordCreatorSpec(
kernel_id=UUID(str(kernel.id)),
session_id=UUID(kernel.session.session_id),
user_uuid=kernel.user_permission.owner_id,
project_id=kernel.user_permission.group_id,
domain_name=kernel.user_permission.domain_name,
resource_group=scaling_group,
period_start=current_start,
period_end=current_end,
resource_usage=resource_seconds,
occupied_slots=kernel.resource.occupied_slots,
)
specs.append(spec)
current_start = current_end
return specs
def _floor_to_boundary(self, dt: datetime) -> datetime:
"""Floor datetime to the previous 5-minute boundary.
Examples:
07:47:30 -> 07:45:00
07:50:00 -> 07:50:00 (already on boundary)
07:52:15 -> 07:50:00
Args:
dt: Datetime to floor
Returns:
Datetime floored to 5-minute boundary
"""
epoch = int(dt.timestamp())
floored_epoch = (epoch // SLICE_DURATION_SECONDS) * SLICE_DURATION_SECONDS
return datetime.fromtimestamp(floored_epoch, tz=dt.tzinfo)
def _calculate_resource_seconds(
self,
slots: ResourceSlot,
seconds: float,
) -> ResourceSlot:
"""Convert resource slots to resource-seconds.
Multiplies each resource value by the number of seconds to get
the total resource-seconds consumed during the period.
Args:
slots: Resource slots (e.g., {"cpu": 2, "mem": 4096})
seconds: Duration in seconds
Returns:
ResourceSlot with values in resource-seconds
"""
return ResourceSlot({key: value * Decimal(str(seconds)) for key, value in slots.items()})