Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions rctab/routers/accounting/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Optional, Union
from uuid import UUID

from databases import Database
from fastapi import APIRouter
from pydantic import BaseModel
from sqlalchemy import desc, func, select, true
Expand Down Expand Up @@ -512,3 +513,26 @@ def get_subscription_name(sub_id: Optional[UUID] = None) -> Select:
return select(subscription_details.c.display_name.label("name")).where(
subscription_details.c.subscription_id == sub_id
)


async def get_subscription_id(database: Database, display_name: str) -> list[UUID]:
"""Get the subscription ID(s) from a display name."""
details = subscription_details # alias for brevity

subq = select(
[
details.c.subscription_id,
details.c.display_name,
func.row_number()
.over(partition_by=details.c.subscription_id, order_by=details.c.id.desc())
.label("rank"),
]
).alias("subq")

results = await database.fetch_all(
select(subq.c.subscription_id).where(
(subq.c.rank == 1) & (subq.c.display_name == display_name)
)
)

return [result[0] for result in results]
10 changes: 10 additions & 0 deletions rctab/routers/accounting/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rctab.crud.models import database
from rctab.routers.accounting.routes import (
SubscriptionItem,
get_subscription_id,
get_subscriptions_with_disable,
router,
)
Expand Down Expand Up @@ -53,3 +54,12 @@ async def post_subscription(
"status": "success",
"detail": f"Added subscription {subscription.sub_id} to RCTab",
}


@router.get("/subscription-id")
async def get_sub_id(
display_name: str, _: UserRBAC = Depends(token_admin_verified)
) -> list[UUID]:
"""Get the subscription ID given a display name."""
results = await get_subscription_id(database, display_name)
return results
42 changes: 42 additions & 0 deletions tests/test_routes/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from rctab.crud.models import database
from rctab.routers.accounting.desired_states import refresh_desired_states
from rctab.routers.accounting.routes import get_subscription_id
from tests.test_routes import constants


Expand Down Expand Up @@ -379,3 +380,44 @@ async def test_refresh_desired_states_doesnt_duplicate(
assert disabled_subscriptions == [
over_budget_sub_id,
]


@pytest.mark.asyncio
async def test_get_subscription_id(test_db: Database) -> None:
"""The returned statement should select nothing, a single ID or raise."""
# Check that we get None if there isn't a match.
result = await get_subscription_id(test_db, "some display name")
assert result == []

# Check things work for a single subscription_details entry.
sub1 = await create_subscription(
test_db,
# Note create_subscription has a hard-coded display name.
current_state=SubscriptionState.ENABLED,
)
result = await get_subscription_id(test_db, "a subscription")
assert result == [sub1]

# Check a single subscription with > 1 name entry.
await test_db.execute(
subscription_details.insert().values(),
SubscriptionStatus(
subscription_id=sub1,
state=SubscriptionState.ENABLED,
display_name="New-Subscription-Name",
role_assignments=(),
).model_dump(),
)
result = await get_subscription_id(test_db, "My-Subscription-Name")
assert result == []

result = await get_subscription_id(test_db, "New-Subscription-Name")
assert result == [sub1]

# Check two subscriptions with the same name
# (note I don't know whether Azure allows this).
sub1 = await create_subscription(test_db, current_state=SubscriptionState.ENABLED)
sub2 = await create_subscription(test_db, current_state=SubscriptionState.ENABLED)

result = await get_subscription_id(test_db, "a subscription")
assert result in ([sub1, sub2], [sub2, sub1])
61 changes: 61 additions & 0 deletions tests/test_routes/test_subscription.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple
from uuid import UUID

from fastapi import FastAPI
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -80,3 +81,63 @@ def test_get_subscription_summary(
assert result.status_code == 200

api_calls.assert_subscription_status(client, expected_details=expected_details)


# here
def test_get_subscription_id(
app_with_signed_status_and_controller_tokens: Tuple[FastAPI, str, str],
) -> None:
"""Returns a subscription id, given a subscription name."""
(auth_app, status_token, _) = app_with_signed_status_and_controller_tokens

with TestClient(auth_app) as client:
# Check the scenario with no matches.
result = client.get(
PREFIX + "/subscription-id",
params={"display_name": "-"},
)
result.raise_for_status()
assert result.json() == []

# Add a subscription and make sure we can get its ID.
api_calls.create_subscription(
client, constants.TEST_SUB_UUID
).raise_for_status()

api_calls.create_subscription_detail(
client=client,
token=status_token,
subscription_id=constants.TEST_SUB_UUID,
state=SubscriptionState.ENABLED,
display_name="MyDisplayName",
).raise_for_status()

result = client.get(
PREFIX + "/subscription-id",
params={"display_name": "MyDisplayName"},
)
result.raise_for_status()
assert [UUID(x) for x in result.json()] == [constants.TEST_SUB_UUID]

# Check multiple matches.
api_calls.create_subscription(
client, constants.TEST_SUB_2_UUID
).raise_for_status()

api_calls.create_subscription_detail(
client=client,
token=status_token,
subscription_id=constants.TEST_SUB_2_UUID,
state=SubscriptionState.ENABLED,
display_name="MyDisplayName",
).raise_for_status()

result = client.get(
PREFIX + "/subscription-id",
params={"display_name": "MyDisplayName"},
)
result.raise_for_status()
assert [UUID(x) for x in result.json()] in [
[constants.TEST_SUB_2_UUID, constants.TEST_SUB_UUID],
[constants.TEST_SUB_UUID, constants.TEST_SUB_2_UUID],
]
Loading