@@ -371,6 +371,8 @@ class QuantizedAllToShardedLinear(Module):
371371 weight. See :func:`~mlx.core.quantize`. Default: ``64``.
372372 bits (int, optional): The bit width to use for the quantized weight.
373373 See :func:`~mlx.core.quantize`. Default: ``4``.
374+ mode (str, optional): The quantization method to use (see
375+ :func:`~mlx.core.quantize`). Default: ``"affine"``.
374376 group (mx.distributed.Group, optional): The sharding will happen across
375377 this group. If not set then the global group is used. Default is
376378 ``None``.
@@ -383,13 +385,15 @@ def __init__(
383385 bias : bool = True ,
384386 group_size : int = 64 ,
385387 bits : int = 4 ,
388+ mode : str = "affine" ,
386389 group : Optional [mx .distributed .Group ] = None ,
387390 ):
388391 super ().__init__ ()
389392
390393 # Quantization config
391394 self .group_size = group_size
392395 self .bits = bits
396+ self .mode = mode
393397
394398 # Initialize the quantized weight
395399 scale = math .sqrt (1.0 / input_dims )
@@ -406,7 +410,10 @@ def __init__(
406410 high = scale ,
407411 shape = (output_dims // N , input_dims ),
408412 )
409- self .weight , self .scales , self .biases = mx .quantize (weight , group_size , bits )
413+ self .weight , self .scales , * biases = mx .quantize (
414+ weight , group_size , bits , mode = mode
415+ )
416+ self .biases = biases [0 ] if biases else None
410417
411418 # And bias if needed
412419 if bias :
@@ -427,7 +434,7 @@ def _extra_repr(self) -> str:
427434 out_dims *= self .group .size ()
428435 return (
429436 f"input_dims={ in_dims } , output_dims={ out_dims } , bias={ 'bias' in self } , "
430- f"group_size={ self .group_size } , bits={ self .bits } "
437+ f"group_size={ self .group_size } , bits={ self .bits } , mode= { self . mode } "
431438 )
432439
433440 def __call__ (self , x : mx .array ) -> mx .array :
@@ -438,10 +445,11 @@ def __call__(self, x: mx.array) -> mx.array:
438445 x ,
439446 self ["weight" ],
440447 scales = self ["scales" ],
441- biases = self [ "biases" ] ,
448+ biases = self . get ( "biases" ) ,
442449 transpose = True ,
443450 group_size = self .group_size ,
444451 bits = self .bits ,
452+ mode = self .mode ,
445453 )
446454 if "bias" in self :
447455 x = x + self ["bias" ]
@@ -465,6 +473,7 @@ def from_quantized_linear(
465473 hasattr (quantized_linear_layer , "bias" ),
466474 group_size = quantized_linear_layer .group_size ,
467475 bits = quantized_linear_layer .bits ,
476+ mode = getattr (quantized_linear_layer , "mode" , "affine" ),
468477 group = group ,
469478 )
470479 sl .update (
@@ -497,6 +506,8 @@ class QuantizedShardedToAllLinear(Module):
497506 weight. See :func:`~mlx.core.quantize`. Default: ``64``.
498507 bits (int, optional): The bit width to use for the quantized weight.
499508 See :func:`~mlx.core.quantize`. Default: ``4``.
509+ mode (str, optional): The quantization method to use (see
510+ :func:`~mlx.core.quantize`). Default: ``"affine"``.
500511 group (mx.distributed.Group, optional): The sharding will happen across
501512 this group. If not set then the global group is used. Default is
502513 ``None``.
@@ -509,13 +520,15 @@ def __init__(
509520 bias : bool = True ,
510521 group_size : int = 64 ,
511522 bits : int = 4 ,
523+ mode : str = "affine" ,
512524 group : Optional [mx .distributed .Group ] = None ,
513525 ):
514526 super ().__init__ ()
515527
516528 # Quantization config
517529 self .group_size = group_size
518530 self .bits = bits
531+ self .mode = mode
519532
520533 # Initialize the quantized weight
521534 scale = math .sqrt (1.0 / input_dims )
@@ -532,7 +545,10 @@ def __init__(
532545 high = scale ,
533546 shape = (output_dims , input_dims // N ),
534547 )
535- self .weight , self .scales , self .biases = mx .quantize (weight , group_size , bits )
548+ self .weight , self .scales , * biases = mx .quantize (
549+ weight , group_size , bits , mode = mode
550+ )
551+ self .biases = biases [0 ] if biases else None
536552
537553 # And bias if needed
538554 if bias :
@@ -552,18 +568,19 @@ def _extra_repr(self) -> str:
552568 in_dims = (in_dims * 32 ) // self .bits * self .group .size ()
553569 return (
554570 f"input_dims={ in_dims } , output_dims={ out_dims } , bias={ 'bias' in self } , "
555- f"group_size={ self .group_size } , bits={ self .bits } "
571+ f"group_size={ self .group_size } , bits={ self .bits } , mode= { self . mode } "
556572 )
557573
558574 def __call__ (self , x : mx .array ) -> mx .array :
559575 x = mx .quantized_matmul (
560576 x ,
561577 self ["weight" ],
562578 scales = self ["scales" ],
563- biases = self [ "biases" ] ,
579+ biases = self . get ( "biases" ) ,
564580 transpose = True ,
565581 group_size = self .group_size ,
566582 bits = self .bits ,
583+ mode = self .mode ,
567584 )
568585 x = mx .distributed .all_sum (x , group = self .group )
569586 if "bias" in self :
@@ -588,6 +605,7 @@ def from_quantized_linear(
588605 hasattr (quantized_linear_layer , "bias" ),
589606 group_size = quantized_linear_layer .group_size ,
590607 bits = quantized_linear_layer .bits ,
608+ mode = getattr (quantized_linear_layer , "mode" , "affine" ),
591609 group = group ,
592610 )
593611 sl .update (
0 commit comments