diff --git a/lib/iris/tests/test_budget.py b/lib/iris/tests/test_budget.py index 8b474da7be..6032e51fcc 100644 --- a/lib/iris/tests/test_budget.py +++ b/lib/iris/tests/test_budget.py @@ -1,6 +1,15 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 +"""Tests for budget tracking (resource_value / compute_user_spend / interleave_by_user) +and the admin API RPCs that expose them.""" + +import pytest +from connectrpc.code import Code +from connectrpc.errors import ConnectError + +from iris.cluster.bundle import BundleStore +from iris.cluster.controller.auth import ControllerAuth from iris.cluster.controller.budget import ( UserTask, compute_effective_band, @@ -8,14 +17,21 @@ interleave_by_user, resource_value, ) -from iris.cluster.controller.db import ControllerDB -from iris.cluster.controller.transitions import ControllerTransitions +from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.transitions import Assignment, HeartbeatApplyRequest, TaskUpdate from iris.cluster.types import JobName, WorkerId -from iris.rpc import job_pb2 -from iris.rpc import controller_pb2 +from iris.log_server.server import LogServiceImpl +from iris.rpc import controller_pb2, job_pb2 +from iris.rpc.auth import VerifiedIdentity, _verified_identity from iris.rpc.proto_utils import PRIORITY_BAND_VALUES, priority_band_name, priority_band_value from rigging.timing import Timestamp +from tests.cluster.controller.conftest import ( + MockController, + make_controller_state, + make_test_entrypoint, +) + PRODUCTION = job_pb2.PRIORITY_BAND_PRODUCTION INTERACTIVE = job_pb2.PRIORITY_BAND_INTERACTIVE BATCH = job_pb2.PRIORITY_BAND_BATCH @@ -23,305 +39,359 @@ GiB = 1024**3 -def test_resource_value_cpu_only(): - # 4 cores, 16 GiB RAM, no accelerators - assert resource_value(cpu_millicores=4000, memory_bytes=16 * GiB, accelerator_count=0) == 5 * 4 + 16 +@pytest.fixture +def state(): + """Fresh ControllerTransitions with a temp DB, cleaned up on exit.""" + with make_controller_state() as s: + yield s -def test_resource_value_gpu(): - # 8 cores, 64 GiB, 4 GPUs - assert resource_value(cpu_millicores=8000, memory_bytes=64 * GiB, accelerator_count=4) == 1000 * 4 + 64 + 5 * 8 +# --------------------------------------------------------------------------- +# resource_value +# --------------------------------------------------------------------------- -def test_resource_value_tpu(): - # 96 cores, 320 GiB, 8 TPU chips - expected = 1000 * 8 + 320 + 5 * 96 - assert resource_value(cpu_millicores=96000, memory_bytes=320 * GiB, accelerator_count=8) == expected +@pytest.mark.parametrize( + "cpu_millicores,memory_bytes,accelerator_count,expected", + [ + (4000, 16 * GiB, 0, 5 * 4 + 16), # CPU only + (8000, 64 * GiB, 4, 1000 * 4 + 64 + 5 * 8), # 4 GPUs + (96000, 320 * GiB, 8, 1000 * 8 + 320 + 5 * 96), # 8 TPU chips + (1500, int(1.5 * GiB), 0, 5 * 1 + 1), # fractional → truncated + ], +) +def test_resource_value(cpu_millicores, memory_bytes, accelerator_count, expected): + assert resource_value(cpu_millicores, memory_bytes, accelerator_count) == expected -def test_resource_value_zero_resources(): - assert resource_value(cpu_millicores=0, memory_bytes=0, accelerator_count=0) == 0 +# --------------------------------------------------------------------------- +# interleave_by_user +# --------------------------------------------------------------------------- -def test_resource_value_truncates_fractional(): - # 1500 millicores = 1 core (truncated), 1.5 GiB = 1 GiB (truncated) - assert resource_value(cpu_millicores=1500, memory_bytes=int(1.5 * GiB), accelerator_count=0) == 5 * 1 + 1 +def test_interleave_by_user_empty(): + assert interleave_by_user([], user_spend={}) == [] -def test_interleave_by_user_single_user(): +def test_interleave_by_user_single_user_preserves_order(): tasks = [UserTask("alice", "t1"), UserTask("alice", "t2"), UserTask("alice", "t3")] - result = interleave_by_user(tasks, user_spend={}) - assert result == ["t1", "t2", "t3"] + assert interleave_by_user(tasks, user_spend={}) == ["t1", "t2", "t3"] -def test_interleave_by_user_two_users_equal_spend(): - tasks = [UserTask("alice", "a1"), UserTask("alice", "a2"), UserTask("bob", "b1"), UserTask("bob", "b2")] - result = interleave_by_user(tasks, user_spend={"alice": 100, "bob": 100}) - # Equal spend: stable sort by user name, then round-robin - assert result == ["a1", "b1", "a2", "b2"] or result == ["b1", "a1", "b2", "a2"] - assert len(result) == 4 +def test_interleave_by_user_lower_spend_goes_first(): + tasks = [ + UserTask("alice", "a1"), + UserTask("alice", "a2"), + UserTask("bob", "b1"), + UserTask("bob", "b2"), + ] + # Bob has spent less, so his task goes first in each round. + assert interleave_by_user(tasks, user_spend={"alice": 8000, "bob": 1000}) == [ + "b1", + "a1", + "b2", + "a2", + ] -def test_interleave_by_user_spend_ordering(): +def test_interleave_by_user_missing_spend_defaults_to_zero(): + tasks = [UserTask("alice", "a1"), UserTask("bob", "b1")] + # Alice has no spend row → defaults to 0 (< bob's 5000), so alice goes first. + assert interleave_by_user(tasks, user_spend={"bob": 5000}) == ["a1", "b1"] + + +def test_interleave_by_user_three_users_unequal_counts(): tasks = [ UserTask("alice", "a1"), UserTask("alice", "a2"), UserTask("bob", "b1"), - UserTask("bob", "b2"), + UserTask("charlie", "c1"), + UserTask("charlie", "c2"), + UserTask("charlie", "c3"), ] - # Bob has spent less, so his tasks should come first in each round - result = interleave_by_user(tasks, user_spend={"alice": 8000, "bob": 1000}) - assert result == ["b1", "a1", "b2", "a2"] + # Spend order: bob (100) < charlie (3000) < alice (5000). + result = interleave_by_user(tasks, user_spend={"alice": 5000, "bob": 100, "charlie": 3000}) + assert result == ["b1", "c1", "a1", "c2", "a2", "c3"] -def test_interleave_by_user_unequal_task_counts(): - tasks = [UserTask("alice", "a1"), UserTask("alice", "a2"), UserTask("alice", "a3"), UserTask("bob", "b1")] - result = interleave_by_user(tasks, user_spend={"alice": 0, "bob": 0}) - # Round 0: a1, b1; Round 1: a2; Round 2: a3 - assert result[0] in ("a1", "b1") - assert result[1] in ("a1", "b1") - assert "a2" in result - assert "a3" in result - assert len(result) == 4 +# --------------------------------------------------------------------------- +# compute_effective_band +# --------------------------------------------------------------------------- -def test_interleave_by_user_empty(): - assert interleave_by_user([], user_spend={}) == [] +@pytest.mark.parametrize( + "task_band,spend,limit,expected", + [ + (INTERACTIVE, 10000, 5000, BATCH), # over budget → demoted + (INTERACTIVE, 3000, 5000, INTERACTIVE), # within budget → kept + (PRODUCTION, 10000, 5000, PRODUCTION), # production never demoted + (INTERACTIVE, 999999, 0, INTERACTIVE), # limit=0 means unlimited + (BATCH, 10000, 5000, BATCH), # batch stays batch + ], +) +def test_effective_band(task_band, spend, limit, expected): + assert compute_effective_band(task_band, "alice", {"alice": spend}, {"alice": limit}) == expected -def test_interleave_by_user_missing_spend_defaults_to_zero(): - tasks = [UserTask("alice", "a1"), UserTask("bob", "b1")] - # Alice has no spend entry → defaults to 0, Bob has 5000 - result = interleave_by_user(tasks, user_spend={"bob": 5000}) - assert result == ["a1", "b1"] +def test_effective_band_no_limit_row_is_unlimited(): + assert compute_effective_band(INTERACTIVE, "alice", {"alice": 999999}, {}) == INTERACTIVE -def test_effective_band_over_budget_becomes_batch(): - """INTERACTIVE task becomes BATCH when user exceeds budget.""" - spend = {"alice": 10000} - limits = {"alice": 5000} - assert compute_effective_band(INTERACTIVE, "alice", spend, limits) == BATCH +# --------------------------------------------------------------------------- +# priority_band helpers +# --------------------------------------------------------------------------- -def test_effective_band_within_budget_keeps_band(): - """INTERACTIVE task stays INTERACTIVE when user is within budget.""" - spend = {"alice": 3000} - limits = {"alice": 5000} - assert compute_effective_band(INTERACTIVE, "alice", spend, limits) == INTERACTIVE +def test_priority_band_name_roundtrip(): + for band in PRIORITY_BAND_VALUES: + assert priority_band_value(priority_band_name(band)) == band -def test_effective_band_production_never_downgraded(): - """PRODUCTION tasks are never downgraded, even when over budget.""" - spend = {"alice": 10000} - limits = {"alice": 5000} - assert compute_effective_band(PRODUCTION, "alice", spend, limits) == PRODUCTION +# --------------------------------------------------------------------------- +# compute_user_spend +# --------------------------------------------------------------------------- -def test_effective_band_zero_budget_means_unlimited(): - """budget_limit=0 means no down-weighting regardless of spend.""" - spend = {"alice": 999999} - limits = {"alice": 0} - assert compute_effective_band(INTERACTIVE, "alice", spend, limits) == INTERACTIVE +def _launch_request( + name: str, + cpu_millicores: int = 4000, + memory_bytes: int = 16 * GiB, + include_resources: bool = True, + replicas: int = 1, + band: int = 0, +) -> controller_pb2.Controller.LaunchJobRequest: + req = controller_pb2.Controller.LaunchJobRequest( + name=name, + entrypoint=make_test_entrypoint(), + environment=job_pb2.EnvironmentConfig(), + replicas=replicas, + priority_band=band, + ) + if include_resources: + req.resources.CopyFrom(job_pb2.ResourceSpecProto(cpu_millicores=cpu_millicores, memory_bytes=memory_bytes)) + return req + + +def _start_running_job( + state, + user: str, + job_name: str, + *, + cpu_millicores: int = 4000, + memory_bytes: int = 16 * GiB, + replicas: int = 1, + include_resources: bool = True, +) -> None: + """Submit a job, register a worker, and transition each task to RUNNING.""" + job_id = JobName.root(user, job_name) + request = _launch_request( + job_id.to_wire(), + cpu_millicores=cpu_millicores, + memory_bytes=memory_bytes, + include_resources=include_resources, + replicas=replicas, + ) + state.submit_job(job_id, request, Timestamp.now()) + + worker_id = WorkerId(f"w-{user}") + state.register_or_refresh_worker( + worker_id=worker_id, + address=f"{worker_id}:8080", + metadata=job_pb2.WorkerMetadata( + hostname=str(worker_id), + ip_address="127.0.0.1", + cpu_count=16, + memory_bytes=64 * GiB, + disk_bytes=100 * GiB, + ), + ts=Timestamp.now(), + ) + for idx in range(replicas): + task_id = job_id.task(idx) + state.queue_assignments([Assignment(task_id=task_id, worker_id=worker_id)]) + state.apply_task_updates( + HeartbeatApplyRequest( + worker_id=worker_id, + worker_resource_snapshot=None, + updates=[TaskUpdate(task_id=task_id, attempt_id=0, new_state=job_pb2.TASK_STATE_RUNNING)], + ) + ) -def test_effective_band_no_budget_row_means_unlimited(): - """User with no budget row is treated as unlimited.""" - spend = {"alice": 999999} - limits = {} # no row for alice - assert compute_effective_band(INTERACTIVE, "alice", spend, limits) == INTERACTIVE +def test_compute_user_spend_empty(state): + with state._db.snapshot() as snap: + assert compute_user_spend(snap) == {} -def test_effective_band_batch_stays_batch_when_over_budget(): - """Already-BATCH task stays BATCH (max of requested and BATCH).""" - spend = {"alice": 10000} - limits = {"alice": 5000} - assert compute_effective_band(BATCH, "alice", spend, limits) == BATCH +def test_compute_user_spend_sums_running_tasks(state): + _start_running_job(state, "alice", "job", cpu_millicores=4000, memory_bytes=16 * GiB, replicas=2) + with state._db.snapshot() as snap: + spend = compute_user_spend(snap) + assert spend["alice"] == resource_value(4000, 16 * GiB, 0) * 2 -def test_priority_band_name_roundtrip(): - for band in PRIORITY_BAND_VALUES: - name = priority_band_name(band) - assert priority_band_value(name) == band +def test_compute_user_spend_excludes_pending(state): + """Tasks that never reach RUNNING/ASSIGNED/BUILDING do not contribute.""" + job_id = JobName.root("bob", "pending") + request = _launch_request(job_id.to_wire(), cpu_millicores=2000, memory_bytes=8 * GiB) + state.submit_job(job_id, request, Timestamp.now()) + with state._db.snapshot() as snap: + assert compute_user_spend(snap).get("bob", 0) == 0 -def test_priority_band_values_are_ordered(): - """Proto enum values are ordered: PRODUCTION < INTERACTIVE < BATCH.""" - assert job_pb2.PRIORITY_BAND_PRODUCTION < job_pb2.PRIORITY_BAND_INTERACTIVE - assert job_pb2.PRIORITY_BAND_INTERACTIVE < job_pb2.PRIORITY_BAND_BATCH +def test_compute_user_spend_null_resources_proto(state): + """Regression: res_device_json is NULL when LaunchJobRequest omits resources.""" + _start_running_job(state, "carol", "no-resources", include_resources=False) + with state._db.snapshot() as snap: + assert compute_user_spend(snap).get("carol", 0) == 0 # --------------------------------------------------------------------------- -# compute_user_spend — direct unit tests +# Budget admin API (service layer) # --------------------------------------------------------------------------- -def test_compute_user_spend_empty(tmp_path): - """No active tasks → empty spend dict.""" - db = ControllerDB(db_dir=tmp_path) +@pytest.fixture +def service(state, tmp_path) -> ControllerServiceImpl: + """ControllerServiceImpl wired with static-provider auth so that + priority-band authorization triggers (see launch_job band check).""" + return ControllerServiceImpl( + state, + state._db, + controller=MockController(), + bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), + log_service=LogServiceImpl(), + auth=ControllerAuth(provider="static"), + ) + + +def _as_admin(fn, *args, **kwargs): + reset = _verified_identity.set(VerifiedIdentity(user_id="admin", role="admin")) try: - with db.snapshot() as snap: - spend = compute_user_spend(snap) - assert spend == {} + return fn(*args, **kwargs) finally: - db.close() + _verified_identity.reset(reset) -def test_compute_user_spend_counts_running_tasks(tmp_path): - """Active tasks contribute to user spend based on resource value.""" - db = ControllerDB(db_dir=tmp_path) - state = ControllerTransitions(db=db) +def _as_user(fn, user_id, *args, **kwargs): + reset = _verified_identity.set(VerifiedIdentity(user_id=user_id, role="user")) try: - # Submit a job with 2 tasks - job_id = JobName.root("alice", "test-job") - request = controller_pb2.Controller.LaunchJobRequest( - name=job_id.to_wire(), - entrypoint=job_pb2.RuntimeEntrypoint(), - resources=job_pb2.ResourceSpecProto(cpu_millicores=4000, memory_bytes=16 * GiB), - environment=job_pb2.EnvironmentConfig(), - replicas=2, - ) - request.entrypoint.run_command.argv[:] = ["echo", "hi"] - state.submit_job(job_id, request, Timestamp.now()) - - # Register worker and assign both tasks - w1 = WorkerId("w1") - state.register_or_refresh_worker( - worker_id=w1, - address="w1:8080", - metadata=job_pb2.WorkerMetadata( - hostname="w1", - ip_address="127.0.0.1", - cpu_count=16, - memory_bytes=64 * GiB, - disk_bytes=100 * GiB, - ), - ts=Timestamp.now(), - ) + return fn(*args, **kwargs) + finally: + _verified_identity.reset(reset) - from iris.cluster.controller.transitions import Assignment, HeartbeatApplyRequest, TaskUpdate - - for idx in range(2): - task_id = job_id.task(idx) - state.queue_assignments([Assignment(task_id=task_id, worker_id=w1)]) - state.apply_task_updates( - HeartbeatApplyRequest( - worker_id=w1, - worker_resource_snapshot=None, - updates=[TaskUpdate(task_id=task_id, attempt_id=0, new_state=job_pb2.TASK_STATE_RUNNING)], - ) - ) - with db.snapshot() as snap: - spend = compute_user_spend(snap) +def _set_budget(user_id: str, limit: int = 5000, max_band: int = INTERACTIVE): + return controller_pb2.Controller.SetUserBudgetRequest(user_id=user_id, budget_limit=limit, max_band=max_band) - # Each task: 4 cores (5*4=20) + 16 GiB (16) + 0 accelerators = 36 - expected_per_task = resource_value(4000, 16 * GiB, 0) - assert spend["alice"] == expected_per_task * 2 - finally: - db.close() +def _get_budget(user_id: str): + return controller_pb2.Controller.GetUserBudgetRequest(user_id=user_id) -def test_compute_user_spend_only_counts_active_states(tmp_path): - """Completed tasks do not contribute to spend.""" - db = ControllerDB(db_dir=tmp_path) - state = ControllerTransitions(db=db) - try: - job_id = JobName.root("bob", "done-job") - request = controller_pb2.Controller.LaunchJobRequest( - name=job_id.to_wire(), - entrypoint=job_pb2.RuntimeEntrypoint(), - resources=job_pb2.ResourceSpecProto(cpu_millicores=2000, memory_bytes=8 * GiB), - environment=job_pb2.EnvironmentConfig(), - replicas=1, - ) - request.entrypoint.run_command.argv[:] = ["echo", "hi"] - state.submit_job(job_id, request, Timestamp.now()) - # Task is PENDING (not active for budget purposes since ACTIVE = ASSIGNED/BUILDING/RUNNING) - with db.snapshot() as snap: - spend = compute_user_spend(snap) - assert spend.get("bob", 0) == 0 - finally: - db.close() +def _launch(name: str, band: int = 0): + return _launch_request( + name, + cpu_millicores=1000, + memory_bytes=GiB, + band=band, + ) -# --------------------------------------------------------------------------- -# interleave_by_user ��� three-user round-robin -# --------------------------------------------------------------------------- +def test_admin_sets_and_reads_budget(service): + _as_admin(service.set_user_budget, _set_budget("alice", 5000, INTERACTIVE), None) + resp = _as_admin(service.get_user_budget, _get_budget("alice"), None) + assert resp.user_id == "alice" + assert resp.budget_limit == 5000 + assert resp.max_band == INTERACTIVE + assert resp.budget_spent == 0 -def test_compute_user_spend_null_resources_proto(tmp_path): - """Jobs submitted without a resources field must not crash compute_user_spend. +def test_non_admin_cannot_set_budget(service): + with pytest.raises(ConnectError) as exc: + _as_user(service.set_user_budget, "alice", _set_budget("alice"), None) + assert exc.value.code == Code.PERMISSION_DENIED - Regression: resources_proto is NULL in the DB when the LaunchJobRequest omits - the resources field. compute_user_spend must treat that as zero spend. - """ - db = ControllerDB(db_dir=tmp_path) - state = ControllerTransitions(db=db) - try: - job_id = JobName.root("carol", "no-resources") - # Intentionally omit the `resources` field - request = controller_pb2.Controller.LaunchJobRequest( - name=job_id.to_wire(), - entrypoint=job_pb2.RuntimeEntrypoint(), - environment=job_pb2.EnvironmentConfig(), - replicas=1, - ) - request.entrypoint.run_command.argv[:] = ["echo", "hi"] - state.submit_job(job_id, request, Timestamp.now()) - - # Register worker and move the task to RUNNING so it counts as active - w1 = WorkerId("w1") - state.register_or_refresh_worker( - worker_id=w1, - address="w1:8080", - metadata=job_pb2.WorkerMetadata( - hostname="w1", - ip_address="127.0.0.1", - cpu_count=8, - memory_bytes=16 * GiB, - disk_bytes=100 * GiB, - ), - ts=Timestamp.now(), - ) - from iris.cluster.controller.transitions import Assignment, HeartbeatApplyRequest, TaskUpdate - task_id = job_id.task(0) - state.queue_assignments([Assignment(task_id=task_id, worker_id=w1)]) - state.apply_task_updates( - HeartbeatApplyRequest( - worker_id=w1, - worker_resource_snapshot=None, - updates=[TaskUpdate(task_id=task_id, attempt_id=0, new_state=job_pb2.TASK_STATE_RUNNING)], - ) +def test_user_can_read_own_budget(service): + """get_user_budget requires identity, not admin.""" + _as_admin(service.set_user_budget, _set_budget("alice", 5000), None) + resp = _as_user(service.get_user_budget, "alice", _get_budget("alice"), None) + assert resp.user_id == "alice" + assert resp.budget_limit == 5000 + + +def test_get_budget_not_found(service): + with pytest.raises(ConnectError) as exc: + _as_admin(service.get_user_budget, _get_budget("nonexistent"), None) + assert exc.value.code == Code.NOT_FOUND + + +def test_set_budget_rejects_invalid_max_band(service): + with pytest.raises(ConnectError) as exc: + _as_admin(service.set_user_budget, _set_budget("alice", 5000, max_band=99), None) + assert exc.value.code == Code.INVALID_ARGUMENT + + +def test_set_budget_rejects_empty_user_id(service): + with pytest.raises(ConnectError) as exc: + _as_admin(service.set_user_budget, _set_budget(""), None) + assert exc.value.code == Code.INVALID_ARGUMENT + + +def test_list_user_budgets(service): + for user_id, limit, band in [ + ("alice", 5000, INTERACTIVE), + ("bob", 3000, BATCH), + ("charlie", 0, PRODUCTION), + ]: + _as_admin(service.set_user_budget, _set_budget(user_id, limit, band), None) + + resp = _as_admin( + service.list_user_budgets, + controller_pb2.Controller.ListUserBudgetsRequest(), + None, + ) + by_user = {u.user_id: u for u in resp.users} + assert set(by_user) == {"alice", "bob", "charlie"} + assert by_user["alice"].budget_limit == 5000 + assert by_user["bob"].max_band == BATCH + assert by_user["charlie"].max_band == PRODUCTION + + +def test_non_admin_cannot_submit_production(service): + with pytest.raises(ConnectError) as exc: + _as_user(service.launch_job, "alice", _launch("/alice/prod-job", band=PRODUCTION), None) + assert exc.value.code == Code.PERMISSION_DENIED + + +def test_admin_can_submit_production(service): + resp = _as_admin(service.launch_job, _launch("/admin/prod-job", band=PRODUCTION), None) + assert resp.job_id == "/admin/prod-job" + + +def test_launch_job_rejects_band_above_user_max(service): + """User with max_band=BATCH cannot submit INTERACTIVE (numerically lower) jobs.""" + _as_admin(service.set_user_budget, _set_budget("alice", 0, BATCH), None) + with pytest.raises(ConnectError) as exc: + _as_user( + service.launch_job, + "alice", + _launch("/alice/interactive-job", band=INTERACTIVE), + None, ) + assert exc.value.code == Code.PERMISSION_DENIED + assert "cannot submit" in str(exc.value.message).lower() - with db.snapshot() as snap: - spend = compute_user_spend(snap) - # Zero resources → zero spend, but must not crash - assert spend.get("carol", 0) == 0 - finally: - db.close() +def test_launch_job_unspecified_band_accepted(service): + """Submitting with band=0 (UNSPECIFIED) is accepted and defaults to INTERACTIVE.""" + resp = _as_user(service.launch_job, "alice", _launch("/alice/default-band-job", band=0), None) + assert resp.job_id == "/alice/default-band-job" -def test_interleave_by_user_three_users(): - """Three users with different task counts and spend levels.""" - tasks = [ - UserTask("alice", "a1"), - UserTask("alice", "a2"), - UserTask("bob", "b1"), - UserTask("charlie", "c1"), - UserTask("charlie", "c2"), - UserTask("charlie", "c3"), - ] - spend = {"alice": 5000, "bob": 100, "charlie": 3000} - result = interleave_by_user(tasks, spend) - # Bob (100) → charlie (3000) → alice (5000) ordering - assert result[0] == "b1" - assert result[1] == "c1" - assert result[2] == "a1" - # Round 1: only charlie and alice have tasks left - assert result[3] == "c2" - assert result[4] == "a2" - # Round 2: only charlie left - assert result[5] == "c3" + +def test_get_budget_spend_reflects_running_task(service, state): + _as_admin(service.set_user_budget, _set_budget("alice", 10000, INTERACTIVE), None) + _start_running_job(state, "alice", "running-job", cpu_millicores=1000, memory_bytes=GiB) + resp = _as_admin(service.get_user_budget, _get_budget("alice"), None) + assert resp.budget_spent > 0 diff --git a/lib/iris/tests/test_budget_api.py b/lib/iris/tests/test_budget_api.py deleted file mode 100644 index 0dcdb3d7cc..0000000000 --- a/lib/iris/tests/test_budget_api.py +++ /dev/null @@ -1,407 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for budget admin API: set/get/list RPCs, auth enforcement, and band validation.""" - -import pytest -from connectrpc.code import Code -from connectrpc.errors import ConnectError - -from iris.cluster.controller.service import ControllerServiceImpl -from iris.cluster.controller.transitions import Assignment, HeartbeatApplyRequest, TaskUpdate -from iris.cluster.types import JobName, WorkerId -from iris.log_server.server import LogServiceImpl -from iris.rpc import job_pb2 -from iris.rpc import controller_pb2 -from iris.rpc.auth import VerifiedIdentity, _verified_identity -from rigging.timing import Timestamp - -from tests.cluster.controller.conftest import ( - MockController, - make_controller_state, - make_test_entrypoint, -) -from iris.cluster.bundle import BundleStore -from iris.cluster.controller.auth import ControllerAuth - - -@pytest.fixture -def state(): - with make_controller_state() as s: - yield s - - -@pytest.fixture -def mock_controller(): - return MockController() - - -def _make_service(state, mock_controller, tmp_path, auth=None): - return ControllerServiceImpl( - state, - state._db, - controller=mock_controller, - bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), - log_service=LogServiceImpl(), - auth=auth or ControllerAuth(), - ) - - -def _as_admin(fn, *args, **kwargs): - """Run fn with admin identity.""" - reset = _verified_identity.set(VerifiedIdentity(user_id="admin", role="admin")) - try: - return fn(*args, **kwargs) - finally: - _verified_identity.reset(reset) - - -def _as_user(fn, user_id="alice", *args, **kwargs): - """Run fn with user identity.""" - reset = _verified_identity.set(VerifiedIdentity(user_id=user_id, role="user")) - try: - return fn(*args, **kwargs) - finally: - _verified_identity.reset(reset) - - -def _make_job_request(name: str, band: int = 0) -> controller_pb2.Controller.LaunchJobRequest: - return controller_pb2.Controller.LaunchJobRequest( - name=name, - entrypoint=make_test_entrypoint(), - resources=job_pb2.ResourceSpecProto(cpu_millicores=1000, memory_bytes=1024**3), - environment=job_pb2.EnvironmentConfig(), - replicas=1, - priority_band=band, - ) - - -# --------------------------------------------------------------------------- -# test_admin_can_set_budget -# --------------------------------------------------------------------------- - - -def test_admin_can_set_budget(state, mock_controller, tmp_path): - service = _make_service(state, mock_controller, tmp_path) - - # Set budget as admin - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id="alice", - budget_limit=5000, - max_band=job_pb2.PRIORITY_BAND_INTERACTIVE, - ), - None, - ) - - # Get budget back - resp = _as_admin( - service.get_user_budget, - controller_pb2.Controller.GetUserBudgetRequest(user_id="alice"), - None, - ) - assert resp.user_id == "alice" - assert resp.budget_limit == 5000 - assert resp.max_band == job_pb2.PRIORITY_BAND_INTERACTIVE - assert resp.budget_spent == 0 - - -# --------------------------------------------------------------------------- -# test_non_admin_cannot_set_budget -# --------------------------------------------------------------------------- - - -def test_non_admin_cannot_set_budget(state, mock_controller, tmp_path): - service = _make_service(state, mock_controller, tmp_path) - - with pytest.raises(ConnectError) as exc_info: - _as_user( - service.set_user_budget, - "alice", - controller_pb2.Controller.SetUserBudgetRequest( - user_id="alice", - budget_limit=5000, - max_band=job_pb2.PRIORITY_BAND_INTERACTIVE, - ), - None, - ) - assert exc_info.value.code == Code.PERMISSION_DENIED - - -# --------------------------------------------------------------------------- -# test_non_admin_cannot_submit_production -# --------------------------------------------------------------------------- - - -def test_non_admin_cannot_submit_production(state, mock_controller, tmp_path): - auth = ControllerAuth(provider="static") - service = _make_service(state, mock_controller, tmp_path, auth=auth) - - # Submit PRODUCTION band as non-admin user -> should fail - request = _make_job_request("/alice/prod-job", band=job_pb2.PRIORITY_BAND_PRODUCTION) - with pytest.raises(ConnectError) as exc_info: - _as_user( - service.launch_job, - "alice", - request, - None, - ) - assert exc_info.value.code == Code.PERMISSION_DENIED - - -# --------------------------------------------------------------------------- -# test_band_validation_max_band -# --------------------------------------------------------------------------- - - -def test_band_validation_rejects_above_max_band(state, mock_controller, tmp_path): - """User with max_band=BATCH cannot submit INTERACTIVE jobs.""" - auth = ControllerAuth(provider="static") - service = _make_service(state, mock_controller, tmp_path, auth=auth) - - # Set alice's max_band to BATCH (sort key 3) - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id="alice", - budget_limit=0, - max_band=job_pb2.PRIORITY_BAND_BATCH, - ), - None, - ) - - # Try to submit INTERACTIVE (sort key 2 < 3) -> should fail - request = _make_job_request("/alice/interactive-job", band=job_pb2.PRIORITY_BAND_INTERACTIVE) - with pytest.raises(ConnectError) as exc_info: - _as_user( - service.launch_job, - "alice", - request, - None, - ) - assert exc_info.value.code == Code.PERMISSION_DENIED - assert "cannot submit" in str(exc_info.value.message).lower() - - -# --------------------------------------------------------------------------- -# test_budget_spend_reflects_running_tasks -# --------------------------------------------------------------------------- - - -def test_budget_spend_reflects_running_tasks(state, mock_controller, tmp_path): - service = _make_service(state, mock_controller, tmp_path) - - # Set budget for alice - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id="alice", - budget_limit=10000, - max_band=job_pb2.PRIORITY_BAND_INTERACTIVE, - ), - None, - ) - - # Submit a job as alice (directly via transitions) - job_id = JobName.root("alice", "test-job") - request = _make_job_request(job_id.to_wire()) - state.submit_job(job_id, request, Timestamp.now()) - - # Register a worker and assign the task - worker_id = WorkerId("worker-1") - state.register_or_refresh_worker( - worker_id=worker_id, - address="worker-1:8080", - metadata=job_pb2.WorkerMetadata( - hostname="worker-1", - ip_address="127.0.0.1", - cpu_count=8, - memory_bytes=16 * 1024**3, - disk_bytes=100 * 1024**3, - ), - ts=Timestamp.now(), - ) - - task_id = JobName.from_wire(f"{job_id.to_wire()}/0") - state.queue_assignments([Assignment(task_id=task_id, worker_id=worker_id)]) - state.apply_task_updates( - HeartbeatApplyRequest( - worker_id=worker_id, - worker_resource_snapshot=None, - updates=[TaskUpdate(task_id=task_id, attempt_id=0, new_state=job_pb2.TASK_STATE_RUNNING)], - ) - ) - - # Get budget — spend should be non-zero - resp = _as_admin( - service.get_user_budget, - controller_pb2.Controller.GetUserBudgetRequest(user_id="alice"), - None, - ) - assert resp.budget_spent > 0 - - -# --------------------------------------------------------------------------- -# test_list_user_budgets -# --------------------------------------------------------------------------- - - -def test_list_user_budgets(state, mock_controller, tmp_path): - service = _make_service(state, mock_controller, tmp_path) - - # Set budgets for multiple users - for user_id, limit, band in [ - ("alice", 5000, job_pb2.PRIORITY_BAND_INTERACTIVE), - ("bob", 3000, job_pb2.PRIORITY_BAND_BATCH), - ("charlie", 0, job_pb2.PRIORITY_BAND_PRODUCTION), - ]: - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id=user_id, - budget_limit=limit, - max_band=band, - ), - None, - ) - - # List all budgets - resp = _as_admin( - service.list_user_budgets, - controller_pb2.Controller.ListUserBudgetsRequest(), - None, - ) - assert len(resp.users) == 3 - user_ids = {u.user_id for u in resp.users} - assert user_ids == {"alice", "bob", "charlie"} - - # Verify specific budget values - by_user = {u.user_id: u for u in resp.users} - assert by_user["alice"].budget_limit == 5000 - assert by_user["bob"].budget_limit == 3000 - assert by_user["charlie"].budget_limit == 0 - assert by_user["charlie"].max_band == job_pb2.PRIORITY_BAND_PRODUCTION - - -# --------------------------------------------------------------------------- -# test_admin_can_submit_production -# --------------------------------------------------------------------------- - - -def test_admin_can_submit_production(state, mock_controller, tmp_path): - """Admin should be able to submit production jobs.""" - auth = ControllerAuth(provider="static") - service = _make_service(state, mock_controller, tmp_path, auth=auth) - - request = _make_job_request("/admin/prod-job", band=job_pb2.PRIORITY_BAND_PRODUCTION) - resp = _as_admin(service.launch_job, request, None) - assert resp.job_id == "/admin/prod-job" - - -# --------------------------------------------------------------------------- -# test_get_budget_not_found -# --------------------------------------------------------------------------- - - -def test_get_budget_not_found(state, mock_controller, tmp_path): - service = _make_service(state, mock_controller, tmp_path) - - with pytest.raises(ConnectError) as exc_info: - _as_admin( - service.get_user_budget, - controller_pb2.Controller.GetUserBudgetRequest(user_id="nonexistent"), - None, - ) - assert exc_info.value.code == Code.NOT_FOUND - - -# --------------------------------------------------------------------------- -# test_set_budget_invalid_max_band -# --------------------------------------------------------------------------- - - -def test_set_budget_invalid_max_band(state, mock_controller, tmp_path): - """Setting an invalid max_band value is rejected.""" - service = _make_service(state, mock_controller, tmp_path) - - with pytest.raises(ConnectError) as exc_info: - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id="alice", - budget_limit=5000, - max_band=99, # invalid - ), - None, - ) - assert exc_info.value.code == Code.INVALID_ARGUMENT - - -# --------------------------------------------------------------------------- -# test_set_budget_empty_user_id -# --------------------------------------------------------------------------- - - -def test_set_budget_empty_user_id(state, mock_controller, tmp_path): - """Setting budget with empty user_id is rejected.""" - service = _make_service(state, mock_controller, tmp_path) - - with pytest.raises(ConnectError) as exc_info: - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id="", - budget_limit=5000, - max_band=job_pb2.PRIORITY_BAND_INTERACTIVE, - ), - None, - ) - assert exc_info.value.code == Code.INVALID_ARGUMENT - - -# --------------------------------------------------------------------------- -# test_default_band_submission -# --------------------------------------------------------------------------- - - -def test_unspecified_band_defaults_to_interactive(state, mock_controller, tmp_path): - """Submitting with UNSPECIFIED (0) band defaults to INTERACTIVE.""" - auth = ControllerAuth(provider="static") - service = _make_service(state, mock_controller, tmp_path, auth=auth) - - # Submit with band=0 (UNSPECIFIED) - request = _make_job_request("/alice/default-band-job", band=0) - resp = _as_user(service.launch_job, "alice", request, None) - assert resp.job_id == "/alice/default-band-job" - - -# --------------------------------------------------------------------------- -# test_user_can_read_own_budget -# --------------------------------------------------------------------------- - - -def test_user_can_read_own_budget(state, mock_controller, tmp_path): - """Non-admin users can read budget info (require_identity, not require admin).""" - service = _make_service(state, mock_controller, tmp_path) - - # Set budget as admin first - _as_admin( - service.set_user_budget, - controller_pb2.Controller.SetUserBudgetRequest( - user_id="alice", - budget_limit=5000, - max_band=job_pb2.PRIORITY_BAND_INTERACTIVE, - ), - None, - ) - - # Read as non-admin user - resp = _as_user( - service.get_user_budget, - "alice", - controller_pb2.Controller.GetUserBudgetRequest(user_id="alice"), - None, - ) - assert resp.user_id == "alice" - assert resp.budget_limit == 5000