@@ -216,207 +216,3 @@ def round_up(x: int, y: int) -> int:
216
216
torch .ops .fbgemm .scaled_fp4_quant (output , input , output_scale , input_global_scale )
217
217
output_scale = output_scale .view (torch .float8_e4m3fn )
218
218
return output , output_scale
219
-
220
-
221
- def _fp32_to_fp4_unpacked (x : torch .Tensor , ebits : int , mbits : int ) -> torch .Tensor :
222
- """Converts a float32 tensor to a unpacked float4 tensor.
223
- Args:
224
- x (torch.Tensor): The input float32 tensor.
225
- ebits (int): The number of bits in the exponent.
226
- mbits (int): The number of bits in the mantissa.
227
- Returns:
228
- torch.Tensor: The resulting unpacked float4 tensor.
229
- """
230
-
231
- def _n_ones (n : int ) -> int :
232
- return (1 << n ) - 1
233
-
234
- EBITS_F32 , MBITS_F32 = 8 , 23
235
- F32_EXP_BIAS = _n_ones (EBITS_F32 - 1 )
236
-
237
- assert x .dtype == torch .float
238
- assert 1 + ebits + mbits <= 8
239
-
240
- # calculate constants
241
- exp_bias = _n_ones (ebits - 1 )
242
- max_int = _n_ones (ebits + mbits )
243
- sign_mask = 1 << (ebits + mbits )
244
-
245
- magic_adder = _n_ones (MBITS_F32 - mbits - 1 )
246
-
247
- # all E bits and M bits are 1s
248
- max_normal = 2 ** (_n_ones (ebits ) - exp_bias ) * (_n_ones (mbits + 1 ) / (2 ** mbits ))
249
-
250
- # E bits = 1, M bits = 0
251
- min_normal = 2 ** (1 - exp_bias )
252
-
253
- denorm_exp = (
254
- # exp bias conversion between formats
255
- (F32_EXP_BIAS - exp_bias )
256
- # mantissa length difference between formats
257
- + (MBITS_F32 - mbits )
258
- # add one to encoded exponent for denormalized numbers
259
- + 1
260
- )
261
- denorm_mask_int = denorm_exp << MBITS_F32
262
-
263
- # reinterpret int32 as float32
264
- denorm_mask_float = torch .tensor (denorm_mask_int , dtype = torch .int32 ).view (
265
- torch .float32
266
- )
267
-
268
- # save the sign
269
- # Note that we have torch.uint32, but some ops like cpu bit shifts
270
- # do not work on it. So, we stay in int32.
271
- x = x .view (torch .int32 )
272
- sign = x & 0x80000000
273
-
274
- # set everything to positive, will add sign back at the end
275
- x = x ^ sign
276
- x = x .view (torch .float )
277
-
278
- # rewrite saturate/denorm/norm branches without explicit data dependent
279
- # control flow, to be more compiler friendly
280
- saturate_mask = x >= max_normal
281
- denormal_mask = torch .logical_and (torch .logical_not (saturate_mask ), x < min_normal )
282
- normal_mask = torch .logical_not (torch .logical_or (saturate_mask , denormal_mask ))
283
-
284
- denormal_x = x + denorm_mask_float
285
- denormal_x = denormal_x .view (torch .int32 )
286
- denormal_x -= denorm_mask_int
287
- denormal_x = denormal_x .to (torch .uint8 )
288
-
289
- normal_x = x .view (torch .int32 )
290
- # resulting mantissa is odd
291
- mant_odd = (normal_x >> (MBITS_F32 - mbits )) & 1
292
- # update exponent, rounding bias part 1
293
- val_to_add = ((exp_bias - F32_EXP_BIAS ) << MBITS_F32 ) + magic_adder
294
- normal_x += val_to_add
295
- # rounding bias part 2
296
- normal_x += mant_odd
297
- # take the bits!
298
- normal_x = normal_x >> (MBITS_F32 - mbits )
299
- normal_x = normal_x .to (torch .uint8 )
300
-
301
- x = torch .full_like (x , max_int , dtype = torch .uint8 )
302
- x = torch .where (denormal_mask , denormal_x , x )
303
- x = torch .where (normal_mask , normal_x , x )
304
-
305
- # add sign back
306
- sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits )
307
- sign_lp = sign_lp .to (torch .uint8 )
308
- # Right shift of a negative signed integer can fill the least significant
309
- # bits with either 1s or 0s, depending on the implementation. Since PyTorch
310
- # doesn't have an uint32 dtype, we mask out these bits to get just the
311
- # f4 sign bit
312
- sign_lp = sign_lp & sign_mask
313
- x = x | sign_lp
314
-
315
- return x .to (torch .uint8 )
316
-
317
-
318
- def _to_blocked (x : torch .Tensor ) -> torch .Tensor :
319
- """Converts a tensor to the blocked layout.
320
- Args:
321
- x (torch.Tensor): The input tensor in non-blocked layout.
322
- Returns:
323
- torch.Tensor: The output tensor in the blocked layout.
324
- """
325
-
326
- def ceil_div (a : int , b : int ) -> int :
327
- return (a + b - 1 ) // b
328
-
329
- rows , cols = x .shape
330
- n_row_blocks = ceil_div (rows , 128 )
331
- n_col_blocks = ceil_div (cols , 4 )
332
-
333
- # Calculate the padded shape
334
- padded_rows = n_row_blocks * 128
335
- padded_cols = n_col_blocks * 4
336
-
337
- padded = x
338
- if (rows , cols ) != (padded_rows , padded_cols ):
339
- padded = torch .zeros (
340
- (padded_rows , padded_cols ),
341
- device = x .device ,
342
- dtype = x .dtype ,
343
- )
344
- padded [:rows , :cols ] = x
345
-
346
- # Rearrange the blocks
347
- blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
348
- rearranged = blocks .reshape (- 1 , 4 , 32 , 4 ).transpose (1 , 2 ).reshape (- 1 , 32 , 16 )
349
-
350
- return rearranged .flatten ()
351
-
352
-
353
- # This PyTorch version refers to https://github.com/pytorch/ao/blob/v0.10.0/torchao/prototype/mx_formats/mx_tensor.py#L146
354
- def scale_mxfp4_quant (
355
- x : torch .Tensor , block_size : int = 32
356
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
357
- """
358
- Quantize input tensor to FP4 and return quantized tensor and scale.
359
- Args:
360
- x (torch.Tensor): The input tensor to be quantized to FP4
361
- block_size (int): The block size to use for quantization. Default is 32.
362
- Returns:
363
- xq (torch.Tensor): Quantized FP4 output tensor
364
- scale (torch.Tensor): Scale E8M0 tensor
365
- """
366
-
367
- F4_E2M1_MAX = 6.0
368
- E8M0_EXPONENT_BIAS = 127
369
- EBITS_F4_E2M1 , MBITS_F4_E2M1 = 2 , 1
370
-
371
- # calculate the scale in e8m0 format
372
- orig_shape = x .shape
373
- x = x .reshape (- 1 , block_size )
374
-
375
- # find max value of the data
376
- # Note: this only implements the `minimally supported` version of
377
- # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
378
- # section 6.3.
379
- max_abs = torch .amax (torch .abs (x ), 1 )
380
- max_pos = F4_E2M1_MAX
381
-
382
- descale = max_abs / max_pos
383
- scale = torch .where (
384
- torch .isnan (descale ),
385
- 0xFF , # Handle biased exponent for nan
386
- # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
387
- (
388
- torch .clamp (
389
- torch .ceil (torch .log2 (descale )),
390
- min = - E8M0_EXPONENT_BIAS ,
391
- max = E8M0_EXPONENT_BIAS ,
392
- )
393
- + E8M0_EXPONENT_BIAS
394
- ).to (torch .uint8 ),
395
- )
396
-
397
- descale_fp = torch .where (
398
- scale == 0 ,
399
- 1.0 ,
400
- torch .exp2 (E8M0_EXPONENT_BIAS - scale .to (torch .float32 )),
401
- )
402
-
403
- # scale and saturated cast the data elements to max of target dtype
404
- xq = torch .clamp (x * descale_fp .unsqueeze (1 ), min = - 1 * max_pos , max = max_pos )
405
-
406
- xq = xq .reshape (orig_shape )
407
- xq = _fp32_to_fp4_unpacked (xq , EBITS_F4_E2M1 , MBITS_F4_E2M1 )
408
- orig_shape = [* orig_shape [:- 1 ], orig_shape [- 1 ] // 2 ]
409
-
410
- shape = xq .shape
411
- assert shape [- 1 ] % 2 == 0
412
- xq = xq .contiguous ().view (- 1 )
413
- xq = (xq [::2 ] << 4 | xq [1 ::2 ]).view ((* shape [:- 1 ], shape [- 1 ] // 2 ))
414
-
415
- target_numel = scale .numel () * block_size / 2
416
- assert target_numel == xq .numel (), f"{ target_numel } != { xq .numel ()} "
417
-
418
- scale = scale .view (torch .float8_e8m0fnu )
419
- scale = scale .view (orig_shape [0 ], - 1 )
420
- scale = _to_blocked (scale )
421
-
422
- return xq , scale
0 commit comments