@@ -1040,7 +1040,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
10401040
10411041def dot_general (lhs : ArrayLike , rhs : ArrayLike , dimension_numbers : DotDimensionNumbers ,
10421042 precision : PrecisionLike = None ,
1043- preferred_element_type : DTypeLike | None = None ) -> Array :
1043+ preferred_element_type : DTypeLike | None = None ,
1044+ out_type = None ) -> Array :
10441045 """General dot product/contraction operator.
10451046
10461047 Wraps XLA's `DotGeneral
@@ -1086,6 +1087,13 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10861087 by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
10871088 non-contracting/non-batch dimensions.
10881089 """
1090+ if out_type is not None and not config .sharding_in_types .value :
1091+ raise NotImplementedError ("out_type only works when sharding_in_types "
1092+ "config is True." )
1093+ if out_type is not None and not isinstance (out_type , NamedSharding ):
1094+ raise NotImplementedError (
1095+ '`out_type` argument of `dot_general` only supports NamedSharding '
1096+ 'instances. Please file a bug if this is not enough for your use case.' )
10891097 (lhs_contract , rhs_contract ), (lhs_batch , rhs_batch ) = dimension_numbers
10901098 cdims = (api_util ._ensure_index_tuple (lhs_contract ),
10911099 api_util ._ensure_index_tuple (rhs_contract ))
@@ -1097,7 +1105,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10971105 return dot_general_p .bind (lhs , rhs ,
10981106 dimension_numbers = (cdims , bdims ),
10991107 precision = canonicalize_precision (precision ),
1100- preferred_element_type = preferred_element_type )
1108+ preferred_element_type = preferred_element_type ,
1109+ out_type = out_type )
11011110
11021111
11031112def ragged_dot (
@@ -1123,7 +1132,8 @@ def ragged_dot(
11231132 """
11241133 return ragged_dot_p .bind (lhs , rhs , group_sizes ,
11251134 precision = canonicalize_precision (precision ),
1126- preferred_element_type = preferred_element_type , group_offset = group_offset )
1135+ preferred_element_type = preferred_element_type ,
1136+ group_offset = group_offset )
11271137
11281138
11291139def broadcast (operand : ArrayLike , sizes : Sequence [int ]) -> Array :
@@ -3002,7 +3012,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
30023012 not dtypes .issubdtype (new_dtype , np .complexfloating )):
30033013 operand = hlo .real (operand )
30043014 aval_in = aval_in .update (dtype = _real_dtype (aval_in .dtype ))
3005- return [mlir .convert_hlo (ctx , operand , aval_in , aval_out )]
3015+ out = mlir .convert_hlo (ctx , operand , aval_in , aval_out )
3016+ if config .sharding_in_types .value :
3017+ proto = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
3018+ return [mlir .wrap_with_sharding_op (ctx , out , aval_out , proto )]
3019+ return [out ]
30063020
30073021mlir .register_lowering (convert_element_type_p , _convert_element_type_lower )
30083022
@@ -3164,7 +3178,10 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
31643178
31653179
31663180def _dot_general_shape_rule (lhs , rhs , * , dimension_numbers , precision ,
3167- preferred_element_type : DTypeLike | None ):
3181+ preferred_element_type : DTypeLike | None ,
3182+ out_type ):
3183+ if out_type is not None and not isinstance (out_type , NamedSharding ):
3184+ raise NotImplementedError
31683185 (lhs_contracting , rhs_contracting ), (lhs_batch , rhs_batch ) = dimension_numbers
31693186 if not all (np .all (np .greater_equal (d , 0 )) and np .all (np .less (d , lhs .ndim ))
31703187 for d in (lhs_contracting , lhs_batch )):
@@ -3241,24 +3258,29 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
32413258 raise TypeError (msg )
32423259
32433260def _dot_general_sharding_rule (lhs , rhs , * , dimension_numbers , precision ,
3244- preferred_element_type : DTypeLike | None ):
3261+ preferred_element_type : DTypeLike | None ,
3262+ out_type ):
32453263 if lhs .sharding .mesh != rhs .sharding .mesh :
32463264 raise ValueError (
32473265 'Mesh of both lhs and rhs should match. Got lhs:'
32483266 f' { lhs .sharding .mesh } and rhs: { rhs .sharding .mesh } ' )
32493267
3268+ if out_type is not None :
3269+ assert isinstance (out_type , NamedSharding )
3270+ return out_type
3271+
32503272 (lhs_contracting , rhs_contracting ), (lhs_batch , rhs_batch ) = dimension_numbers
32513273 lhs_batch_spec = tuple (lhs .sharding .spec [i ] for i in lhs_batch )
32523274 rhs_batch_spec = tuple (rhs .sharding .spec [i ] for i in rhs_batch )
32533275 msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
3254- f"to have the consistent sharding, got { lhs_batch_spec } and "
3255- f"{ rhs_batch_spec } ." )
3276+ f"to have the consistent sharding, got { lhs_batch_spec } and "
3277+ f"{ rhs_batch_spec } ." )
32563278 _check_specs_match (lhs_batch_spec , rhs_batch_spec , msg )
32573279
32583280 lhs_contracting_spec = tuple (lhs .sharding .spec [i ] for i in lhs_contracting )
32593281 rhs_contracting_spec = tuple (rhs .sharding .spec [i ] for i in rhs_contracting )
32603282 msg = ("dot_general requires contracting dimensions to have consistent "
3261- f"sharding, got { lhs_contracting_spec } and { rhs_contracting_spec } ." )
3283+ f"sharding, got { lhs_contracting_spec } and { rhs_contracting_spec } ." )
32623284 _check_specs_match (lhs_contracting_spec , rhs_contracting_spec , msg )
32633285
32643286 return _dot_general_sharding_computation (
@@ -3280,7 +3302,10 @@ def tuple_delete(tup, idx):
32803302
32813303
32823304def _dot_general_dtype_rule (lhs , rhs , * , dimension_numbers , precision ,
3283- preferred_element_type : DTypeLike | None ):
3305+ preferred_element_type : DTypeLike | None ,
3306+ out_type ):
3307+ if out_type is not None and not isinstance (out_type , NamedSharding ):
3308+ raise NotImplementedError
32843309 del dimension_numbers # unused
32853310 # We're mostly matching XLA's logic here, namely in shape_inference.cc and
32863311 # primitive_util.h's HigherPrecisionType, e.g.
@@ -3327,7 +3352,9 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
33273352
33283353def _dot_general_transpose_lhs (g , x , y , * , dimension_numbers , precision ,
33293354 preferred_element_type : DTypeLike | None ,
3330- swap_ans = False ):
3355+ out_type , swap_ans = False ):
3356+ if out_type is not None :
3357+ raise NotImplementedError
33313358 (x_contract , y_contract ), (x_batch , y_batch ) = dimension_numbers
33323359 x_ndim = x .aval .ndim
33333360 x_kept = remaining (range (x_ndim ), x_contract , x_batch )
@@ -3347,12 +3374,16 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33473374 return x_bar
33483375
33493376def _dot_general_transpose_rhs (g , x , y , * , dimension_numbers , precision ,
3350- preferred_element_type : DTypeLike | None ):
3377+ preferred_element_type : DTypeLike | None ,
3378+ out_type ):
3379+ if out_type is not None :
3380+ raise NotImplementedError
33513381 (x_contract , y_contract ), (x_batch , y_batch ) = dimension_numbers
33523382 swapped_dimension_numbers = ((y_contract , x_contract ), (y_batch , x_batch ))
33533383 y_bar = _dot_general_transpose_lhs (
33543384 g , y , x , dimension_numbers = swapped_dimension_numbers , precision = precision ,
3355- preferred_element_type = preferred_element_type , swap_ans = True )
3385+ preferred_element_type = preferred_element_type , out_type = out_type ,
3386+ swap_ans = True )
33563387 if y_bar .dtype != y .aval .dtype :
33573388 y_bar = _convert_element_type (y_bar , y .aval .dtype , y .aval .weak_type )
33583389 return y_bar
@@ -3366,6 +3397,7 @@ def _dot_batch_rule(
33663397 batch_dims ,
33673398 * ,
33683399 dimension_numbers ,
3400+ out_type ,
33693401 precision ,
33703402 preferred_element_type : DTypeLike | None ,
33713403 ** _ ,
@@ -3395,12 +3427,16 @@ def _dot_batch_rule(
33953427 rhs_shape = batching .bdim_as_shape (rbd , rhs .shape )
33963428 else :
33973429 rhs_shape = np .shape (rhs )
3430+ if out_type is not None :
3431+ raise NotImplementedError ("vmap with out_type is not supported. "
3432+ "Please open an issue." )
33983433 batched_out = invoke_prim (
33993434 lhs ,
34003435 rhs ,
34013436 new_dimension_numbers ,
34023437 precision = precision ,
34033438 preferred_element_type = preferred_element_type ,
3439+ out_type = out_type ,
34043440 )
34053441 result_batch_dim = batching .shape_as_bdim (
34063442 result_stack_dim ,
@@ -3570,7 +3606,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
35703606
35713607def _dot_general_lower (ctx , lhs , rhs , * , dimension_numbers ,
35723608 precision , preferred_element_type : np .dtype | None ,
3573- platform : str = "default" ):
3609+ out_type , platform : str = "default" ):
35743610 def _is_fp8_mixed_precision_matmul (_lhs_dtypes , _rhs_dtypes ):
35753611 fp8_dtypes = (dtypes .float8_e4m3fn , dtypes .float8_e5m2 ,
35763612 dtypes .float8_e5m2fnuz , dtypes .float8_e4m3fnuz )
@@ -3658,6 +3694,8 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36583694 ** algorithm_kwarg ,
36593695 )
36603696 if config .sharding_in_types .value :
3697+ if out_type is not None :
3698+ assert aval_out .sharding == out_type
36613699 out_sp = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
36623700 result = mlir .wrap_with_sharding_op (ctx , result , aval_out , out_sp )
36633701 if accumulation_aval .dtype != aval_out .dtype :
@@ -3711,12 +3749,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
37113749 return (m , n )
37123750
37133751def _ragged_dot_dtype_rule (lhs : Array , rhs : Array , group_sizes : Array ,
3714- precision , preferred_element_type : DTypeLike | None , ** _ ) -> np .dtype :
3752+ precision , preferred_element_type : DTypeLike | None ,
3753+ ** _ ) -> np .dtype :
37153754 if not dtypes .issubdtype (group_sizes .dtype , np .integer ):
37163755 raise TypeError ("ragged_dot requires that group_sizes.dtype is subtype of np.integer." )
37173756 # defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
3718- return _dot_general_dtype_rule (lhs , rhs , dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
3719- precision = precision , preferred_element_type = preferred_element_type )
3757+ return _dot_general_dtype_rule (
3758+ lhs , rhs , dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
3759+ precision = precision , preferred_element_type = preferred_element_type ,
3760+ out_type = None )
37203761
37213762
37223763def _ragged_dot_jvp_rule (
@@ -3839,7 +3880,9 @@ def _ragged_dot_invoke_prim(
38393880 new_dimension_numbers ,
38403881 precision ,
38413882 preferred_element_type ,
3883+ out_type ,
38423884):
3885+ del out_type
38433886 return ragged_dot (
38443887 lhs ,
38453888 rhs ,
@@ -3868,6 +3911,7 @@ def _ragged_dot_batch_rule(
38683911 dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
38693912 precision = precision ,
38703913 preferred_element_type = preferred_element_type ,
3914+ out_type = None ,
38713915 )
38723916
38733917
0 commit comments