|
11 | 11 | import torch
|
12 | 12 |
|
13 | 13 | from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
|
| 14 | + _to_blocked, |
14 | 15 | triton_quantize_mx4_unpack,
|
| 16 | + triton_rms_quantize_mx4_unpack, |
| 17 | + triton_silu_quantize_mx4_unpack, |
15 | 18 | )
|
16 | 19 | from fbgemm_gpu.quantize_utils import fp32_to_mx4, RoundingMode
|
17 | 20 |
|
@@ -57,8 +60,133 @@ def _test_quantize_fp4(
|
57 | 60 | x_scale[:, i] = xq_packed[:, end_idx]
|
58 | 61 |
|
59 | 62 | self.assertTrue(torch.equal(xq, xq_ref))
|
60 |
| - self.assertTrue(torch.equal(x_scale, x_scale_ref)) |
| 63 | + self.assertTrue( |
| 64 | + torch.equal(_to_blocked(x_scale), x_scale_ref.view(torch.uint8)) |
| 65 | + ) |
61 | 66 |
|
62 | 67 | _test_quantize_fp4((1, 128))
|
63 | 68 | _test_quantize_fp4((3, 512))
|
64 | 69 | _test_quantize_fp4((128, 1024))
|
| 70 | + _test_quantize_fp4((4096, 10240)) |
| 71 | + |
| 72 | + |
| 73 | +@unittest.skipIf( |
| 74 | + not torch.cuda.is_available() |
| 75 | + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, |
| 76 | + "Skip when H100 is not available", |
| 77 | +) |
| 78 | +class TestFp4RmsQuantize(unittest.TestCase): |
| 79 | + def setUp(self) -> None: |
| 80 | + torch.manual_seed(0) |
| 81 | + |
| 82 | + def test_rms_quantize_fp4(self) -> None: |
| 83 | + def _test_rms_quantize_fp4( |
| 84 | + shape: Tuple[int, int], |
| 85 | + device: str = "cuda", |
| 86 | + ) -> None: |
| 87 | + M, N = shape |
| 88 | + group_size = 32 |
| 89 | + rounding_mode = RoundingMode.even |
| 90 | + packed_group_size = group_size // 2 |
| 91 | + groups_per_row = math.ceil(N / group_size) |
| 92 | + x = torch.randn(M, N, dtype=torch.bfloat16, device=device) |
| 93 | + w = torch.randn(M, N, dtype=torch.bfloat16, device=device) |
| 94 | + xq_ref, x_scale_ref = triton_rms_quantize_mx4_unpack( |
| 95 | + x, w, EPS=1e-5, group_size=group_size, rounding_mode=rounding_mode |
| 96 | + ) |
| 97 | + |
| 98 | + intermediate = ( |
| 99 | + x.to(torch.float32).reshape(-1, group_size) |
| 100 | + * torch.rsqrt( |
| 101 | + torch.pow(x.to(torch.float32).reshape(-1, group_size), 2).mean( |
| 102 | + dim=1 |
| 103 | + ) |
| 104 | + + 1e-5 |
| 105 | + ).unsqueeze(1) |
| 106 | + ) * w.reshape(-1, group_size).to(torch.float32) |
| 107 | + |
| 108 | + intermediate = intermediate.to(torch.bfloat16).reshape(M, N) |
| 109 | + xq_packed = fp32_to_mx4( |
| 110 | + intermediate, group_size=group_size, rounding_mode=rounding_mode |
| 111 | + ) |
| 112 | + |
| 113 | + xq = torch.empty([M, N // 2], device=x.device, dtype=torch.uint8) |
| 114 | + x_scale = torch.empty( |
| 115 | + [M, groups_per_row], device=x.device, dtype=torch.uint8 |
| 116 | + ) |
| 117 | + |
| 118 | + for i in range(groups_per_row): |
| 119 | + start_idx = i * (packed_group_size + 1) |
| 120 | + end_idx = start_idx + packed_group_size |
| 121 | + xq[:, i * packed_group_size : (i + 1) * packed_group_size] = xq_packed[ |
| 122 | + :, start_idx:end_idx |
| 123 | + ] |
| 124 | + x_scale[:, i] = xq_packed[:, end_idx] |
| 125 | + |
| 126 | + self.assertTrue(torch.equal(xq, xq_ref)) |
| 127 | + self.assertTrue( |
| 128 | + torch.equal(_to_blocked(x_scale), x_scale_ref.view(torch.uint8)) |
| 129 | + ) |
| 130 | + |
| 131 | + _test_rms_quantize_fp4((1, 32)) |
| 132 | + _test_rms_quantize_fp4((1, 128)) |
| 133 | + _test_rms_quantize_fp4((3, 512)) |
| 134 | + _test_rms_quantize_fp4((128, 1024)) |
| 135 | + # TODO: fix potential bug with large tensors |
| 136 | + # _test_rms_quantize_fp4((4096, 10240)) |
| 137 | + |
| 138 | + |
| 139 | +@unittest.skipIf( |
| 140 | + not torch.cuda.is_available() |
| 141 | + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, |
| 142 | + "Skip when H100 is not available", |
| 143 | +) |
| 144 | +class TestFp4SiluQuantize(unittest.TestCase): |
| 145 | + def setUp(self) -> None: |
| 146 | + torch.manual_seed(0) |
| 147 | + |
| 148 | + def test_silu_quantize_fp4(self) -> None: |
| 149 | + def _test_silu_quantize_fp4( |
| 150 | + shape: Tuple[int, int], |
| 151 | + device: str = "cuda", |
| 152 | + ) -> None: |
| 153 | + M, N = shape |
| 154 | + group_size = 32 |
| 155 | + rounding_mode = RoundingMode.even |
| 156 | + packed_group_size = group_size // 2 |
| 157 | + groups_per_row = math.ceil(N / group_size) |
| 158 | + x = torch.randn(M, N, dtype=torch.bfloat16, device=device) |
| 159 | + w = torch.randn(M, N, dtype=torch.bfloat16, device=device) |
| 160 | + xq_ref, x_scale_ref = triton_silu_quantize_mx4_unpack( |
| 161 | + x, w, group_size=group_size, rounding_mode=rounding_mode |
| 162 | + ) |
| 163 | + intermediate = torch.nn.functional.silu(x.to(torch.float32)) * w.to( |
| 164 | + torch.float32 |
| 165 | + ) |
| 166 | + intermediate = intermediate.to(torch.bfloat16) |
| 167 | + xq_packed = fp32_to_mx4( |
| 168 | + intermediate, group_size=group_size, rounding_mode=rounding_mode |
| 169 | + ) |
| 170 | + |
| 171 | + xq = torch.empty([M, N // 2], device=x.device, dtype=torch.uint8) |
| 172 | + x_scale = torch.empty( |
| 173 | + [M, groups_per_row], device=x.device, dtype=torch.uint8 |
| 174 | + ) |
| 175 | + |
| 176 | + for i in range(groups_per_row): |
| 177 | + start_idx = i * (packed_group_size + 1) |
| 178 | + end_idx = start_idx + packed_group_size |
| 179 | + xq[:, i * packed_group_size : (i + 1) * packed_group_size] = xq_packed[ |
| 180 | + :, start_idx:end_idx |
| 181 | + ] |
| 182 | + x_scale[:, i] = xq_packed[:, end_idx] |
| 183 | + |
| 184 | + self.assertTrue(torch.equal(xq, xq_ref)) |
| 185 | + self.assertTrue( |
| 186 | + torch.equal(_to_blocked(x_scale), x_scale_ref.view(torch.uint8)) |
| 187 | + ) |
| 188 | + |
| 189 | + _test_silu_quantize_fp4((1, 128)) |
| 190 | + _test_silu_quantize_fp4((3, 512)) |
| 191 | + _test_silu_quantize_fp4((128, 1024)) |
| 192 | + _test_silu_quantize_fp4((10240, 10240)) |
0 commit comments