@@ -259,6 +259,25 @@ def volume_validator(cls, value):
259259 return value
260260
261261
262+ class AffinityConfig (defaultdict ):
263+ def __init__ (self , default_factory = None , * args , ** kwargs ):
264+ super ().__init__ (default_factory , * args , ** kwargs )
265+
266+ def is_set_for (self , node_name ):
267+ return bool (self .get_for (node_name ))
268+
269+ def get_for (self , node_name ):
270+ node_value = self ._get_or_default (node_name , None )
271+ if node_value is not None :
272+ return node_value
273+ return self ._get_or_default ("__default__" , None )
274+
275+ def _get_or_default (self , node_name , default ):
276+ if node_name in self :
277+ return self [node_name ]
278+ return self .get ("__default__" , default )
279+
280+
262281class RunConfig (BaseModel ):
263282 def __init__ (self , ** kwargs ):
264283 super ().__init__ (** kwargs )
@@ -293,6 +312,10 @@ def _validate_tolerations(cls, value):
293312 def _validate_extra_volumes (cls , value ):
294313 return RunConfig ._create_default_dict_with (value , [], defaultdict )
295314
315+ @validator ("affinity" , always = True )
316+ def _validate_affinity (cls , value ):
317+ return RunConfig ._create_default_dict_with (value , None , AffinityConfig )
318+
296319 image : str
297320 image_pull_policy : str = "IfNotPresent"
298321 root : Optional [str ]
@@ -305,6 +328,7 @@ def _validate_extra_volumes(cls, value):
305328 retry_policy : Optional [Dict [str , Optional [RetryPolicyConfig ]]]
306329 volume : Optional [VolumeConfig ] = None
307330 extra_volumes : Optional [Dict [str , List [ExtraVolumeConfig ]]] = None
331+ affinity : Optional [Dict [str , Any ]] = None
308332 wait_for_completion : bool = False
309333 store_kedro_outputs_as_kfp_artifacts : bool = True
310334 max_cache_staleness : Optional [str ] = None
0 commit comments