77
88from lib .commands import ssh
99
10- JOBS = {
10+ from typing import NotRequired , TypedDict , cast
11+
12+ class JobData (TypedDict ):
13+ description : str
14+ requirements : list [str ]
15+ nb_pools : int
16+ params : dict [str , str ]
17+ paths : list [str ]
18+ markers : NotRequired [str ]
19+ name_filter : NotRequired [str ]
20+
21+ JOBS : dict [str , JobData ] = {
1122 "main" : {
1223 "description" : "a group of not-too-long tests that run either without a VM, or with a single small one" ,
1324 "requirements" : [
473484 "tests/storage/zfsvol/test_zfsvol_sr.py::TestZfsvolVm::test_quicktest" ,
474485]
475486
487+ VmDef = str | tuple [str , str ]
488+ VMSDef = dict [str , dict [str , VmDef | list [VmDef ]]]
489+
476490# Returns the vm filename or None if a host_version is passed and matches the one specified
477491# with the vm filename in vm_data.py. ex: ("centos6-32-hvm-created_8.2-zstd.xva", "8\.2\..*")
478- def filter_vm (vm , host_version ) :
492+ def filter_vm (vm : VmDef , host_version : str | None ) -> str | None :
479493 import re
480494
481- if type (vm ) is tuple :
495+ if isinstance (vm , tuple ) :
482496 if len (vm ) != 2 :
483497 print (f"ERROR: VM definition from vm_data.py is a tuple so it should contain exactly two items:\n { vm } " )
484498 sys .exit (1 )
@@ -497,34 +511,34 @@ def filter_vm(vm, host_version):
497511
498512 return vm
499513
500- def get_vm_or_vms_refs (handle , host_version = None ):
514+ def get_vm_or_vms_refs (handle : str , host_version : str | None = None ) -> str | list [ str ] :
501515 try :
502- from vm_data import VMS
516+ from vm_data import VMS as VMS_untyped
503517 except ImportError :
504518 print ("ERROR: Could not import VMS from vm_data.py." )
505- print ("Get the latest vm_data.py from XCP-ng's internal lab or copy data.py-dist and fill with your VM refs." )
519+ print ("Get the latest vm_data.py from XCP-ng's internal lab or copy vm_data.py-dist and fill"
520+ " with your VM refs." )
506521 print ("You may also bypass this error by providing your own --vm parameter(s)." )
507522 sys .exit (1 )
508523
524+ VMS = cast (VMSDef , VMS_untyped )
509525 category , key = handle .split ("/" )
510- if category not in VMS or not VMS [category ]. get ( key ) :
526+ if category not in VMS or key not in VMS [category ]:
511527 print (f"ERROR: Could not find VMS['{ category } ']['{ key } '] in vm_data.py, or it's empty." )
512528 print ("You need to update your local vm_data.py." )
513529 print ("You may also bypass this error by providing your own --vm parameter(s)." )
514530 sys .exit (1 )
515531
516- if type (VMS [category ][key ]) is list :
532+ vms : str | list [str ] | None = []
533+ vms_unfiltered = VMS [category ][key ]
534+ if isinstance (vms_unfiltered , list ):
517535 # Multi VMs
518- vms = list ()
519- for vm in VMS [category ][key ]:
520- xva = filter_vm (vm , host_version )
521- if xva is not None :
522- vms .append (xva )
523- if len (vms ) == 0 :
536+ vms = [xva for vm in vms_unfiltered if (xva := filter_vm (vm , host_version )) is not None ]
537+ if vms == []:
524538 vms = None
525- else :
539+ elif isinstance ( vms_unfiltered , str ) :
526540 # Single VMs
527- vms = filter_vm (VMS [ category ][ key ] , host_version )
541+ vms = filter_vm (vms_unfiltered , host_version )
528542
529543 if vms is None :
530544 print (f"ERROR: Could not find VMS['{ category } ']['{ key } '] for host version { host_version } ." )
@@ -534,7 +548,8 @@ def get_vm_or_vms_refs(handle, host_version=None):
534548
535549 return vms
536550
537- def build_pytest_cmd (job_data , hosts = None , host_version = None , pytest_args = []):
551+ def build_pytest_cmd (job_data : JobData , hosts : str | None = None , host_version : str | None = None ,
552+ pytest_args : list [str ] = []) -> list [str ]:
538553 markers = job_data .get ("markers" , None )
539554 name_filter = job_data .get ("name_filter" , None )
540555
@@ -544,13 +559,12 @@ def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
544559 if hosts is not None :
545560 try :
546561 host = hosts .split (',' )[0 ]
547- cmd = "lsb_release -sr"
548- host_version = ssh (host , cmd )
562+ host_version = ssh (host , "lsb_release -sr" )
549563 except Exception as e :
550564 print (e , file = sys .stderr )
551565
552- def _join_pytest_args (arg , option ) :
553- cli_args = []
566+ def _join_pytest_args (arg : str | None , option : str ) -> str | None :
567+ cli_args : list [ str ] = []
554568 try :
555569 while True :
556570 i = pytest_args .index (option )
@@ -601,21 +615,21 @@ def _join_pytest_args(arg, option):
601615 cmd += pytest_args
602616 return cmd
603617
604- def action_list (args ) :
618+ def action_list (args : argparse . Namespace ) -> None :
605619 for job , data in JOBS .items ():
606620 print (f"{ job } : { data ['description' ]} " )
607621
608- def action_show (args ) :
622+ def action_show (args : argparse . Namespace ) -> None :
609623 print (json .dumps (JOBS [args .job ], indent = 4 ))
610624
611- def action_collect (args ) :
625+ def action_collect (args : argparse . Namespace ) -> None :
612626 cmd = build_pytest_cmd (JOBS [args .job ], None , args .host_version , ["--collect-only" ] + args .pytest_args )
613627 subprocess .run (cmd )
614628
615- def action_check (args ) :
629+ def action_check (args : argparse . Namespace ) -> None :
616630 error = False
617631
618- def extract_tests (cmd ) :
632+ def extract_tests (cmd : list [ str ]) -> set [ str ] :
619633 tests = set ()
620634 res = subprocess .run (cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
621635 if res .returncode != 0 and res .returncode != 5 : # 5 means no test found
@@ -664,6 +678,7 @@ def extract_tests(cmd):
664678 multi_vm_tests = extract_tests (["pytest" , "--collect-only" , "-q" , "-m" , "multi_vms" ]) - broken_tests
665679 job_tests = set ()
666680 for job_data in JOBS .values ():
681+ assert isinstance (job_data ["params" ], dict )
667682 if "--vm[]" in job_data ["params" ]:
668683 job_tests |= extract_tests (build_pytest_cmd (job_data , None , None , ["--collect-only" , "-q" , "--vm=a_vm" ]))
669684 tests_missing = sorted (list (multi_vm_tests - job_tests ))
@@ -677,23 +692,25 @@ def extract_tests(cmd):
677692 if error :
678693 sys .exit (1 )
679694
680- def action_run (args ) :
695+ def action_run (args : argparse . Namespace ) -> None :
681696 cmd = build_pytest_cmd (JOBS [args .job ], args .hosts , None , args .pytest_args )
682697 print (subprocess .list2cmdline (cmd ))
683698 if args .print_only :
684699 return
685700
686701 # check that enough pool masters have been provided
687702 nb_pools = len (args .hosts .split ("," ))
688- if nb_pools < JOBS [args .job ]["nb_pools" ]:
689- print (f"Error: only { nb_pools } master host(s) provided, { JOBS [args .job ]['nb_pools' ]} required." )
703+ job_nb_pools = JOBS [args .job ]["nb_pools" ]
704+ assert isinstance (job_nb_pools , int )
705+ if nb_pools < job_nb_pools :
706+ print (f"Error: only { nb_pools } master host(s) provided, { job_nb_pools } required." )
690707 sys .exit (1 )
691708
692709 res = subprocess .run (cmd )
693710 if res .returncode :
694711 sys .exit (1 )
695712
696- def main ():
713+ def main () -> None :
697714 parser = argparse .ArgumentParser (description = "Manage test jobs" )
698715 subparsers = parser .add_subparsers (dest = "action" , metavar = "action" )
699716 subparsers .required = True
0 commit comments