Skip to content

Commit f405af5

Browse files
Tianyu Liangfacebook-github-bot
Tianyu Liang
authored andcommitted
Fused SILU with quantization and RMS with quantization (#4204)
Summary: Pull Request resolved: #4204 X-link: facebookresearch/FBGEMM#1280 FP4: fused silu + quantization and rms + quantization Reviewed By: jiawenliu64 Differential Revision: D75549445
1 parent f52ab82 commit f405af5

File tree

2 files changed

+1022
-12
lines changed

2 files changed

+1022
-12
lines changed

fbgemm_gpu/experimental/gemm/test/fp4_quantize_test.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import torch
1212

1313
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
14+
_to_blocked,
1415
triton_quantize_mx4_unpack,
16+
triton_rms_quantize_mx4_unpack,
17+
triton_silu_quantize_mx4_unpack,
1518
)
1619
from fbgemm_gpu.quantize_utils import fp32_to_mx4, RoundingMode
1720

@@ -57,8 +60,133 @@ def _test_quantize_fp4(
5760
x_scale[:, i] = xq_packed[:, end_idx]
5861

5962
self.assertTrue(torch.equal(xq, xq_ref))
60-
self.assertTrue(torch.equal(x_scale, x_scale_ref))
63+
self.assertTrue(
64+
torch.equal(_to_blocked(x_scale), x_scale_ref.view(torch.uint8))
65+
)
6166

6267
_test_quantize_fp4((1, 128))
6368
_test_quantize_fp4((3, 512))
6469
_test_quantize_fp4((128, 1024))
70+
_test_quantize_fp4((4096, 10240))
71+
72+
73+
@unittest.skipIf(
74+
not torch.cuda.is_available()
75+
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
76+
"Skip when H100 is not available",
77+
)
78+
class TestFp4RmsQuantize(unittest.TestCase):
79+
def setUp(self) -> None:
80+
torch.manual_seed(0)
81+
82+
def test_rms_quantize_fp4(self) -> None:
83+
def _test_rms_quantize_fp4(
84+
shape: Tuple[int, int],
85+
device: str = "cuda",
86+
) -> None:
87+
M, N = shape
88+
group_size = 32
89+
rounding_mode = RoundingMode.even
90+
packed_group_size = group_size // 2
91+
groups_per_row = math.ceil(N / group_size)
92+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
93+
w = torch.randn(M, N, dtype=torch.bfloat16, device=device)
94+
xq_ref, x_scale_ref = triton_rms_quantize_mx4_unpack(
95+
x, w, EPS=1e-5, group_size=group_size, rounding_mode=rounding_mode
96+
)
97+
98+
intermediate = (
99+
x.to(torch.float32).reshape(-1, group_size)
100+
* torch.rsqrt(
101+
torch.pow(x.to(torch.float32).reshape(-1, group_size), 2).mean(
102+
dim=1
103+
)
104+
+ 1e-5
105+
).unsqueeze(1)
106+
) * w.reshape(-1, group_size).to(torch.float32)
107+
108+
intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
109+
xq_packed = fp32_to_mx4(
110+
intermediate, group_size=group_size, rounding_mode=rounding_mode
111+
)
112+
113+
xq = torch.empty([M, N // 2], device=x.device, dtype=torch.uint8)
114+
x_scale = torch.empty(
115+
[M, groups_per_row], device=x.device, dtype=torch.uint8
116+
)
117+
118+
for i in range(groups_per_row):
119+
start_idx = i * (packed_group_size + 1)
120+
end_idx = start_idx + packed_group_size
121+
xq[:, i * packed_group_size : (i + 1) * packed_group_size] = xq_packed[
122+
:, start_idx:end_idx
123+
]
124+
x_scale[:, i] = xq_packed[:, end_idx]
125+
126+
self.assertTrue(torch.equal(xq, xq_ref))
127+
self.assertTrue(
128+
torch.equal(_to_blocked(x_scale), x_scale_ref.view(torch.uint8))
129+
)
130+
131+
_test_rms_quantize_fp4((1, 32))
132+
_test_rms_quantize_fp4((1, 128))
133+
_test_rms_quantize_fp4((3, 512))
134+
_test_rms_quantize_fp4((128, 1024))
135+
# TODO: fix potential bug with large tensors
136+
# _test_rms_quantize_fp4((4096, 10240))
137+
138+
139+
@unittest.skipIf(
140+
not torch.cuda.is_available()
141+
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
142+
"Skip when H100 is not available",
143+
)
144+
class TestFp4SiluQuantize(unittest.TestCase):
145+
def setUp(self) -> None:
146+
torch.manual_seed(0)
147+
148+
def test_silu_quantize_fp4(self) -> None:
149+
def _test_silu_quantize_fp4(
150+
shape: Tuple[int, int],
151+
device: str = "cuda",
152+
) -> None:
153+
M, N = shape
154+
group_size = 32
155+
rounding_mode = RoundingMode.even
156+
packed_group_size = group_size // 2
157+
groups_per_row = math.ceil(N / group_size)
158+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
159+
w = torch.randn(M, N, dtype=torch.bfloat16, device=device)
160+
xq_ref, x_scale_ref = triton_silu_quantize_mx4_unpack(
161+
x, w, group_size=group_size, rounding_mode=rounding_mode
162+
)
163+
intermediate = torch.nn.functional.silu(x.to(torch.float32)) * w.to(
164+
torch.float32
165+
)
166+
intermediate = intermediate.to(torch.bfloat16)
167+
xq_packed = fp32_to_mx4(
168+
intermediate, group_size=group_size, rounding_mode=rounding_mode
169+
)
170+
171+
xq = torch.empty([M, N // 2], device=x.device, dtype=torch.uint8)
172+
x_scale = torch.empty(
173+
[M, groups_per_row], device=x.device, dtype=torch.uint8
174+
)
175+
176+
for i in range(groups_per_row):
177+
start_idx = i * (packed_group_size + 1)
178+
end_idx = start_idx + packed_group_size
179+
xq[:, i * packed_group_size : (i + 1) * packed_group_size] = xq_packed[
180+
:, start_idx:end_idx
181+
]
182+
x_scale[:, i] = xq_packed[:, end_idx]
183+
184+
self.assertTrue(torch.equal(xq, xq_ref))
185+
self.assertTrue(
186+
torch.equal(_to_blocked(x_scale), x_scale_ref.view(torch.uint8))
187+
)
188+
189+
_test_silu_quantize_fp4((1, 128))
190+
_test_silu_quantize_fp4((3, 512))
191+
_test_silu_quantize_fp4((128, 1024))
192+
_test_silu_quantize_fp4((10240, 10240))

0 commit comments

Comments
 (0)