Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions src/shared/management/commands/fetch_all_channels.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import sys
from collections.abc import Coroutine
from dataclasses import dataclass
from pprint import pprint
from pprint import pformat
from typing import Any

import requests
Expand All @@ -13,11 +11,11 @@
from shared.models.nix_evaluation import NixChannel


@dataclass
class MonitoredChannel:
name: str
revision: str
status: str
def __init__(self, name: str, revision: str, status: str) -> None:
self.name = name
self.revision = revision
self.status = status


def release_from_branch(branch: str) -> str | None:
Expand Down Expand Up @@ -92,39 +90,34 @@ def fetch_from_monitoring() -> dict[str, MonitoredChannel]:
return aggregate_by_channels(resp.json()["data"]["result"])


async def wait_for_parallel_fetches(
parallel_fetches: list[Coroutine[Any, Any, bool]],
) -> list[Any]:
return await asyncio.gather(*parallel_fetches, return_exceptions=True)


class Command(BaseCommand):
help = "Register Nix channels"

def handle(self, *args: Any, **kwargs: Any) -> str | None:
fresh_channels = fetch_from_monitoring()

registered: list[dict[str, Any]] = []
for channel in fresh_channels.values():
channel_branch = channel.name
staging_branch = staging_from_branch(channel.name)
branch_info = {
"staging_branch": staging_branch,
branch_info: dict[str, Any] = {
"channel_branch": channel.name,
"staging_branch": staging_from_branch(channel.name),
"state": state_from_status(channel.status),
"head_sha1_commit": channel.revision,
"release_version": release_from_branch(channel.name),
}
pprint(branch_info | {"channel_branch": channel.name})
NixChannel.objects.update_or_create(
branch_info, channel_branch=channel_branch
)
NixChannel.objects.update_or_create(branch_info, channel_branch=channel.name)
registered.append(branch_info)

repo = GitRepo(
settings.LOCAL_NIXPKGS_CHECKOUT,
stderr=sys.stderr.fileno(),
)
parallel_fetches = []
for channel in NixChannel.objects.iterator():
parallel_fetches.append(repo.update_from_ref(channel.head_sha1_commit))
results = asyncio.run(
asyncio.gather(
*[repo.update_from_ref(info["head_sha1_commit"]) for info in registered],
return_exceptions=True,
)
)

results = asyncio.run(wait_for_parallel_fetches(parallel_fetches))
# FIXME(@fricklerhandwerk): Fold that into `branch_info`, so there's only one output.
print("Parallel fetches results", results)
for branch_info, result in zip(registered, results):
self.stdout.write(pformat(branch_info | {"fetched": result}))
48 changes: 48 additions & 0 deletions src/shared/tests/test_fetch_all_channels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import io
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from django.core.management import call_command

from shared.management.commands.fetch_all_channels import MonitoredChannel
from shared.models.nix_evaluation import NixChannel


@pytest.mark.django_db
@patch("shared.management.commands.fetch_all_channels.fetch_from_monitoring")
@patch("shared.management.commands.fetch_all_channels.GitRepo")
@patch(
"shared.management.commands.fetch_all_channels.asyncio.gather",
new_callable=AsyncMock,
)
def test_command_upserts_channels_and_reports_fetch_results(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good test to have!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

mock_gather: AsyncMock,
mock_git_repo_class: MagicMock,
mock_fetch_monitoring: MagicMock,
) -> None:
mock_fetch_monitoring.return_value = {
"nixos-24.11": MonitoredChannel(
name="nixos-24.11", revision="1234567890abcdef", status="stable"
),
"nixos-unstable": MonitoredChannel(
name="nixos-unstable", revision="aabbcc0011223344", status="rolling"
),
}

mock_gather.return_value = [True, False]

out = io.StringIO()
call_command("fetch_all_channels", stdout=out)

assert (
NixChannel.objects.get(channel_branch="nixos-24.11").state
== NixChannel.ChannelState.STABLE
)
assert (
NixChannel.objects.get(channel_branch="nixos-unstable").state
== NixChannel.ChannelState.UNSTABLE
)

output = out.getvalue()
assert "nixos-24.11" in output and "True" in output
assert "nixos-unstable" in output and "False" in output