|
16 | 16 |
|
17 | 17 | import primus.backends.transformer_engine.transformer_engine_torch as ptex |
18 | 18 |
|
19 | | -if is_te_min_version("2.0"): |
20 | | - from transformer_engine.pytorch.cpp_extensions.gemm import ( |
21 | | - reset_swizzled_inputs, |
22 | | - swizzle_inputs, |
| 19 | +if is_te_min_version("2.3", check_equality=False): |
| 20 | + from transformer_engine.debug.pytorch.debug_quantization import DebugQuantizer |
| 21 | + from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import ( |
| 22 | + Float8BlockwiseQTensorBase, |
23 | 23 | ) |
| 24 | + |
| 25 | +if is_te_min_version("2.0"): |
| 26 | + |
| 27 | + # TE version >= 2.0 and <= 2.3 |
| 28 | + if not is_te_min_version("2.3", check_equality=False): |
| 29 | + from transformer_engine.pytorch.cpp_extensions.gemm import ( |
| 30 | + reset_swizzled_inputs, |
| 31 | + swizzle_inputs, |
| 32 | + ) |
24 | 33 | from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer |
25 | 34 |
|
26 | 35 | def general_gemm( |
@@ -86,6 +95,19 @@ def general_gemm( |
86 | 95 | if not out.is_contiguous(): |
87 | 96 | raise ValueError("Output tensor is not contiguous.") |
88 | 97 |
|
| 98 | + # TE version > 2.3 |
| 99 | + if is_te_min_version("2.3", check_equality=False): |
| 100 | + debug_quantizer = None |
| 101 | + if isinstance(quantization_params, DebugQuantizer): |
| 102 | + debug_quantizer = quantization_params |
| 103 | + quantization_params = quantization_params.parent_quantizer |
| 104 | + A = A.get_tensor(not transa) |
| 105 | + B = B.get_tensor(transb) |
| 106 | + if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): |
| 107 | + # There is not use_split_accumulator == False |
| 108 | + # implementation for Float8BlockwiseQTensorBase GEMM |
| 109 | + use_split_accumulator = True |
| 110 | + |
89 | 111 | # Use bfloat16 as default bias_dtype |
90 | 112 | bias_dtype = torch.bfloat16 if bias is None else bias.dtype |
91 | 113 |
|
@@ -114,9 +136,18 @@ def general_gemm( |
114 | 136 | "bulk_overlap": bulk_overlap, |
115 | 137 | } |
116 | 138 |
|
117 | | - original_scale_inverses = swizzle_inputs(A, B, layout) |
| 139 | + # TE version >= 2.0 and <= 2.3 |
| 140 | + if not is_te_min_version("2.3", check_equality=False): |
| 141 | + original_scale_inverses = swizzle_inputs(A, B, layout) |
| 142 | + |
118 | 143 | out, bias_grad, gelu_input, extra_output = ptex.generic_gemm(*args, **kwargs) |
119 | | - reset_swizzled_inputs(A, B, original_scale_inverses) |
| 144 | + |
| 145 | + # TE version >= 2.0 and <= 2.3 |
| 146 | + if not is_te_min_version("2.3", check_equality=False): |
| 147 | + reset_swizzled_inputs(A, B, original_scale_inverses) |
| 148 | + elif debug_quantizer is not None: |
| 149 | + # TE version >= 2.4 |
| 150 | + out = debug_quantizer.process_gemm_output(out) |
120 | 151 |
|
121 | 152 | return out, bias_grad, gelu_input, extra_output |
122 | 153 |
|
|
0 commit comments