-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathop_task.py
220 lines (181 loc) · 7.26 KB
/
op_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import dataclasses
import gc
import os
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from tritonbench.components.tasks import base as base_task
from tritonbench.components.workers import subprocess_worker
class Worker(subprocess_worker.SubprocessWorker):
"""Run subprocess using taskset if CPU affinity is set.
When GOMP_CPU_AFFINITY is set, importing `torch` in the main process has
the very surprising effect of changing the threading behavior in the
subprocess. (See https://github.com/pytorch/pytorch/issues/49971 for
details.) This is a problem, because it means that the worker is not
hermetic and also tends to force the subprocess torch to run in single
threaded mode which drastically skews results.
This can be ameliorated by calling the subprocess using `taskset`, which
allows the subprocess PyTorch to properly bind threads.
"""
@property
def args(self) -> List[str]:
affinity = os.environ.get("GOMP_CPU_AFFINITY", "")
return (["taskset", "--cpu-list", affinity] if affinity else []) + super().args
@dataclasses.dataclass(frozen=True)
class OpDetails:
"""Static description of what a particular TritonBench operator supports.
When parameterizing tests, we only want to generate sensible ones.
(e.g. Those where an operator can be imported and supports the feature to be
tested or benchmarked.) This requires us to import the operator; however many
of the operators are EXTREMELY stateful, and even importing them consumes
significant system resources. As a result, we only want one (or a few)
alive at any given time.
Note that affinity cannot be solved by simply calling `torch.set_num_threads`
in the child process; this will cause PyTorch to use all of the cores but
at a much lower efficiency.
This class describes what a particular operator does and does not support, so
that we can release the underlying subprocess but retain any pertinent
metadata.
"""
name: str
exists: bool
class OpTask(base_task.TaskBase):
# The worker may (and often does) consume significant system resources.
# In order to ensure that runs do not interfere with each other, we only
# allow a single OpTask to exist at a time.
_lock = threading.Lock()
def __init__(
self,
name: str,
timeout: Optional[float] = None,
extra_env: Optional[Dict[str, str]] = None,
save_output_dir: Optional[Path] = None,
) -> None:
gc.collect() # Make sure previous task has a chance to release the lock
assert self._lock.acquire(blocking=False), "Failed to acquire lock."
self._op_name = name
self._worker = Worker(
timeout=timeout, extra_env=extra_env, save_output_dir=save_output_dir
)
self.worker.run("import torch")
self._details: OpDetails = OpDetails(
**self._maybe_import_operator(
package=__name__,
op_name=name,
)
)
# =========================================================================
# == Import Operator in the child process ====================================
# =========================================================================
@property
def worker(self) -> subprocess_worker.SubprocessWorker:
return self._worker
@base_task.run_in_worker(scoped=True)
@staticmethod
def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]:
import importlib
import os
import traceback
from tritonbench.operators import load_opbench_by_name
Operator = load_opbench_by_name(op_name)
# Populate global namespace so subsequent calls to worker.run can access `Operator`
globals()["Operator"] = Operator
# This will be used to populate a `OpDetails` instance in the parent.
return {
"name": op_name,
"exists": Operator is not None,
}
# =========================================================================
# == Instantiate a concrete `op` instance ==============================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def make_operator_instance(
args: List[str],
) -> None:
from tritonbench.utils.parser import get_parser
parser = get_parser()
tb_args, extra_args = parser.parse_known_args(args)
Operator = globals()["Operator"]
parser = get_parser()
op = Operator(
tb_args=tb_args,
extra_args=extra_args,
)
import gc
gc.collect()
if op.device == "cuda":
maybe_sync = torch.cuda.synchronize
else:
maybe_sync = lambda: None
globals().update(
{
"op": op,
"maybe_sync": maybe_sync,
}
)
# =========================================================================
# == Forward calls to `op` from parent to worker =======================
# =========================================================================
def run(self) -> None:
self.worker.run(
"""
op.run()
maybe_sync()
"""
)
# =========================================================================
# == Get Operator attribute in the child process =============================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def get_attribute(
attr: str, field: Optional[str] = None, classattr: bool = False
) -> Any:
if classattr:
op = globals()["Operator"]
else:
op = globals()["op"]
if hasattr(op, attr):
if field:
op_attr = getattr(op, attr)
return getattr(op_attr, field)
else:
return getattr(op, attr)
else:
return None
# =========================================================================
# == Check output is expected in the child process ========================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def check_output() -> None:
op = globals()["op"]
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS
output = op.output
output_impls = output.result[0][1].keys()
ci_enabled_impls = [
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in op._skip
]
# Make sure that all the ci_enabled impls are in the output
assert set(output_impls) == set(
ci_enabled_impls
), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}"
def del_op_instance(self):
self.worker.run(
"""
del op
del maybe_sync
"""
)
self.gc_collect()
def gc_collect(self) -> None:
self.worker.run(
"""
import gc
gc.collect()
"""
)
def __del__(self) -> None:
self._lock.release()