Skip to content

Commit e226af7

Browse files
authored
Propagate quantization mode in quantized layers (#3133)
1 parent 43f4a74 commit e226af7

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

python/mlx/nn/layers/distributed.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ class QuantizedAllToShardedLinear(Module):
371371
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
372372
bits (int, optional): The bit width to use for the quantized weight.
373373
See :func:`~mlx.core.quantize`. Default: ``4``.
374+
mode (str, optional): The quantization method to use (see
375+
:func:`~mlx.core.quantize`). Default: ``"affine"``.
374376
group (mx.distributed.Group, optional): The sharding will happen across
375377
this group. If not set then the global group is used. Default is
376378
``None``.
@@ -383,13 +385,15 @@ def __init__(
383385
bias: bool = True,
384386
group_size: int = 64,
385387
bits: int = 4,
388+
mode: str = "affine",
386389
group: Optional[mx.distributed.Group] = None,
387390
):
388391
super().__init__()
389392

390393
# Quantization config
391394
self.group_size = group_size
392395
self.bits = bits
396+
self.mode = mode
393397

394398
# Initialize the quantized weight
395399
scale = math.sqrt(1.0 / input_dims)
@@ -406,7 +410,10 @@ def __init__(
406410
high=scale,
407411
shape=(output_dims // N, input_dims),
408412
)
409-
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
413+
self.weight, self.scales, *biases = mx.quantize(
414+
weight, group_size, bits, mode=mode
415+
)
416+
self.biases = biases[0] if biases else None
410417

411418
# And bias if needed
412419
if bias:
@@ -427,7 +434,7 @@ def _extra_repr(self) -> str:
427434
out_dims *= self.group.size()
428435
return (
429436
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
430-
f"group_size={self.group_size}, bits={self.bits}"
437+
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
431438
)
432439

433440
def __call__(self, x: mx.array) -> mx.array:
@@ -438,10 +445,11 @@ def __call__(self, x: mx.array) -> mx.array:
438445
x,
439446
self["weight"],
440447
scales=self["scales"],
441-
biases=self["biases"],
448+
biases=self.get("biases"),
442449
transpose=True,
443450
group_size=self.group_size,
444451
bits=self.bits,
452+
mode=self.mode,
445453
)
446454
if "bias" in self:
447455
x = x + self["bias"]
@@ -465,6 +473,7 @@ def from_quantized_linear(
465473
hasattr(quantized_linear_layer, "bias"),
466474
group_size=quantized_linear_layer.group_size,
467475
bits=quantized_linear_layer.bits,
476+
mode=getattr(quantized_linear_layer, "mode", "affine"),
468477
group=group,
469478
)
470479
sl.update(
@@ -497,6 +506,8 @@ class QuantizedShardedToAllLinear(Module):
497506
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
498507
bits (int, optional): The bit width to use for the quantized weight.
499508
See :func:`~mlx.core.quantize`. Default: ``4``.
509+
mode (str, optional): The quantization method to use (see
510+
:func:`~mlx.core.quantize`). Default: ``"affine"``.
500511
group (mx.distributed.Group, optional): The sharding will happen across
501512
this group. If not set then the global group is used. Default is
502513
``None``.
@@ -509,13 +520,15 @@ def __init__(
509520
bias: bool = True,
510521
group_size: int = 64,
511522
bits: int = 4,
523+
mode: str = "affine",
512524
group: Optional[mx.distributed.Group] = None,
513525
):
514526
super().__init__()
515527

516528
# Quantization config
517529
self.group_size = group_size
518530
self.bits = bits
531+
self.mode = mode
519532

520533
# Initialize the quantized weight
521534
scale = math.sqrt(1.0 / input_dims)
@@ -532,7 +545,10 @@ def __init__(
532545
high=scale,
533546
shape=(output_dims, input_dims // N),
534547
)
535-
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
548+
self.weight, self.scales, *biases = mx.quantize(
549+
weight, group_size, bits, mode=mode
550+
)
551+
self.biases = biases[0] if biases else None
536552

537553
# And bias if needed
538554
if bias:
@@ -552,18 +568,19 @@ def _extra_repr(self) -> str:
552568
in_dims = (in_dims * 32) // self.bits * self.group.size()
553569
return (
554570
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
555-
f"group_size={self.group_size}, bits={self.bits}"
571+
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
556572
)
557573

558574
def __call__(self, x: mx.array) -> mx.array:
559575
x = mx.quantized_matmul(
560576
x,
561577
self["weight"],
562578
scales=self["scales"],
563-
biases=self["biases"],
579+
biases=self.get("biases"),
564580
transpose=True,
565581
group_size=self.group_size,
566582
bits=self.bits,
583+
mode=self.mode,
567584
)
568585
x = mx.distributed.all_sum(x, group=self.group)
569586
if "bias" in self:
@@ -588,6 +605,7 @@ def from_quantized_linear(
588605
hasattr(quantized_linear_layer, "bias"),
589606
group_size=quantized_linear_layer.group_size,
590607
bits=quantized_linear_layer.bits,
608+
mode=getattr(quantized_linear_layer, "mode", "affine"),
591609
group=group,
592610
)
593611
sl.update(

python/tests/mlx_distributed_tests.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_shard_linear(self):
146146
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
147147
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
148148

149-
# And their quant versions (QuintizedMatmul is not supported on CUDA)
149+
# And their quant versions (QuantizedMatmul is not supported on CUDA)
150150
if not mx.cuda.is_available():
151151
qlin = lin.to_quantized()
152152
slin1 = shard_linear(qlin, "all-to-sharded")
@@ -157,6 +157,27 @@ def test_shard_linear(self):
157157
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
158158
self.assertTrue(mx.allclose(y[part], y1))
159159

160+
# Test non-affine quantization modes (mxfp8)
161+
qlin_mxfp8 = lin.to_quantized(group_size=32, bits=8, mode="mxfp8")
162+
self.assertEqual(qlin_mxfp8.mode, "mxfp8")
163+
164+
slin1_mxfp8 = shard_linear(qlin_mxfp8, "all-to-sharded")
165+
slin2_mxfp8 = shard_linear(qlin_mxfp8, "sharded-to-all")
166+
167+
# Verify mode is propagated
168+
self.assertEqual(slin1_mxfp8.mode, "mxfp8")
169+
self.assertEqual(slin2_mxfp8.mode, "mxfp8")
170+
171+
# Verify biases parameter is not set for mxfp8
172+
self.assertIsNone(slin1_mxfp8.get("biases"))
173+
self.assertIsNone(slin2_mxfp8.get("biases"))
174+
175+
y = qlin_mxfp8(x)
176+
y1 = slin1_mxfp8(x)
177+
y2 = slin2_mxfp8(x[part])
178+
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
179+
self.assertTrue(mx.allclose(y[part], y1))
180+
160181
# Check the backward works as expected
161182
def dummy_loss(model, x, y):
162183
return (model(x) * y).sum()

0 commit comments

Comments
 (0)