Skip to content

Commit 47c708a

Browse files
committed
jobs: add type hints
Signed-off-by: Gaëtan Lehmann <gaetan.lehmann@vates.tech>
1 parent 07389bf commit 47c708a

File tree

1 file changed

+47
-30
lines changed

1 file changed

+47
-30
lines changed

jobs.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77

88
from lib.commands import ssh
99

10-
JOBS = {
10+
from typing import NotRequired, TypedDict, cast
11+
12+
class JobData(TypedDict):
13+
description: str
14+
requirements: list[str]
15+
nb_pools: int
16+
params: dict[str, str]
17+
paths: list[str]
18+
markers: NotRequired[str]
19+
name_filter: NotRequired[str]
20+
21+
JOBS: dict[str, JobData] = {
1122
"main": {
1223
"description": "a group of not-too-long tests that run either without a VM, or with a single small one",
1324
"requirements": [
@@ -473,12 +484,15 @@
473484
"tests/storage/zfsvol/test_zfsvol_sr.py::TestZfsvolVm::test_quicktest",
474485
]
475486

487+
VmDef = str | tuple[str, str]
488+
VMSDef = dict[str, dict[str, VmDef | list[VmDef]]]
489+
476490
# Returns the vm filename or None if a host_version is passed and matches the one specified
477491
# with the vm filename in vm_data.py. ex: ("centos6-32-hvm-created_8.2-zstd.xva", "8\.2\..*")
478-
def filter_vm(vm, host_version):
492+
def filter_vm(vm: VmDef, host_version: str | None) -> str | None:
479493
import re
480494

481-
if type(vm) is tuple:
495+
if isinstance(vm, tuple):
482496
if len(vm) != 2:
483497
print(f"ERROR: VM definition from vm_data.py is a tuple so it should contain exactly two items:\n{vm}")
484498
sys.exit(1)
@@ -497,34 +511,34 @@ def filter_vm(vm, host_version):
497511

498512
return vm
499513

500-
def get_vm_or_vms_refs(handle, host_version=None):
514+
def get_vm_or_vms_refs(handle: str, host_version: str | None = None) -> str | list[str]:
501515
try:
502-
from vm_data import VMS
516+
from vm_data import VMS as VMS_untyped
503517
except ImportError:
504518
print("ERROR: Could not import VMS from vm_data.py.")
505-
print("Get the latest vm_data.py from XCP-ng's internal lab or copy data.py-dist and fill with your VM refs.")
519+
print("Get the latest vm_data.py from XCP-ng's internal lab or copy vm_data.py-dist and fill"
520+
" with your VM refs.")
506521
print("You may also bypass this error by providing your own --vm parameter(s).")
507522
sys.exit(1)
508523

524+
VMS = cast(VMSDef, VMS_untyped)
509525
category, key = handle.split("/")
510-
if category not in VMS or not VMS[category].get(key):
526+
if category not in VMS or key not in VMS[category]:
511527
print(f"ERROR: Could not find VMS['{category}']['{key}'] in vm_data.py, or it's empty.")
512528
print("You need to update your local vm_data.py.")
513529
print("You may also bypass this error by providing your own --vm parameter(s).")
514530
sys.exit(1)
515531

516-
if type(VMS[category][key]) is list:
532+
vms: str | list[str] | None = []
533+
vms_unfiltered = VMS[category][key]
534+
if isinstance(vms_unfiltered, list):
517535
# Multi VMs
518-
vms = list()
519-
for vm in VMS[category][key]:
520-
xva = filter_vm(vm, host_version)
521-
if xva is not None:
522-
vms.append(xva)
523-
if len(vms) == 0:
536+
vms = [xva for vm in vms_unfiltered if (xva := filter_vm(vm, host_version)) is not None]
537+
if vms == []:
524538
vms = None
525-
else:
539+
elif isinstance(vms_unfiltered, str):
526540
# Single VMs
527-
vms = filter_vm(VMS[category][key], host_version)
541+
vms = filter_vm(vms_unfiltered, host_version)
528542

529543
if vms is None:
530544
print(f"ERROR: Could not find VMS['{category}']['{key}'] for host version {host_version}.")
@@ -534,7 +548,8 @@ def get_vm_or_vms_refs(handle, host_version=None):
534548

535549
return vms
536550

537-
def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
551+
def build_pytest_cmd(job_data: JobData, hosts: str | None = None, host_version: str | None = None,
552+
pytest_args: list[str] = []) -> list[str]:
538553
markers = job_data.get("markers", None)
539554
name_filter = job_data.get("name_filter", None)
540555

@@ -544,13 +559,12 @@ def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
544559
if hosts is not None:
545560
try:
546561
host = hosts.split(',')[0]
547-
cmd = "lsb_release -sr"
548-
host_version = ssh(host, cmd)
562+
host_version = ssh(host, "lsb_release -sr")
549563
except Exception as e:
550564
print(e, file=sys.stderr)
551565

552-
def _join_pytest_args(arg, option):
553-
cli_args = []
566+
def _join_pytest_args(arg: str | None, option: str) -> str | None:
567+
cli_args: list[str] = []
554568
try:
555569
while True:
556570
i = pytest_args.index(option)
@@ -601,21 +615,21 @@ def _join_pytest_args(arg, option):
601615
cmd += pytest_args
602616
return cmd
603617

604-
def action_list(args):
618+
def action_list(args: argparse.Namespace) -> None:
605619
for job, data in JOBS.items():
606620
print(f"{job}: {data['description']}")
607621

608-
def action_show(args):
622+
def action_show(args: argparse.Namespace) -> None:
609623
print(json.dumps(JOBS[args.job], indent=4))
610624

611-
def action_collect(args):
625+
def action_collect(args: argparse.Namespace) -> None:
612626
cmd = build_pytest_cmd(JOBS[args.job], None, args.host_version, ["--collect-only"] + args.pytest_args)
613627
subprocess.run(cmd)
614628

615-
def action_check(args):
629+
def action_check(args: argparse.Namespace) -> None:
616630
error = False
617631

618-
def extract_tests(cmd):
632+
def extract_tests(cmd: list[str]) -> set[str]:
619633
tests = set()
620634
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
621635
if res.returncode != 0 and res.returncode != 5: # 5 means no test found
@@ -664,6 +678,7 @@ def extract_tests(cmd):
664678
multi_vm_tests = extract_tests(["pytest", "--collect-only", "-q", "-m", "multi_vms"]) - broken_tests
665679
job_tests = set()
666680
for job_data in JOBS.values():
681+
assert isinstance(job_data["params"], dict)
667682
if "--vm[]" in job_data["params"]:
668683
job_tests |= extract_tests(build_pytest_cmd(job_data, None, None, ["--collect-only", "-q", "--vm=a_vm"]))
669684
tests_missing = sorted(list(multi_vm_tests - job_tests))
@@ -677,23 +692,25 @@ def extract_tests(cmd):
677692
if error:
678693
sys.exit(1)
679694

680-
def action_run(args):
695+
def action_run(args: argparse.Namespace) -> None:
681696
cmd = build_pytest_cmd(JOBS[args.job], args.hosts, None, args.pytest_args)
682697
print(subprocess.list2cmdline(cmd))
683698
if args.print_only:
684699
return
685700

686701
# check that enough pool masters have been provided
687702
nb_pools = len(args.hosts.split(","))
688-
if nb_pools < JOBS[args.job]["nb_pools"]:
689-
print(f"Error: only {nb_pools} master host(s) provided, {JOBS[args.job]['nb_pools']} required.")
703+
job_nb_pools = JOBS[args.job]["nb_pools"]
704+
assert isinstance(job_nb_pools, int)
705+
if nb_pools < job_nb_pools:
706+
print(f"Error: only {nb_pools} master host(s) provided, {job_nb_pools} required.")
690707
sys.exit(1)
691708

692709
res = subprocess.run(cmd)
693710
if res.returncode:
694711
sys.exit(1)
695712

696-
def main():
713+
def main() -> None:
697714
parser = argparse.ArgumentParser(description="Manage test jobs")
698715
subparsers = parser.add_subparsers(dest="action", metavar="action")
699716
subparsers.required = True

0 commit comments

Comments
 (0)