forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 79
Expand file tree
/
Copy pathtriton_heuristics.py
More file actions
3447 lines (3019 loc) · 125 KB
/
triton_heuristics.py
File metadata and controls
3447 lines (3019 loc) · 125 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
from __future__ import annotations
import builtins
import copy
import dataclasses
import functools
import hashlib
import inspect
import itertools
import logging
import math
import operator
import os
import os.path
import re
import sys
import threading
import time
from collections import namedtuple
from typing import (
Any,
Callable,
Generic,
Literal,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
import torch
from torch._dynamo.utils import set_feature_use
from torch._environment import is_fbcode
from torch._prims_common import compute_required_storage_length
from torch.utils._ordered_set import OrderedSet
from ..triton_bundler import TritonBundler
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
from . import triton_helpers
from .autotune_cache import AutotuneCache
from .benchmarking import benchmarker
from .coordinate_descent_tuner import CoordescTuner
from .hints import (
_NUM_THREADS_PER_WARP,
AutotuneHint,
DeviceProperties,
HeuristicType,
ReductionHint,
TileHint,
TRITON_MAX_BLOCK,
TRITON_MAX_RSPLIT,
)
from .runtime_utils import (
ceildiv,
conditional_product,
create_bandwidth_info_str,
dynamo_timed,
get_first_attr,
get_max_y_grid,
get_num_bytes,
next_power_of_2,
triton_cache_dir,
triton_config_to_hashable,
triton_hash_to_path_key,
validate_triton_config,
)
from .static_cuda_launcher import StaticallyLaunchedCudaKernel
from .triton_compat import (
ASTSource,
autograd_profiler,
cc_warp_size,
CompiledKernel,
Config,
GPUTarget,
HAS_WARP_SPEC,
KernelInterface,
knobs,
OutOfResources,
PTXASError,
triton,
)
class InductorConfig(Config):
"""Inductor-specific Triton config with additional control flags"""
def __init__(self, *args, dynamic_scale_rblock=True, **kwargs):
super().__init__(*args, **kwargs)
self.dynamic_scale_rblock = dynamic_scale_rblock
class NoTritonConfigsError(RuntimeError):
pass
if TYPE_CHECKING:
from collections.abc import Container, Hashable
from torch._guards import CompileId
LauncherType = Any
_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel]
_T = TypeVar("_T", bound=_KernelType)
log = logging.getLogger(__name__)
triton_name_sub = re.compile(r"^def [^(]+\(")
def generate_lookup_hash_from_source_code(size_hints_str: str, source_code: str) -> str:
# Name agnostic + strip white space
fn_strip_name = re.sub(triton_name_sub, "(", source_code.strip(), count=1)
hash_str = size_hints_str + fn_strip_name
fn_hash = hashlib.sha256(hash_str.encode("utf-8")).hexdigest()
return fn_hash
def lookup_autotune_config(size_hints, fn) -> Optional[Config]:
lookup_table = torch._inductor.config.autotune_lookup_table
cached_config = None
if len(lookup_table) > 0 and "_fused_" in fn.src:
fn_hash = generate_lookup_hash_from_source_code(str(size_hints), fn.src)
if fn_hash in lookup_table:
config_dict = lookup_table[fn_hash]
block_configs = {k: v for k, v in config_dict.items() if "BLOCK" in k}
cached_config = Config(
block_configs,
num_warps=config_dict["num_warps"],
num_stages=config_dict["num_stages"],
)
return cached_config
def get_total_reduction_numel(numels: dict[str, int]) -> int:
return conditional_product(
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
)
def autotune_hints_to_configs(
hints: OrderedSet[AutotuneHint],
size_hints,
block_size: int,
device_props: DeviceProperties,
) -> list[Config]:
"""
AutotuneHints can be attached to the metadata of triton kernels for providing
suggestions about what to try for autotuning. One reason to do this is if there are
some configs that are only useful in specific scenarios, in which case we can avoid
wasting compile time on autotuning unless we know we are in one of those scenarios.
Based on those hints, this function will generate a list of additional autotuning
configs to try.
"""
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
configs: list[Config] = []
for hint in hints:
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
if len(size_hints) == 1:
xyz_options = ((block_size // 4, None, None),)
elif len(size_hints) == 2:
xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
elif len(size_hints) == 3:
xyz_options = (
(block_size // 4, 1, 1),
(1, block_size // 4, 1),
(1, 1, block_size // 4),
)
configs.extend(
triton_config(
size_hints,
*xyz,
num_elements_per_warp=(
device_props.warp_size if device_props.warp_size else 32
),
)
for xyz in xyz_options
)
return configs
def disable_pointwise_autotuning(inductor_meta):
# Autotuning can give different benchmarking results from run to run, and
# therefore we disable autotuning when use_deterministic flag is on.
if inductor_meta.get("are_deterministic_algorithms_enabled"):
return True
return not inductor_meta.get("autotune_pointwise", True)
def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
call_args = []
call_kwargs = {}
for arg in args:
if isinstance(arg, (int, bool)):
call_args.append(str(arg))
else:
call_args.append("T")
for k, v in kwargs.items():
if isinstance(arg, (int, bool)):
call_kwargs[k] = v
else:
call_kwargs[k] = v
call_kwargs.update(launcher.config.kwargs)
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
if HAS_WARP_SPEC:
call_kwargs["num_consumer_groups"] = getattr(
launcher.config, "num_consumer_groups", 0
)
call_kwargs["num_buffers_warp_spec"] = getattr(
launcher.config, "num_buffers_warp_spec", 0
)
args_str = [*call_args]
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
args_str = ", ".join(args_str)
abs_path = os.path.abspath(sys.argv[0])
with open(f"{abs_path}.launch_params", "a") as f:
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
def check_autotune_cache(
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
"""
Given a list of configs, checks autotune cache and return metadata
"""
autotune_cache = None
autotune_cache_info = {}
disabled = inductor_meta.get("force_disable_caches", False)
if (
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
):
configs_hash = hash_configs(configs)
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
if autotune_cache:
if best_config := autotune_cache.read_best(inductor_meta, configs):
configs = [best_config]
autotune_cache_info["best_config"] = triton_config_to_hashable(
best_config
)
autotune_cache_info["autotune_cache_state"] = "hit"
else:
autotune_cache_info["autotune_cache_state"] = "miss"
autotune_cache_info["num_configs"] = len(configs)
if inductor_meta.get("coordinate_descent_tuning"):
autotune_cache_info["coordesc_tuning"] = True
if len(configs) == 1:
# This is the config that coordinate descent tuning started at, which
# is not the same as the final config chosen (i.e. only_config, best_config)
autotune_cache_info["coordesc_tuning_start_config"] = (
triton_config_to_hashable(configs[0])
)
else:
if len(configs) == 1:
autotune_cache_info["autotune_cache_state"] = "only 1 config"
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])
if disabled:
autotune_cache_info["autotune_cache_state"] = "force_disabled"
log.debug("autotune caching is disabled by config.force_disable_caches")
return configs, autotune_cache, autotune_cache_info
class CachingAutotuner(KernelInterface):
"""
Simplified version of Triton autotuner that has no invalidation
key and caches the best config to disk to improve cold start times.
Unlike the main triton Autotuner, this version can precompile all
configs, and does not rely on the Triton JIT.
"""
def __init__(
self,
fn,
triton_meta, # passed directly to triton
configs,
save_cache_hook,
mutated_arg_names: list[str], # see [Note: clone mutated buffers]
optimize_mem,
heuristic_type,
size_hints=None,
inductor_meta=None, # metadata not relevant to triton
custom_kernel=False, # whether the kernel is inductor-generated or custom
filename: Optional[str] = None,
reset_to_zero_arg_names: Optional[list[str]] = None,
autotune_cache_info: Optional[dict[str, Any]] = None,
):
super().__init__()
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
# makes sure there are no pre-hooks on any of the triton configs
for cfg in configs:
validate_triton_config(cfg)
self.fn = fn
self.device_props: DeviceProperties = triton_meta["device"]
self.triton_meta = {
**triton_meta,
"device": self.device_props.index,
"device_type": self.device_props.type,
}
self.inductor_meta = {} if inductor_meta is None else inductor_meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.reset_to_zero_arg_names = (
[] if reset_to_zero_arg_names is None else reset_to_zero_arg_names
)
self.optimize_mem = optimize_mem
cached_config = lookup_autotune_config(size_hints, fn)
self.configs = [cached_config] if cached_config else configs
self.heuristic_type = heuristic_type
self.custom_kernel = custom_kernel
self.cuda_kernel_saved = False
self.autotune_cache_info = autotune_cache_info
if log.isEnabledFor(logging.DEBUG):
log.debug(
"CachingAutotuner gets %d configs for %s",
len(self.configs),
self.fn.__name__,
)
for c in self.configs:
log.debug(c)
self.compile_results: list[CompileResult[_KernelType]] = []
self.launchers: list[LauncherType] = []
self.lock = threading.Lock()
if os.getenv("TRITON_CACHE_DIR") is None:
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
self.triton_meta.get("device", 0)
)
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
self.size_hints = size_hints
self.coordesc_tuner = CoordescTuner(
is_mm=False,
name=self.fn.__name__,
size_hints=size_hints,
inductor_meta=self.inductor_meta,
)
self.filename = filename
# used for profiling
self.kernel_hash: str = ""
# Kernels are stored in the codecache with the filename as a hash of the code.
# We rely on this to obtain the kernel hash
if self.filename is not None:
base_name = os.path.basename(self.filename)
if ".py" in base_name:
self.kernel_hash = os.path.splitext(base_name)[0]
self.precompile_time_taken_ns = 0
self.autotune_time_taken_ns = 0
# Dumps the launch configs after autotuning.
self.dump_launch_params = (
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
)
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
# Compile-time info included in runtime logginging
self.compile_id: Optional[CompileId] = None
self.is_backward = False
# Mode for launch grid calculation
self.grid_mode: Literal["python", "python_slow", "cpp"] = "python"
def is_statically_launchable(self):
"""
Checks if every compiled kernel is statically launchable, which
allows us to efficiently cache it in FXGraphCache
"""
if not self.compile_results:
return False
return all(
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
)
def recheck_autotune_cache(
self, reload_kernel_from_src: Callable[[], CachingAutotuner]
) -> None:
"""
On cache load on static autotuner, we need to recheck the autotune cache, since
a best config could have been found from a previous run
"""
assert self.is_statically_launchable()
configs = [result.config for result in self.compile_results]
(cached_configs, _, autotune_cache_info) = check_autotune_cache(
configs, self.filename, self.inductor_meta
)
self.autotune_cache_info = autotune_cache_info
# I.e. there was an autotune cache hit
if len(cached_configs) == 1 and len(configs) > 1:
best_config = cached_configs[0]
# Grab the best compiled config, if it's in the list of available ones
best_config_hash = triton_config_to_hashable(best_config)
for compile_result in self.compile_results:
if triton_config_to_hashable(compile_result.config) == best_config_hash:
self.compile_results = [compile_result]
return
# If the best config isn't in our list of compile results,
# it's likely because it was found by coordesc after the cache
# already saved
if best_config.found_by_coordesc:
with dynamo_timed("CachingAutotuner.slow_precompile_config"):
if self.fn.fn is None:
self.fn = reload_kernel_from_src().fn
self.compile_results = [self._precompile_config(best_config)]
def set_compile_info(
self, compile_id: Optional[CompileId], is_backward: bool
) -> None:
self.compile_id = compile_id
self.is_backward = is_backward
def precompile(
self,
warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
static_triton_bundle_key: Optional[str] = None,
):
if warm_cache_only:
self._precompile_worker()
return
with self.lock:
# Helper function for reloading a kernel generated in a worker
# in the parent class. Normally we don't need to reload the kernel
# in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
# we need to actually run compilation on the parent process
if reload_kernel is not None:
self._reload_kernel = reload_kernel
self._precompile_worker()
if static_triton_bundle_key is not None and self.is_statically_launchable():
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
self._make_launchers()
self._dynamic_scale_rblock()
def _precompile_worker(self):
if self.compile_results:
for result in self.compile_results:
TritonBundler.put(
triton_hash_to_path_key(result.kernel.hash), # type: ignore[attr-defined]
self.triton_meta.get("device", 0),
)
return
assert not self.launchers
if not self.configs:
raise NoTritonConfigsError("No triton configs are available")
compile_results = []
exc = None
for c in self.configs:
try:
compile_results.append(self._precompile_config(c))
except (OutOfResources, PTXASError) as e:
exc = e
if len(compile_results) == 0:
raise NoTritonConfigsError(
f"No valid triton configs. {type(exc).__name__}: {exc}"
)
self.compile_results = compile_results
self.configs = None
def _dynamic_scale_rblock(self):
# TODO(jansel): we should find a way to move this extra compile into the worker process
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
device_prop = self.device_props
if (
self.inductor_meta.get("dynamic_scale_rblock", True)
and not self.inductor_meta.get("persistent_reduction")
and self.heuristic_type == HeuristicType.REDUCTION
and self.size_hints is not None
# Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
and device_prop.type in ["cuda", "hip"]
and device_prop.major
and (device_prop.major >= 8 or torch.version.hip)
and device_prop.regs_per_multiprocessor is not None
):
assert device_prop.regs_per_multiprocessor
assert device_prop.max_threads_per_multi_processor
assert device_prop.multi_processor_count
seen_config_hashes: Optional[OrderedSet[Hashable]] = None
warp_size = device_prop.warp_size or 32
for result in self.compile_results:
triton_config = result.config
compiled_binary = result.kernel
assert len(self.size_hints) >= 2
xblock = triton_config.kwargs.get("XBLOCK", 1)
reduction_kwargs = [
kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
]
rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs]
total_block = (self.size_hints["x"] + xblock - 1) // xblock
nreg = getattr(compiled_binary, "n_regs", None)
if nreg is None:
continue
# make sure rblocks are not too small
if conditional_product(*rblocks) <= 64:
continue
# each SM of A100 has 65536 32-bit registers. To maximize
# the theoretical occupancy, we need run 2048 threads on each
# SM. So each thread should use no more than 65536 / 2048
# = 32 registers. In cases where occupancy matters, and each
# thread uses too many registers, reduce R0_BLOCK to reduce
# the register usage.
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
# from PLBartForCausalLM, latency improve from
# 7.795ms to 4.883ms.
#
if (
nreg
<= device_prop.regs_per_multiprocessor
// device_prop.max_threads_per_multi_processor
):
continue
nreg_per_warp = nreg * warp_size
nreg_per_block = nreg_per_warp * triton_config.num_warps
# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
# The formula below is a tighter upper bound since we have the assumption that
# nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
# due to the if condition above and:
# regs_per_multiprocessor / nreg_per_block
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
# = max_threads_per_multi_processor / (32 * num_warps)
# Using a tigher upper bound can reveal more optimization opportunities.
max_blocks_per_sm = max(
device_prop.regs_per_multiprocessor // nreg_per_block, 1
)
if total_block <= max_blocks_per_sm * device_prop.multi_processor_count:
# no need to improve occupancy
continue
new_config = copy.deepcopy(triton_config)
# Reduce the largest Rn_BLOCK by a factor of 2.
largest_rkwarg: str = max(
reduction_kwargs, key=triton_config.kwargs.__getitem__
)
new_config.kwargs[largest_rkwarg] //= 2
if seen_config_hashes is None:
seen_config_hashes = OrderedSet(
[
triton_config_to_hashable(x.config)
for x in self.compile_results
]
)
new_config_hash = triton_config_to_hashable(new_config)
if new_config_hash in seen_config_hashes:
continue
seen_config_hashes.add(new_config_hash)
log.debug(
"Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
largest_rkwarg,
triton_config,
new_config,
)
if self.fn.fn is None:
"""
We are in the parent process, while this program was compiled in a worker
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
containing the real fn yet.
"""
assert hasattr(self, "_reload_kernel")
assert callable(self._reload_kernel)
self.fn = self._reload_kernel().fn
self.compile_results.append(self._precompile_config(new_config)) # noqa: B909
self._make_launchers()
def _make_launchers(self):
if len(self.launchers) == len(self.compile_results):
return
from torch._dynamo.device_interface import DeviceGuard
device_interface = self.get_device_interface()
# load binary to the correct device
with DeviceGuard(device_interface, self.triton_meta["device"]):
# need to initialize context
with dynamo_timed(
"CachingAutotuner.synchronize",
# Deliberately avoid overloading pt2_compile_events:
log_pt2_compile_event=False,
):
device_interface.synchronize(device_interface.current_device())
launchers = []
exc = None
for result in self.compile_results:
try:
launchers.append(result.make_launcher())
except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e:
exc = e
if len(launchers) == 0:
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
self.launchers = launchers
def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]:
"""Drop stuff from triton.JITFunction that does not pickle.
This must be called after precompile so that these things are no longer needed.
Returns a tuple of old values
"""
old_values = (
self.fn.fn,
self.fn.__globals__,
self.fn.used_global_vals,
self.fn.repr,
self.launchers,
getattr(self.fn, "_hash_lock", None),
)
self.fn.fn = None
self.fn.__globals__ = None
self.fn.used_global_vals = None
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
self.launchers = []
self.fn._hash_lock = None
return old_values
def restore_after_unpickle(
self, old_values: Optional[tuple[Any, Any, Any, Any, Any, Any]]
) -> None:
if old_values:
(
self.fn.fn,
self.fn.__globals__,
self.fn.used_global_vals,
self.fn.repr,
self.launchers,
self.fn._hash_lock,
) = old_values
else:
# even if we don't need/have specific values, we do need the
# _hash_lock to be a valid RLock
self.fn._hash_lock = threading.RLock()
def prepare_for_caching(self) -> None:
"""
Statically Launched CUDA Kernels have a raw cubin on them
that we don't need to store in the cache(since TritonBundler handles the collection for us)
"""
for result in self.compile_results:
if isinstance(result, StaticTritonCompileResult):
# Don't save this in the inductor cache, as it is very large
result.kernel.cubin_raw = None
def __getstate__(self) -> dict[str, Any]:
assert not self.launchers, (
"pickle should not be called with after make_launchers()"
)
return {
**self.__dict__,
"lock": None,
}
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
self.lock = threading.Lock()
def get_device_interface(self):
# this code cannot run in compile workers, because it imports from torch
from torch._dynamo.device_interface import get_interface_for_device
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
"""Ahead of time compile a given autotuner config."""
compile_meta = copy.deepcopy(self.triton_meta)
cfg_kwargs = cfg.kwargs
if self.device_props.type == "hip":
cfg_kwargs = {**cfg_kwargs}
for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"):
if k in cfg_kwargs:
compile_meta[k] = cfg_kwargs.pop(k)
compile_meta["constants"].update(cfg_kwargs)
for i in self.fn.constexprs:
arg_name = self.fn.arg_names[i]
if arg_name not in compile_meta["constants"] and (
arg_name == "num_warps" or arg_name == "num_stages"
):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
if HAS_WARP_SPEC:
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
compile_meta["num_buffers_warp_spec"] = getattr(
cfg, "num_buffers_warp_spec", 0
)
compile_meta["debug"] = self.inductor_meta.get(
"assert_indirect_indexing", True
) and not self.inductor_meta.get("is_hip", False)
# device type will be "hip" rather than "cuda" here
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
if self.device_props.type == "cpu":
triton_helpers.set_driver_to_cpu()
else:
triton_helpers.set_driver_to_gpu()
if not ASTSource:
raise RuntimeError("Installed triton version too old, please upgrade")
compile_args = (
ASTSource(
self.fn,
compile_meta["signature"],
compile_meta["constants"],
compile_meta["configs"][0],
),
)
if self.device_props.type == "mtia":
from mtia.host_runtime.torch_mtia.acc_flags import ( # type: ignore[import-not-found]
build_codename,
)
arch = build_codename()
else:
arch = compile_meta["cc"]
target = GPUTarget(
compile_meta["device_type"],
arch,
cc_warp_size(compile_meta["cc"]),
)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if "enable_fp_fusion" in compile_meta:
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
if HAS_WARP_SPEC:
options.update(
{
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
"num_buffers_warp_spec": compile_meta.get(
"num_buffers_warp_spec", 0
),
}
)
if self.device_props.type == "cuda":
options.update(
{
"launch_cooperative_grid": compile_meta.get(
"launch_cooperative_grid", False
),
"launch_pdl": compile_meta.get("launch_pdl", False), # True
}
)
if self.device_props.type == "hip":
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
if "matrix_instr_nonkdim" in compile_meta:
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
compile_kwargs = {
"target": target,
"options": options,
}
try:
binary = triton.compile(*compile_args, **compile_kwargs)
except Exception:
log.exception(
"Triton compilation failed: %s\n%s\nmetadata: %s",
self.inductor_meta.get("kernel_name", "triton_"),
self.fn.src,
compile_meta,
)
raise
# Simulate JIT Hook call
if (
torch._inductor.config.run_jit_post_compile_hook
and knobs
and getattr(knobs.runtime, "jit_post_compile_hook", None)
):
try:
hook = knobs.runtime.jit_post_compile_hook
# base args everyone should get
call_kwargs = dict(
key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
repr=getattr(self.fn, "src", None),
fn=self.fn,
compile=binary,
is_manual_warmup=False,
already_compiled=True,
)
# only add inductor_args if the hook takes it
sig = inspect.signature(hook)
params = sig.parameters
if "inductor_args" in params:
call_kwargs["inductor_args"] = self.inductor_meta["config_args"]
hook(**call_kwargs)
except Exception:
log.exception("jit_post_compile_hook failed")
TritonBundler.put(
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
)
# If the binary has a cubin file to directly launch, save it on the binary
static_launcher = StaticTritonCompileResult.can_statically_launch(
binary, self.inductor_meta, self.triton_meta, self.heuristic_type
)
if static_launcher is not None:
result = StaticTritonCompileResult(
static_launcher, cfg, compile_meta, self.inductor_meta
)
return result
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
def bench(self, launcher, *args, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
# we don't skip configs with spilled registers when auto-tuning custom
# (user-written) Triton kernels, as (i) we don't have any knowledge or
# control over the kernel code; (ii) there is empirical evidence that
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
):
log.debug(
"Skip config %s because of register spilling: %d",
launcher.config,
launcher.n_spills,
)
return float("inf")
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
def kernel_call():
cloned_args, cloned_kwargs = self.maybe_clone_args(
cpu_copies, *args, **kwargs
)
# reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs)
if autograd_profiler._is_profiler_enabled:
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
with torch._C._profiler._RecordFunctionFast(
self.inductor_meta.get("kernel_name", "triton kernel"),
cloned_args,
profiler_kwargs,
):
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
else:
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
self.restore_args_from_cpu(cpu_copies)
# only use profiler when not already in a profiler instance
if with_profiler and not autograd_profiler._is_profiler_enabled:
from torch._inductor.utils import do_bench_using_profiling
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
if self.device_props.type == "cpu":
return benchmarker.benchmark_cpu(kernel_call)
return benchmarker.benchmark_gpu(kernel_call, rep=40)
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
"""
To support benchmarking in the presence of mutated args, we need to avoid
autotuning contanminating them. We try to pass cloned args to the kernel.
If those clones would increase the peak memory usage, however, we instead
copy to cpu and restore them after each iteration. Figure out the args
to be copied and do the copying.
"""
if not self.optimize_mem:
return {}
copies = {}
try:
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
except RuntimeError:
# Possibly a custom CUDA allocator, see https://github.com/pytorch/pytorch/issues/163257
return {}
def maybe_copy(name, arg):
if name in self.mutated_arg_names and arg.is_cuda:
nonlocal budget
assert isinstance(arg, torch.Tensor)
required_storage_length = compute_required_storage_length(
arg.size(),
arg.stride(),
0,
)
size = required_storage_length * arg.element_size()
if size > budget:
cpu_arg = torch.empty_strided(
(required_storage_length,),
(1,),
dtype=arg.dtype,
device="cpu",
pin_memory=True,
)
cpu_arg.copy_(
arg.as_strided((required_storage_length,), (1,)),
non_blocking=True,
)
copies[name] = (arg, cpu_arg)
else:
budget -= size
for name, arg in zip(self.fn.arg_names, args):
maybe_copy(name, arg)
for name, arg in kwargs.items():
maybe_copy(name, arg)
return copies
def restore_args_from_cpu(self, cpu_copies):
for pair in cpu_copies.values():
arg, cpu_arg = pair
required_storage_length = compute_required_storage_length(
arg.size(),
arg.stride(),
0,
)
arg.as_strided((required_storage_length,), (1,)).copy_(
cpu_arg, non_blocking=True
)
def reset_to_zero_args(self, *args, **kwargs):
if not self.reset_to_zero_arg_names:
return
for i, arg in enumerate(args):
if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
assert isinstance(
arg,
torch.Tensor,
), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_()
for name, arg in kwargs.items():
if name in self.reset_to_zero_arg_names:
assert isinstance(
arg,
torch.Tensor,
), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_()
def maybe_clone_args(
self, exclude: Container[str], *args, **kwargs
) -> tuple[list[Any], dict[str, Any]]:
"""
Prepare new args and kwargs by cloning any in-place buffers
(that are not in the provided exclusion list), to avoid autotune
contaminating them. Avoid cloning the other buffers because it
leads to increased memory usage.
"""
from ..compile_fx import clone_preserve_strides