1717 PrecisionPolicy ,
1818)
1919from deployment .core .backend import Backend
20+ from deployment .core .device import DeviceSpec
2021from deployment .exporters .common .configs import TensorRTProfileConfig
2122
2223
@@ -72,8 +73,8 @@ def should_export_tensorrt(self) -> bool:
7273class DeviceConfig :
7374 """Normalized device settings shared across deployment stages."""
7475
75- cpu : str = "cpu"
76- cuda : Optional [str ] = "cuda:0"
76+ cpu : DeviceSpec = field ( default_factory = lambda : DeviceSpec . from_value ( "cpu" ))
77+ cuda : Optional [DeviceSpec ] = field ( default_factory = lambda : DeviceSpec . from_value ( "cuda:0" ))
7778
7879 def __post_init__ (self ) -> None :
7980 object .__setattr__ (self , "cpu" , self ._normalize_cpu (self .cpu ))
@@ -82,47 +83,32 @@ def __post_init__(self) -> None:
8283 @classmethod
8384 def from_dict (cls , config_dict : Mapping [str , Any ]) -> DeviceConfig :
8485 """Create DeviceConfig from dict."""
85- return cls (cpu = config_dict .get ("cpu" , cls . cpu ), cuda = config_dict .get ("cuda" , cls . cuda ))
86+ return cls (cpu = config_dict .get ("cpu" , " cpu" ), cuda = config_dict .get ("cuda" , " cuda:0" ))
8687
8788 @staticmethod
88- def _normalize_cpu (device : Optional [str ]) -> str :
89- """Normalize CPU device string."""
90- if not device :
91- return "cpu"
92- normalized = str (device ).strip ().lower ()
93- if normalized .startswith ("cuda" ):
89+ def _normalize_cpu (device : Any ) -> DeviceSpec :
90+ """Normalize CPU device."""
91+ normalized = DeviceSpec .from_value (device if device is not None else "cpu" )
92+ if normalized .is_cuda :
9493 raise ValueError ("CPU device cannot be a CUDA device" )
9594 return normalized
9695
9796 @staticmethod
98- def _normalize_cuda (device : Optional [ str ] ) -> Optional [str ]:
99- """Normalize CUDA device string to 'cuda:N' format ."""
97+ def _normalize_cuda (device : Any ) -> Optional [DeviceSpec ]:
98+ """Normalize CUDA device to DeviceSpec ."""
10099 if device is None :
101100 return None
102- if not isinstance (device , str ):
103- raise ValueError ("cuda device must be a string (e.g., 'cuda:0')" )
104- normalized = device .strip ().lower ()
105- if normalized == "" :
106- return None
107- if normalized == "cuda" :
108- normalized = "cuda:0"
109- if not normalized .startswith ("cuda" ):
110- raise ValueError (f"Invalid CUDA device '{ device } '. Must start with 'cuda'" )
111- suffix = normalized .split (":" , 1 )[1 ] if ":" in normalized else "0"
112- suffix = suffix .strip () or "0"
113- if not suffix .isdigit ():
114- raise ValueError (f"Invalid CUDA device index in '{ device } '" )
115- device_id = int (suffix )
116- if device_id < 0 :
117- raise ValueError ("CUDA device index must be non-negative" )
118- return f"cuda:{ device_id } "
101+ normalized = DeviceSpec .from_value (device )
102+ if not normalized .is_cuda :
103+ raise ValueError (f"Invalid CUDA device '{ device } '." )
104+ return normalized
119105
120106 @property
121107 def cuda_device_index (self ) -> Optional [int ]:
122108 """Return CUDA device index as integer (if configured)."""
123109 if self .cuda is None :
124110 return None
125- return int ( self .cuda .split ( ":" , 1 )[ 1 ])
111+ return self .cuda .index
126112
127113
128114@dataclass (frozen = True )
@@ -360,7 +346,7 @@ class EvaluationConfig:
360346 verbose : bool = False
361347 backends : Mapping [Any , Mapping [str , Any ]] = field (default_factory = _empty_mapping )
362348 models : Mapping [Any , Any ] = field (default_factory = _empty_mapping )
363- devices : Mapping [str , str ] = field (default_factory = _empty_mapping )
349+ devices : Mapping [str , DeviceSpec ] = field (default_factory = _empty_mapping )
364350
365351 @classmethod
366352 def from_dict (cls , config_dict : Mapping [str , Any ]) -> EvaluationConfig :
@@ -383,13 +369,15 @@ def from_dict(cls, config_dict: Mapping[str, Any]) -> EvaluationConfig:
383369 if not isinstance (devices_raw , Mapping ):
384370 raise TypeError (f"evaluation.devices must be a mapping, got { type (devices_raw ).__name__ } " )
385371
372+ normalized_devices = {str (key ): DeviceSpec .from_value (value ) for key , value in devices_raw .items ()}
373+
386374 return cls (
387375 enabled = config_dict .get ("enabled" , False ),
388376 num_samples = config_dict .get ("num_samples" , 10 ),
389377 verbose = config_dict .get ("verbose" , False ),
390378 backends = MappingProxyType (backends_frozen ),
391379 models = MappingProxyType (dict (models_raw )),
392- devices = MappingProxyType (dict ( devices_raw ) ),
380+ devices = MappingProxyType (normalized_devices ),
393381 )
394382
395383
@@ -398,9 +386,9 @@ class VerificationScenario:
398386 """Immutable verification scenario specification."""
399387
400388 ref_backend : Backend
401- ref_device : str
389+ ref_device : DeviceSpec
402390 test_backend : Backend
403- test_device : str
391+ test_device : DeviceSpec
404392
405393 @classmethod
406394 def from_dict (cls , data : Mapping [str , Any ]) -> VerificationScenario :
@@ -410,9 +398,9 @@ def from_dict(cls, data: Mapping[str, Any]) -> VerificationScenario:
410398
411399 return cls (
412400 ref_backend = Backend .from_value (data ["ref_backend" ]),
413- ref_device = str (data ["ref_device" ]),
401+ ref_device = DeviceSpec . from_value (data ["ref_device" ]),
414402 test_backend = Backend .from_value (data ["test_backend" ]),
415- test_device = str (data ["test_device" ]),
403+ test_device = DeviceSpec . from_value (data ["test_device" ]),
416404 )
417405
418406
@@ -423,7 +411,7 @@ class VerificationConfig:
423411 enabled : bool = True
424412 num_verify_samples : int = 3
425413 tolerance : float = 0.1
426- devices : Mapping [str , str ] = field (default_factory = _empty_mapping )
414+ devices : Mapping [str , DeviceSpec ] = field (default_factory = _empty_mapping )
427415 scenarios : Mapping [ExportMode , Tuple [VerificationScenario , ...]] = field (default_factory = _empty_mapping )
428416
429417 @classmethod
@@ -452,11 +440,13 @@ def from_dict(cls, config_dict: Mapping[str, Any]) -> VerificationConfig:
452440 if not isinstance (devices_raw , Mapping ):
453441 raise TypeError (f"verification.devices must be a mapping, got { type (devices_raw ).__name__ } " )
454442
443+ normalized_devices = {str (key ): DeviceSpec .from_value (value ) for key , value in devices_raw .items ()}
444+
455445 return cls (
456446 enabled = config_dict .get ("enabled" , True ),
457447 num_verify_samples = config_dict .get ("num_verify_samples" , 3 ),
458448 tolerance = config_dict .get ("tolerance" , 0.1 ),
459- devices = MappingProxyType (dict ( devices_raw ) ),
449+ devices = MappingProxyType (normalized_devices ),
460450 scenarios = MappingProxyType (scenario_map ),
461451 )
462452
0 commit comments