11from dataclasses import dataclass , field
2- from typing import TYPE_CHECKING , Any , Dict , List , Mapping , Optional , Union
2+ from typing import TYPE_CHECKING , Any , Dict , List , Literal , Mapping , Optional , Union
33
44from aibs_informatics_core .env import ENV_BASE_KEY_ALIAS , EnvBase , get_env_base
55from aibs_informatics_core .models .aws .batch import JobName , ResourceRequirements
66from aibs_informatics_core .utils .decorators import retry
77from aibs_informatics_core .utils .hashing import sha256_hexdigest
88from aibs_informatics_core .utils .logging import get_logger
9- from aibs_informatics_core .utils .tools .dicttools import convert_key_case , remove_null_values
9+ from aibs_informatics_core .utils .tools .dicttools import convert_key_case
1010from aibs_informatics_core .utils .tools .strtools import pascalcase
1111from botocore .exceptions import ClientError
1212
2424 HostTypeDef ,
2525 JobDefinitionTypeDef ,
2626 KeyValuePairTypeDef ,
27+ LinuxParametersTypeDef ,
2728 MountPointTypeDef ,
28- RegisterJobDefinitionRequestRequestTypeDef ,
2929 RegisterJobDefinitionResponseTypeDef ,
3030 ResourceRequirementTypeDef ,
3131 RetryStrategyTypeDef ,
4242 JobDefinitionTypeDef = dict
4343 DescribeJobsResponseTypeDef = dict
4444 KeyValuePairTypeDef = dict
45+ LinuxParametersTypeDef = dict
4546 MountPointTypeDef = dict
46- RegisterJobDefinitionRequestRequestTypeDef = dict
4747 RegisterJobDefinitionResponseTypeDef = dict
4848 ResourceRequirementTypeDef = dict
4949 RetryStrategyTypeDef = dict
@@ -61,27 +61,27 @@ def to_volume(
6161 name : Optional [str ],
6262 efs_volume_configuration : Optional [EFSVolumeConfigurationTypeDef ],
6363) -> VolumeTypeDef :
64- return remove_null_values (
65- VolumeTypeDef (
66- host = HostTypeDef (sourcePath = source_path ) if source_path else None ,
67- name = name ,
68- efsVolumeConfiguration = efs_volume_configuration ,
69- )
70- )
64+ volume_dict = VolumeTypeDef ()
65+ if source_path :
66+ volume_dict ["host" ] = HostTypeDef (sourcePath = source_path )
67+ if name :
68+ volume_dict ["name" ] = name
69+ if efs_volume_configuration :
70+ volume_dict ["efsVolumeConfiguration" ] = efs_volume_configuration
71+ return volume_dict
7172
7273
7374def to_mount_point (
7475 container_path : Optional [str ],
7576 read_only : bool ,
7677 source_volume : Optional [str ],
7778) -> MountPointTypeDef :
78- return remove_null_values (
79- MountPointTypeDef (
80- containerPath = container_path ,
81- readOnly = read_only ,
82- sourceVolume = source_volume ,
83- )
84- )
79+ mount_point_dict = MountPointTypeDef (readOnly = read_only )
80+ if container_path :
81+ mount_point_dict ["containerPath" ] = container_path
82+ if source_volume :
83+ mount_point_dict ["sourceVolume" ] = source_volume
84+ return mount_point_dict
8585
8686
8787def to_key_value_pairs (
@@ -104,7 +104,7 @@ def to_key_value_pairs(
104104 for k , v in environment .items ()
105105 if not remove_null_values or v is not None
106106 ],
107- key = lambda _ : _ .get ("name" ),
107+ key = lambda _ : _ .get ("name" , "" ),
108108 )
109109
110110
@@ -115,16 +115,23 @@ def to_resource_requirements(
115115) -> List [ResourceRequirementTypeDef ]:
116116 """Converts Batch resource requirement parameters into a list of ResourceRequirement objects
117117
118+ The returned list only includes dictionary entries for resources that specify
119+ an explicit value. Anything unset will be dropped.
120+
118121 Args:
119- gpu (Optional[int], optional): number of . Defaults to None.
120- memory (Optional[int], optional): _description_ . Defaults to None.
121- vcpus (Optional[int], optional): _description_ . Defaults to None.
122+ gpu (Optional[int], optional): number of GPUs to use . Defaults to None.
123+ memory (Optional[int], optional): amount of memory in MiB . Defaults to None.
124+ vcpus (Optional[int], optional): Number of VCPUs to use . Defaults to None.
122125
123126 Returns:
124127 List[ResourceRequirementTypeDef]: list of resource requirements
125128 """
126129
127- pairs = [("GPU" , gpu ), ("MEMORY" , memory ), ("VCPU" , vcpus )]
130+ pairs : list [tuple [Literal ["GPU" , "MEMORY" , "VCPU" ], Optional [int ]]] = [
131+ ("GPU" , gpu ),
132+ ("MEMORY" , memory ),
133+ ("VCPU" , vcpus ),
134+ ]
128135 return [ResourceRequirementTypeDef (type = t , value = str (v )) for t , v in pairs if v is not None ]
129136
130137
@@ -176,7 +183,7 @@ def register_job_definition(
176183 tags : Optional [Mapping [str , str ]] = None ,
177184 propagate_tags : bool = False ,
178185 region : Optional [str ] = None ,
179- ) -> JobDefinitionTypeDef :
186+ ) -> JobDefinitionTypeDef | RegisterJobDefinitionResponseTypeDef :
180187 batch = get_batch_client (region = region )
181188
182189 # First we check to make sure that we aren't crearting unnecessary revisions
@@ -210,7 +217,7 @@ def register_job_definition(
210217 logger .info (
211218 f"Registering job definition with following properties: { register_job_definition_kwargs } "
212219 )
213- response = batch .register_job_definition (** register_job_definition_kwargs )
220+ response = batch .register_job_definition (** register_job_definition_kwargs ) # type: ignore[arg-type]
214221 return response
215222
216223
@@ -263,7 +270,7 @@ class BatchJobBuilder:
263270 mount_points : List [MountPointTypeDef ] = field (default_factory = list )
264271 volumes : List [VolumeTypeDef ] = field (default_factory = list )
265272 privileged : bool = field (default = False )
266- linux_parameters : Optional [Dict [ str , Any ] ] = field (default = None )
273+ linux_parameters : Optional [LinuxParametersTypeDef ] = field (default = None )
267274 env_base : EnvBase = field (default_factory = EnvBase .from_env )
268275
269276 def __post_init__ (self ):
@@ -295,7 +302,7 @@ def container_overrides(self) -> ContainerOverridesTypeDef:
295302
296303 @property
297304 def container_overrides__sfn (self ) -> Dict [str , Any ]:
298- return convert_key_case (self .container_overrides , pascalcase )
305+ return convert_key_case (self .container_overrides , pascalcase ) # type: ignore[arg-type]
299306
300307 def _normalized_resource_requirements (self ) -> List [ResourceRequirementTypeDef ]:
301308 if isinstance (self .resource_requirements , list ):
0 commit comments