From 14e5e7897ab908ab851223dda1a8d12994f858d3 Mon Sep 17 00:00:00 2001 From: Iain-S <25081046+Iain-S@users.noreply.github.com> Date: Mon, 9 Feb 2026 10:00:03 +0000 Subject: [PATCH 1/3] Add get_subscription_id function --- rctab/routers/accounting/routes.py | 24 +++++++++++++++++ tests/test_routes/test_routes.py | 42 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/rctab/routers/accounting/routes.py b/rctab/routers/accounting/routes.py index 8ad7145..b604b9c 100644 --- a/rctab/routers/accounting/routes.py +++ b/rctab/routers/accounting/routes.py @@ -5,6 +5,7 @@ from typing import List, Optional, Union from uuid import UUID +from databases import Database from fastapi import APIRouter from pydantic import BaseModel from sqlalchemy import desc, func, select, true @@ -512,3 +513,26 @@ def get_subscription_name(sub_id: Optional[UUID] = None) -> Select: return select(subscription_details.c.display_name.label("name")).where( subscription_details.c.subscription_id == sub_id ) + + +async def get_subscription_id(database: Database, display_name: str) -> list[UUID]: + """Get the subscription ID(s) from a display name.""" + details = subscription_details # alias for brevity + + subq = select( + [ + details.c.subscription_id, + details.c.display_name, + func.row_number() + .over(partition_by=details.c.subscription_id, order_by=details.c.id.desc()) + .label("rank"), + ] + ).alias("subq") + + results = await database.fetch_all( + select(subq.c.subscription_id).where( + (subq.c.rank == 1) & (subq.c.display_name == display_name) + ) + ) + + return [result[0] for result in results] diff --git a/tests/test_routes/test_routes.py b/tests/test_routes/test_routes.py index dd10a72..952d78b 100644 --- a/tests/test_routes/test_routes.py +++ b/tests/test_routes/test_routes.py @@ -32,6 +32,7 @@ ) from rctab.crud.models import database from rctab.routers.accounting.desired_states import refresh_desired_states +from rctab.routers.accounting.routes import get_subscription_id from tests.test_routes import constants @@ -379,3 +380,44 @@ async def test_refresh_desired_states_doesnt_duplicate( assert disabled_subscriptions == [ over_budget_sub_id, ] + + +@pytest.mark.asyncio +async def test_get_subscription_id(test_db: Database) -> None: + """The returned statement should select nothing, a single ID or raise.""" + # Check that we get None if there isn't a match. + result = await get_subscription_id(test_db, "My-Subscription-Name") + assert result == [] + + # Check things work for a single subscription_details entry. + sub1 = await create_subscription( + test_db, + # Note create_subscription has a hard-coded display name. + current_state=SubscriptionState.ENABLED, + ) + result = await get_subscription_id(test_db, "My-Subscription-Name") + assert result == [sub1] + + # Check a single subscription with > 1 name entry. + await test_db.execute( + subscription_details.insert().values(), + SubscriptionStatus( + subscription_id=sub1, + state=SubscriptionState.ENABLED, + display_name="New-Subscription-Name", + role_assignments=(), + ).model_dump(), + ) + result = await get_subscription_id(test_db, "My-Subscription-Name") + assert result == [] + + result = await get_subscription_id(test_db, "New-Subscription-Name") + assert result == [sub1] + + # Check two subscriptions with the same name + # (note I don't know whether Azure allows this). + sub1 = await create_subscription(test_db, current_state=SubscriptionState.ENABLED) + sub2 = await create_subscription(test_db, current_state=SubscriptionState.ENABLED) + + result = await get_subscription_id(test_db, "My-Subscription-Name") + assert result in ([sub1, sub2], [sub2, sub1]) From 695742bca85c004217246610cd75b202f94af3d8 Mon Sep 17 00:00:00 2001 From: Iain-S <25081046+Iain-S@users.noreply.github.com> Date: Mon, 9 Feb 2026 10:43:59 +0000 Subject: [PATCH 2/3] Add /subscription-id route --- rctab/routers/accounting/subscription.py | 10 ++++ tests/test_routes/test_subscription.py | 61 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/rctab/routers/accounting/subscription.py b/rctab/routers/accounting/subscription.py index 349fd9b..f4ed194 100644 --- a/rctab/routers/accounting/subscription.py +++ b/rctab/routers/accounting/subscription.py @@ -12,6 +12,7 @@ from rctab.crud.models import database from rctab.routers.accounting.routes import ( SubscriptionItem, + get_subscription_id, get_subscriptions_with_disable, router, ) @@ -53,3 +54,12 @@ async def post_subscription( "status": "success", "detail": f"Added subscription {subscription.sub_id} to RCTab", } + + +@router.get("/subscription-id") +async def get_sub_id( + display_name: str, _: UserRBAC = Depends(token_admin_verified) +) -> list[UUID]: + """Get the subscription ID given a display name.""" + results = await get_subscription_id(database, display_name) + return results diff --git a/tests/test_routes/test_subscription.py b/tests/test_routes/test_subscription.py index 609af3f..546972b 100644 --- a/tests/test_routes/test_subscription.py +++ b/tests/test_routes/test_subscription.py @@ -1,4 +1,5 @@ from typing import Tuple +from uuid import UUID from fastapi import FastAPI from fastapi.testclient import TestClient @@ -80,3 +81,63 @@ def test_get_subscription_summary( assert result.status_code == 200 api_calls.assert_subscription_status(client, expected_details=expected_details) + + +# here +def test_get_subscription_id( + app_with_signed_status_and_controller_tokens: Tuple[FastAPI, str, str], +) -> None: + """Returns a subscription id, given a subscription name.""" + (auth_app, status_token, _) = app_with_signed_status_and_controller_tokens + + with TestClient(auth_app) as client: + # Check the scenario with no matches. + result = client.get( + PREFIX + "/subscription-id", + params={"display_name": "-"}, + ) + result.raise_for_status() + assert result.json() == [] + + # Add a subscription and make sure we can get its ID. + api_calls.create_subscription( + client, constants.TEST_SUB_UUID + ).raise_for_status() + + api_calls.create_subscription_detail( + client=client, + token=status_token, + subscription_id=constants.TEST_SUB_UUID, + state=SubscriptionState.ENABLED, + display_name="MyDisplayName", + ).raise_for_status() + + result = client.get( + PREFIX + "/subscription-id", + params={"display_name": "MyDisplayName"}, + ) + result.raise_for_status() + assert [UUID(x) for x in result.json()] == [constants.TEST_SUB_UUID] + + # Check multiple matches. + api_calls.create_subscription( + client, constants.TEST_SUB_2_UUID + ).raise_for_status() + + api_calls.create_subscription_detail( + client=client, + token=status_token, + subscription_id=constants.TEST_SUB_2_UUID, + state=SubscriptionState.ENABLED, + display_name="MyDisplayName", + ).raise_for_status() + + result = client.get( + PREFIX + "/subscription-id", + params={"display_name": "MyDisplayName"}, + ) + result.raise_for_status() + assert [UUID(x) for x in result.json()] in [ + [constants.TEST_SUB_2_UUID, constants.TEST_SUB_UUID], + [constants.TEST_SUB_UUID, constants.TEST_SUB_2_UUID], + ] From ef19fbff55efed4b94cea34778fd0e65f8779ec0 Mon Sep 17 00:00:00 2001 From: Iain-S <25081046+Iain-S@users.noreply.github.com> Date: Mon, 9 Feb 2026 12:01:05 +0000 Subject: [PATCH 3/3] Fix unit test failures --- tests/test_routes/test_routes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_routes/test_routes.py b/tests/test_routes/test_routes.py index 952d78b..4fa26d6 100644 --- a/tests/test_routes/test_routes.py +++ b/tests/test_routes/test_routes.py @@ -386,7 +386,7 @@ async def test_refresh_desired_states_doesnt_duplicate( async def test_get_subscription_id(test_db: Database) -> None: """The returned statement should select nothing, a single ID or raise.""" # Check that we get None if there isn't a match. - result = await get_subscription_id(test_db, "My-Subscription-Name") + result = await get_subscription_id(test_db, "some display name") assert result == [] # Check things work for a single subscription_details entry. @@ -395,7 +395,7 @@ async def test_get_subscription_id(test_db: Database) -> None: # Note create_subscription has a hard-coded display name. current_state=SubscriptionState.ENABLED, ) - result = await get_subscription_id(test_db, "My-Subscription-Name") + result = await get_subscription_id(test_db, "a subscription") assert result == [sub1] # Check a single subscription with > 1 name entry. @@ -419,5 +419,5 @@ async def test_get_subscription_id(test_db: Database) -> None: sub1 = await create_subscription(test_db, current_state=SubscriptionState.ENABLED) sub2 = await create_subscription(test_db, current_state=SubscriptionState.ENABLED) - result = await get_subscription_id(test_db, "My-Subscription-Name") + result = await get_subscription_id(test_db, "a subscription") assert result in ([sub1, sub2], [sub2, sub1])