Skip to content

Commit 2b28f60

Browse files
committed
chore: fix more Device spec
Signed-off-by: vividf <yihsiang.fang@tier4.jp>
1 parent dd46052 commit 2b28f60

5 files changed

Lines changed: 40 additions & 73 deletions

File tree

deployment/configs/schema.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
PrecisionPolicy,
1818
)
1919
from deployment.core.backend import Backend
20+
from deployment.core.device import DeviceSpec
2021
from deployment.exporters.common.configs import TensorRTProfileConfig
2122

2223

@@ -72,8 +73,8 @@ def should_export_tensorrt(self) -> bool:
7273
class DeviceConfig:
7374
"""Normalized device settings shared across deployment stages."""
7475

75-
cpu: str = "cpu"
76-
cuda: Optional[str] = "cuda:0"
76+
cpu: DeviceSpec = field(default_factory=lambda: DeviceSpec.from_value("cpu"))
77+
cuda: Optional[DeviceSpec] = field(default_factory=lambda: DeviceSpec.from_value("cuda:0"))
7778

7879
def __post_init__(self) -> None:
7980
object.__setattr__(self, "cpu", self._normalize_cpu(self.cpu))
@@ -82,47 +83,32 @@ def __post_init__(self) -> None:
8283
@classmethod
8384
def from_dict(cls, config_dict: Mapping[str, Any]) -> DeviceConfig:
8485
"""Create DeviceConfig from dict."""
85-
return cls(cpu=config_dict.get("cpu", cls.cpu), cuda=config_dict.get("cuda", cls.cuda))
86+
return cls(cpu=config_dict.get("cpu", "cpu"), cuda=config_dict.get("cuda", "cuda:0"))
8687

8788
@staticmethod
88-
def _normalize_cpu(device: Optional[str]) -> str:
89-
"""Normalize CPU device string."""
90-
if not device:
91-
return "cpu"
92-
normalized = str(device).strip().lower()
93-
if normalized.startswith("cuda"):
89+
def _normalize_cpu(device: Any) -> DeviceSpec:
90+
"""Normalize CPU device."""
91+
normalized = DeviceSpec.from_value(device if device is not None else "cpu")
92+
if normalized.is_cuda:
9493
raise ValueError("CPU device cannot be a CUDA device")
9594
return normalized
9695

9796
@staticmethod
98-
def _normalize_cuda(device: Optional[str]) -> Optional[str]:
99-
"""Normalize CUDA device string to 'cuda:N' format."""
97+
def _normalize_cuda(device: Any) -> Optional[DeviceSpec]:
98+
"""Normalize CUDA device to DeviceSpec."""
10099
if device is None:
101100
return None
102-
if not isinstance(device, str):
103-
raise ValueError("cuda device must be a string (e.g., 'cuda:0')")
104-
normalized = device.strip().lower()
105-
if normalized == "":
106-
return None
107-
if normalized == "cuda":
108-
normalized = "cuda:0"
109-
if not normalized.startswith("cuda"):
110-
raise ValueError(f"Invalid CUDA device '{device}'. Must start with 'cuda'")
111-
suffix = normalized.split(":", 1)[1] if ":" in normalized else "0"
112-
suffix = suffix.strip() or "0"
113-
if not suffix.isdigit():
114-
raise ValueError(f"Invalid CUDA device index in '{device}'")
115-
device_id = int(suffix)
116-
if device_id < 0:
117-
raise ValueError("CUDA device index must be non-negative")
118-
return f"cuda:{device_id}"
101+
normalized = DeviceSpec.from_value(device)
102+
if not normalized.is_cuda:
103+
raise ValueError(f"Invalid CUDA device '{device}'.")
104+
return normalized
119105

120106
@property
121107
def cuda_device_index(self) -> Optional[int]:
122108
"""Return CUDA device index as integer (if configured)."""
123109
if self.cuda is None:
124110
return None
125-
return int(self.cuda.split(":", 1)[1])
111+
return self.cuda.index
126112

127113

128114
@dataclass(frozen=True)
@@ -360,7 +346,7 @@ class EvaluationConfig:
360346
verbose: bool = False
361347
backends: Mapping[Any, Mapping[str, Any]] = field(default_factory=_empty_mapping)
362348
models: Mapping[Any, Any] = field(default_factory=_empty_mapping)
363-
devices: Mapping[str, str] = field(default_factory=_empty_mapping)
349+
devices: Mapping[str, DeviceSpec] = field(default_factory=_empty_mapping)
364350

365351
@classmethod
366352
def from_dict(cls, config_dict: Mapping[str, Any]) -> EvaluationConfig:
@@ -383,13 +369,15 @@ def from_dict(cls, config_dict: Mapping[str, Any]) -> EvaluationConfig:
383369
if not isinstance(devices_raw, Mapping):
384370
raise TypeError(f"evaluation.devices must be a mapping, got {type(devices_raw).__name__}")
385371

372+
normalized_devices = {str(key): DeviceSpec.from_value(value) for key, value in devices_raw.items()}
373+
386374
return cls(
387375
enabled=config_dict.get("enabled", False),
388376
num_samples=config_dict.get("num_samples", 10),
389377
verbose=config_dict.get("verbose", False),
390378
backends=MappingProxyType(backends_frozen),
391379
models=MappingProxyType(dict(models_raw)),
392-
devices=MappingProxyType(dict(devices_raw)),
380+
devices=MappingProxyType(normalized_devices),
393381
)
394382

395383

@@ -398,9 +386,9 @@ class VerificationScenario:
398386
"""Immutable verification scenario specification."""
399387

400388
ref_backend: Backend
401-
ref_device: str
389+
ref_device: DeviceSpec
402390
test_backend: Backend
403-
test_device: str
391+
test_device: DeviceSpec
404392

405393
@classmethod
406394
def from_dict(cls, data: Mapping[str, Any]) -> VerificationScenario:
@@ -410,9 +398,9 @@ def from_dict(cls, data: Mapping[str, Any]) -> VerificationScenario:
410398

411399
return cls(
412400
ref_backend=Backend.from_value(data["ref_backend"]),
413-
ref_device=str(data["ref_device"]),
401+
ref_device=DeviceSpec.from_value(data["ref_device"]),
414402
test_backend=Backend.from_value(data["test_backend"]),
415-
test_device=str(data["test_device"]),
403+
test_device=DeviceSpec.from_value(data["test_device"]),
416404
)
417405

418406

@@ -423,7 +411,7 @@ class VerificationConfig:
423411
enabled: bool = True
424412
num_verify_samples: int = 3
425413
tolerance: float = 0.1
426-
devices: Mapping[str, str] = field(default_factory=_empty_mapping)
414+
devices: Mapping[str, DeviceSpec] = field(default_factory=_empty_mapping)
427415
scenarios: Mapping[ExportMode, Tuple[VerificationScenario, ...]] = field(default_factory=_empty_mapping)
428416

429417
@classmethod
@@ -452,11 +440,13 @@ def from_dict(cls, config_dict: Mapping[str, Any]) -> VerificationConfig:
452440
if not isinstance(devices_raw, Mapping):
453441
raise TypeError(f"verification.devices must be a mapping, got {type(devices_raw).__name__}")
454442

443+
normalized_devices = {str(key): DeviceSpec.from_value(value) for key, value in devices_raw.items()}
444+
455445
return cls(
456446
enabled=config_dict.get("enabled", True),
457447
num_verify_samples=config_dict.get("num_verify_samples", 3),
458448
tolerance=config_dict.get("tolerance", 0.1),
459-
devices=MappingProxyType(dict(devices_raw)),
449+
devices=MappingProxyType(normalized_devices),
460450
scenarios=MappingProxyType(scenario_map),
461451
)
462452

deployment/exporters/export_pipelines/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from deployment.configs import BaseDeploymentConfig
1111
from deployment.core.artifacts import Artifact
12+
from deployment.core.device import DeviceSpec
1213
from deployment.core.io.base_data_loader import BaseDataLoader
1314

1415

@@ -54,7 +55,7 @@ def export(
5455
onnx_path: str,
5556
output_dir: str,
5657
config: BaseDeploymentConfig,
57-
device: str,
58+
device: DeviceSpec,
5859
) -> Artifact:
5960
"""
6061
Execute the TensorRT export pipeline and return the produced artifact.
@@ -63,7 +64,7 @@ def export(
6364
onnx_path: Path to ONNX model file/directory
6465
output_dir: Directory for output files
6566
config: Deployment configuration
66-
device: CUDA device string
67+
device: CUDA device specification
6768
6869
Returns:
6970
Artifact describing the exported TensorRT output

deployment/runtime/evaluation_orchestrator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _get_models_to_evaluate(self, artifact_manager: ArtifactManager) -> List[Mod
128128
if not backend_cfg.get("enabled", False):
129129
continue
130130

131-
raw_device = backend_cfg.get("device") or str(self._get_default_device(backend_enum))
131+
raw_device = backend_cfg.get("device") or self._get_default_device(backend_enum)
132132
device = DeviceSpec.from_value(raw_device)
133133
artifact, is_valid = artifact_manager.resolve_artifact(backend_enum)
134134

@@ -175,8 +175,10 @@ def _get_default_device(self, backend: Backend) -> DeviceSpec:
175175
Default device string
176176
"""
177177
if backend is Backend.TENSORRT:
178-
return DeviceSpec.from_value(self.config.devices.cuda or "cuda:0")
179-
return DeviceSpec.from_value(self.config.devices.cpu or "cpu")
178+
if self.config.devices.cuda is None:
179+
raise RuntimeError("TensorRT backend requires a configured CUDA device.")
180+
return self.config.devices.cuda
181+
return self.config.devices.cpu
180182

181183
def _print_cross_backend_comparison(self, all_results: Mapping[str, Any]) -> None:
182184
"""

deployment/runtime/export_orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,9 @@ def _export_tensorrt(self, onnx_path: str, context: ExportContext) -> Optional[A
400400
os.makedirs(tensorrt_dir, exist_ok=True)
401401

402402
cuda_device = self.config.devices.cuda
403-
device_id = self.config.devices.cuda_device_index
404-
if cuda_device is None or device_id is None:
403+
if cuda_device is None:
405404
raise RuntimeError("TensorRT export requires a CUDA device. Set deploy_cfg.devices['cuda'].")
405+
device_id = cuda_device.index
406406
torch.cuda.set_device(device_id)
407407
self.logger.info(f"Using CUDA device for TensorRT export: {cuda_device}")
408408

deployment/runtime/verification_orchestrator.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from __future__ import annotations
88

99
import logging
10-
from typing import Any, Dict, Mapping
10+
from typing import Any, Dict
1111

1212
from deployment.configs import BaseDeploymentConfig
1313
from deployment.core.backend import Backend
14-
from deployment.core.device import DeviceSpec
1514
from deployment.core.evaluation.base_evaluator import BaseEvaluator
1615
from deployment.core.evaluation.evaluator_types import ModelSpec
1716
from deployment.core.io.base_data_loader import BaseDataLoader
@@ -85,16 +84,6 @@ def run(self, artifact_manager: ArtifactManager) -> Dict[str, Any]:
8584

8685
num_verify_samples = verification_cfg.num_verify_samples
8786
tolerance = verification_cfg.tolerance
88-
devices_raw = verification_cfg.devices
89-
if devices_raw is None:
90-
devices_raw = {}
91-
if not isinstance(devices_raw, Mapping):
92-
raise TypeError(f"verification.devices must be a mapping, got {type(devices_raw).__name__}")
93-
devices_map = dict(devices_raw)
94-
devices_map.setdefault("cpu", self.config.devices.cpu or "cpu")
95-
if self.config.devices.cuda:
96-
devices_map.setdefault("cuda", self.config.devices.cuda)
97-
9887
self.logger.info("=" * 80)
9988
self.logger.info(f"Running Verification (mode: {export_mode.value})")
10089
self.logger.info("=" * 80)
@@ -104,8 +93,8 @@ def run(self, artifact_manager: ArtifactManager) -> Dict[str, Any]:
10493
total_failed = 0
10594

10695
for i, policy in enumerate(scenarios):
107-
ref_device = self._resolve_device(policy.ref_device, devices_map)
108-
test_device = self._resolve_device(policy.test_device, devices_map)
96+
ref_device = policy.ref_device
97+
test_device = policy.test_device
10998

11099
self.logger.info(
111100
f"\nScenario {i+1}/{len(scenarios)}: "
@@ -164,18 +153,3 @@ def run(self, artifact_manager: ArtifactManager) -> Dict[str, Any]:
164153
}
165154

166155
return all_results
167-
168-
def _resolve_device(self, device_key: str, devices_map: Mapping[str, str]) -> DeviceSpec:
169-
"""
170-
Resolve a device key to a full device string.
171-
172-
Args:
173-
device_key: Device key to resolve
174-
devices_map: Mapping of device keys to full device strings
175-
Returns:
176-
Resolved device
177-
"""
178-
if device_key in devices_map:
179-
return DeviceSpec.from_value(devices_map[device_key])
180-
self.logger.warning(f"Device alias '{device_key}' not found in devices map, using as-is")
181-
return DeviceSpec.from_value(device_key)

0 commit comments

Comments
 (0)