|
| 1 | +"""GPU type normalization utility. |
| 2 | +
|
| 3 | +Normalizes user-specified GPU types to canonical names used in benchmark data. |
| 4 | +Uses ModelCatalog as the single source of truth for GPU aliases. |
| 5 | +""" |
| 6 | + |
| 7 | +import logging |
| 8 | +from typing import TYPE_CHECKING |
| 9 | + |
| 10 | +if TYPE_CHECKING: |
| 11 | + from ..knowledge_base.model_catalog import ModelCatalog |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +# Canonical GPU names from benchmark data |
| 16 | +CANONICAL_GPUS = {"L4", "A100-40", "A100-80", "H100", "H200", "B200"} |
| 17 | + |
| 18 | +# Expansion map for shorthand/ambiguous names |
| 19 | +# When user says "A100" without specifying variant, include both |
| 20 | +GPU_EXPANSIONS = { |
| 21 | + "A100": ["A100-80", "A100-40"], |
| 22 | +} |
| 23 | + |
| 24 | +# Singleton catalog instance to avoid repeated loading |
| 25 | +_catalog_instance: "ModelCatalog | None" = None |
| 26 | + |
| 27 | + |
| 28 | +def _get_catalog() -> "ModelCatalog": |
| 29 | + """Get or create the ModelCatalog singleton.""" |
| 30 | + global _catalog_instance |
| 31 | + if _catalog_instance is None: |
| 32 | + from ..knowledge_base.model_catalog import ModelCatalog |
| 33 | + _catalog_instance = ModelCatalog() |
| 34 | + return _catalog_instance |
| 35 | + |
| 36 | + |
| 37 | +def normalize_gpu_types(gpu_types: list[str]) -> list[str]: |
| 38 | + """ |
| 39 | + Normalize GPU types to canonical names using ModelCatalog aliases. |
| 40 | +
|
| 41 | + - Case-insensitive matching |
| 42 | + - Uses ModelCatalog's alias lookup (from model_catalog.json) |
| 43 | + - Expands shorthand (A100 → [A100-80, A100-40]) |
| 44 | + - Returns empty list for empty input |
| 45 | +
|
| 46 | + Args: |
| 47 | + gpu_types: List of GPU type strings from user input or intent extraction |
| 48 | +
|
| 49 | + Returns: |
| 50 | + List of canonical GPU names (uppercase), deduplicated and sorted |
| 51 | + """ |
| 52 | + if not gpu_types: |
| 53 | + return [] |
| 54 | + |
| 55 | + catalog = _get_catalog() |
| 56 | + normalized = set() |
| 57 | + |
| 58 | + for gpu in gpu_types: |
| 59 | + if not gpu or not isinstance(gpu, str): |
| 60 | + continue |
| 61 | + |
| 62 | + gpu_stripped = gpu.strip() |
| 63 | + gpu_upper = gpu_stripped.upper() |
| 64 | + |
| 65 | + # Skip empty or "any gpu" values |
| 66 | + if not gpu_upper or gpu_upper == "ANY GPU": |
| 67 | + continue |
| 68 | + |
| 69 | + # Check if it's an expansion case (e.g., A100 → both variants) |
| 70 | + if gpu_upper in GPU_EXPANSIONS: |
| 71 | + normalized.update(GPU_EXPANSIONS[gpu_upper]) |
| 72 | + logger.debug(f"Expanded '{gpu}' to {GPU_EXPANSIONS[gpu_upper]}") |
| 73 | + continue |
| 74 | + |
| 75 | + # Use ModelCatalog's alias lookup (handles case-insensitivity) |
| 76 | + gpu_info = catalog.get_gpu_type(gpu_stripped) |
| 77 | + if gpu_info: |
| 78 | + normalized.add(gpu_info.gpu_type.upper()) |
| 79 | + logger.debug(f"Resolved '{gpu}' to '{gpu_info.gpu_type}' via ModelCatalog") |
| 80 | + continue |
| 81 | + |
| 82 | + # Check if it's already a canonical name (direct match) |
| 83 | + if gpu_upper in CANONICAL_GPUS: |
| 84 | + normalized.add(gpu_upper) |
| 85 | + continue |
| 86 | + |
| 87 | + # Unknown GPU type - log warning and skip |
| 88 | + logger.warning( |
| 89 | + f"Unknown GPU type '{gpu}' - not found in ModelCatalog or canonical list. " |
| 90 | + "Skipping this GPU filter." |
| 91 | + ) |
| 92 | + |
| 93 | + return sorted(normalized) # Sorted for consistent ordering |
0 commit comments