-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathhint_manager.py
More file actions
135 lines (105 loc) · 4.41 KB
/
hint_manager.py
File metadata and controls
135 lines (105 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import sys
import importlib
class BaseHintHandler:
# dynamicly find method
def trigger(self, hook_name, *args, **kwargs):
if hasattr(self, hook_name):
method = getattr(self, hook_name)
if callable(method):
try:
return method(*args, **kwargs)
except TypeError as e:
import inspect
try:
sig = inspect.signature(method)
expected = str(sig)
except Exception:
expected = "(unknown)"
actual_args = f"{len(args)} positional"
actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords"
print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}")
print(f" > Expect : {expected}")
print(f" > Actual : {actual_args}, {actual_kwargs}")
print(f" > Reason : {e}\n")
raise e
return None
class HintManager:
def __init__(self, backend_name):
self.backend_name = backend_name
# load Handler with backend name
self.handler = self._load_handler(backend_name)
def _load_handler(self, backend):
if backend == 'npu':
try:
module = importlib.import_module("triton.backends.ascend.ascend_hint_handler")
return module.AscendHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'aipu':
try:
module = importlib.import_module("triton.backends.aipu.aipu_hint_handler")
return module.AipuHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'cuda':
try:
module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler")
return module.NvidiaHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
else:
return BaseHintHandler()
# supported backend with matched version
SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"]
# TODO : npu will have conflicts if more backend involved
# mapping name
BACKEND_ALIASES = {
"ascend": "npu",
"huawei": "npu",
"nvidia": "cuda",
}
def normalize_backend_name(name: str) -> str:
if not name:
return ""
name = name.lower()
return BACKEND_ALIASES.get(name, name)
def hint_get_flagtree_backend() -> str:
detected_backend = ""
import torch
# Priority 1: Triton Driver
try:
from triton.runtime import driver
if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'):
device = driver.active.get_active_torch_device()
if isinstance(device, torch.device):
detected_backend = device.type
# unimplemented support
elif isinstance(device, str):
detected_backend = device
except ImportError:
pass
# TODO : some backend may not support priority 1, so keep priority 2 is necessary
# Priority 2: Torch Global State
if not detected_backend:
check_priority = ["aipu", "npu", "cuda"]
# 3. parse according to benefit
for candidate in check_priority:
module = getattr(torch, candidate, None)
if module and hasattr(module, "is_available") and module.is_available():
detected_backend = candidate
break
# (Normalization and Validation)
canonical_backend = normalize_backend_name(detected_backend)
if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS:
return ""
return canonical_backend
# lazy load after first call hint trigger
_global_hint_manager = None
def hint_trigger(hook_name, *args, **kwargs):
global _global_hint_manager
if _global_hint_manager is None:
_global_hint_manager = HintManager(hint_get_flagtree_backend())
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)