Skip to content

Commit 149e668

Browse files
fix(tp-overlap): adapt transformer_engine 2.4 for Megatron backend (#259)
Co-authored-by: Xiaoming-AMD <[email protected]>
1 parent 4e8d1fc commit 149e668

File tree

3 files changed

+43
-12
lines changed

3 files changed

+43
-12
lines changed

primus/backends/transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,20 @@
1616

1717
import primus.backends.transformer_engine.transformer_engine_torch as ptex
1818

19-
if is_te_min_version("2.0"):
20-
from transformer_engine.pytorch.cpp_extensions.gemm import (
21-
reset_swizzled_inputs,
22-
swizzle_inputs,
19+
if is_te_min_version("2.3", check_equality=False):
20+
from transformer_engine.debug.pytorch.debug_quantization import DebugQuantizer
21+
from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import (
22+
Float8BlockwiseQTensorBase,
2323
)
24+
25+
if is_te_min_version("2.0"):
26+
27+
# TE version >= 2.0 and <= 2.3
28+
if not is_te_min_version("2.3", check_equality=False):
29+
from transformer_engine.pytorch.cpp_extensions.gemm import (
30+
reset_swizzled_inputs,
31+
swizzle_inputs,
32+
)
2433
from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer
2534

2635
def general_gemm(
@@ -86,6 +95,19 @@ def general_gemm(
8695
if not out.is_contiguous():
8796
raise ValueError("Output tensor is not contiguous.")
8897

98+
# TE version > 2.3
99+
if is_te_min_version("2.3", check_equality=False):
100+
debug_quantizer = None
101+
if isinstance(quantization_params, DebugQuantizer):
102+
debug_quantizer = quantization_params
103+
quantization_params = quantization_params.parent_quantizer
104+
A = A.get_tensor(not transa)
105+
B = B.get_tensor(transb)
106+
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
107+
# There is not use_split_accumulator == False
108+
# implementation for Float8BlockwiseQTensorBase GEMM
109+
use_split_accumulator = True
110+
89111
# Use bfloat16 as default bias_dtype
90112
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
91113

@@ -114,9 +136,18 @@ def general_gemm(
114136
"bulk_overlap": bulk_overlap,
115137
}
116138

117-
original_scale_inverses = swizzle_inputs(A, B, layout)
139+
# TE version >= 2.0 and <= 2.3
140+
if not is_te_min_version("2.3", check_equality=False):
141+
original_scale_inverses = swizzle_inputs(A, B, layout)
142+
118143
out, bias_grad, gelu_input, extra_output = ptex.generic_gemm(*args, **kwargs)
119-
reset_swizzled_inputs(A, B, original_scale_inverses)
144+
145+
# TE version >= 2.0 and <= 2.3
146+
if not is_te_min_version("2.3", check_equality=False):
147+
reset_swizzled_inputs(A, B, original_scale_inverses)
148+
elif debug_quantizer is not None:
149+
# TE version >= 2.4
150+
out = debug_quantizer.process_gemm_output(out)
120151

121152
return out, bias_grad, gelu_input, extra_output
122153

primus/backends/transformer_engine/transformer_engine_torch/comm_overlap.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def is_p2p_overlap(self) -> bool: ...
124124
def is_fp8_ubuf(self) -> bool:
125125
return self.buf_dtype.itemsize == 1
126126

127-
def copy_into_buffer(self, input: torch.Tensor, quantizer: Quantizer, local_chunk: bool = False):
127+
def copy_into_buffer(
128+
self, input: torch.Tensor, quantizer: Quantizer = None, local_chunk: bool = False
129+
):
128130
"""copy input to local buffer
129131
130132
Args:

tests/run_unit_tests.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@
1313

1414
UNIT_TEST_PASS = True
1515

16+
EXCLUDE_UNIT_TESTS = []
17+
1618

1719
def get_all_unit_tests():
18-
global DISTRIBUTED_UNIT_TESTS
20+
global DISTRIBUTED_UNIT_TESTS, EXCLUDE_UNIT_TESTS
1921

2022
cur_dir = "./tests"
2123
unit_tests = {}
2224

23-
EXCLUDE_UNIT_TESTS = [
24-
"unit_tests/megatron/cco/test_tp_overlap.py",
25-
]
26-
2725
for root, dirs, files in os.walk(cur_dir):
2826
for file_name in files:
2927
if not file_name.endswith(".py") or not file_name.startswith("test_"):

0 commit comments

Comments
 (0)