Skip to content

Commit f8a1f02

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types][Take 2] Add out_type argument to einsum and dot_general to allow specifying for the output type. Right now, it only accept a NamedSharding but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout.
Reverts 0b3f0e1 PiperOrigin-RevId: 688663504
1 parent 32be199 commit f8a1f02

File tree

8 files changed

+163
-35
lines changed

8 files changed

+163
-35
lines changed

jax/_src/lax/lax.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
10401040

10411041
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
10421042
precision: PrecisionLike = None,
1043-
preferred_element_type: DTypeLike | None = None) -> Array:
1043+
preferred_element_type: DTypeLike | None = None,
1044+
out_type=None) -> Array:
10441045
"""General dot product/contraction operator.
10451046
10461047
Wraps XLA's `DotGeneral
@@ -1086,6 +1087,13 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10861087
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
10871088
non-contracting/non-batch dimensions.
10881089
"""
1090+
if out_type is not None and not config.sharding_in_types.value:
1091+
raise NotImplementedError("out_type only works when sharding_in_types "
1092+
"config is True.")
1093+
if out_type is not None and not isinstance(out_type, NamedSharding):
1094+
raise NotImplementedError(
1095+
'`out_type` argument of `dot_general` only supports NamedSharding '
1096+
'instances. Please file a bug if this is not enough for your use case.')
10891097
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
10901098
cdims = (api_util._ensure_index_tuple(lhs_contract),
10911099
api_util._ensure_index_tuple(rhs_contract))
@@ -1097,7 +1105,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10971105
return dot_general_p.bind(lhs, rhs,
10981106
dimension_numbers=(cdims, bdims),
10991107
precision=canonicalize_precision(precision),
1100-
preferred_element_type=preferred_element_type)
1108+
preferred_element_type=preferred_element_type,
1109+
out_type=out_type)
11011110

11021111

11031112
def ragged_dot(
@@ -1123,7 +1132,8 @@ def ragged_dot(
11231132
"""
11241133
return ragged_dot_p.bind(lhs, rhs, group_sizes,
11251134
precision=canonicalize_precision(precision),
1126-
preferred_element_type=preferred_element_type, group_offset=group_offset)
1135+
preferred_element_type=preferred_element_type,
1136+
group_offset=group_offset)
11271137

11281138

11291139
def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array:
@@ -3002,7 +3012,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
30023012
not dtypes.issubdtype(new_dtype, np.complexfloating)):
30033013
operand = hlo.real(operand)
30043014
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
3005-
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
3015+
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
3016+
if config.sharding_in_types.value:
3017+
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
3018+
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
3019+
return [out]
30063020

30073021
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
30083022

@@ -3164,7 +3178,10 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
31643178

31653179

31663180
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
3167-
preferred_element_type: DTypeLike | None):
3181+
preferred_element_type: DTypeLike | None,
3182+
out_type):
3183+
if out_type is not None and not isinstance(out_type, NamedSharding):
3184+
raise NotImplementedError
31683185
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
31693186
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
31703187
for d in (lhs_contracting, lhs_batch)):
@@ -3241,24 +3258,29 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
32413258
raise TypeError(msg)
32423259

32433260
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
3244-
preferred_element_type: DTypeLike | None):
3261+
preferred_element_type: DTypeLike | None,
3262+
out_type):
32453263
if lhs.sharding.mesh != rhs.sharding.mesh:
32463264
raise ValueError(
32473265
'Mesh of both lhs and rhs should match. Got lhs:'
32483266
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
32493267

3268+
if out_type is not None:
3269+
assert isinstance(out_type, NamedSharding)
3270+
return out_type
3271+
32503272
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
32513273
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
32523274
rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch)
32533275
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
3254-
f"to have the consistent sharding, got {lhs_batch_spec} and "
3255-
f"{rhs_batch_spec}.")
3276+
f"to have the consistent sharding, got {lhs_batch_spec} and "
3277+
f"{rhs_batch_spec}.")
32563278
_check_specs_match(lhs_batch_spec, rhs_batch_spec, msg)
32573279

32583280
lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting)
32593281
rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting)
32603282
msg = ("dot_general requires contracting dimensions to have consistent "
3261-
f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.")
3283+
f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.")
32623284
_check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg)
32633285

32643286
return _dot_general_sharding_computation(
@@ -3280,7 +3302,10 @@ def tuple_delete(tup, idx):
32803302

32813303

32823304
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
3283-
preferred_element_type: DTypeLike | None):
3305+
preferred_element_type: DTypeLike | None,
3306+
out_type):
3307+
if out_type is not None and not isinstance(out_type, NamedSharding):
3308+
raise NotImplementedError
32843309
del dimension_numbers # unused
32853310
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
32863311
# primitive_util.h's HigherPrecisionType, e.g.
@@ -3327,7 +3352,9 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
33273352

33283353
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33293354
preferred_element_type: DTypeLike | None,
3330-
swap_ans=False):
3355+
out_type, swap_ans=False):
3356+
if out_type is not None:
3357+
raise NotImplementedError
33313358
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
33323359
x_ndim = x.aval.ndim
33333360
x_kept = remaining(range(x_ndim), x_contract, x_batch)
@@ -3347,12 +3374,16 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33473374
return x_bar
33483375

33493376
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
3350-
preferred_element_type: DTypeLike | None):
3377+
preferred_element_type: DTypeLike | None,
3378+
out_type):
3379+
if out_type is not None:
3380+
raise NotImplementedError
33513381
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
33523382
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
33533383
y_bar = _dot_general_transpose_lhs(
33543384
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
3355-
preferred_element_type=preferred_element_type, swap_ans=True)
3385+
preferred_element_type=preferred_element_type, out_type=out_type,
3386+
swap_ans=True)
33563387
if y_bar.dtype != y.aval.dtype:
33573388
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
33583389
return y_bar
@@ -3366,6 +3397,7 @@ def _dot_batch_rule(
33663397
batch_dims,
33673398
*,
33683399
dimension_numbers,
3400+
out_type,
33693401
precision,
33703402
preferred_element_type: DTypeLike | None,
33713403
**_,
@@ -3395,12 +3427,16 @@ def _dot_batch_rule(
33953427
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
33963428
else:
33973429
rhs_shape = np.shape(rhs)
3430+
if out_type is not None:
3431+
raise NotImplementedError("vmap with out_type is not supported. "
3432+
"Please open an issue.")
33983433
batched_out = invoke_prim(
33993434
lhs,
34003435
rhs,
34013436
new_dimension_numbers,
34023437
precision=precision,
34033438
preferred_element_type=preferred_element_type,
3439+
out_type=out_type,
34043440
)
34053441
result_batch_dim = batching.shape_as_bdim(
34063442
result_stack_dim,
@@ -3570,7 +3606,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
35703606

35713607
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
35723608
precision, preferred_element_type: np.dtype | None,
3573-
platform: str = "default"):
3609+
out_type, platform: str = "default"):
35743610
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
35753611
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
35763612
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
@@ -3658,6 +3694,8 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36583694
**algorithm_kwarg,
36593695
)
36603696
if config.sharding_in_types.value:
3697+
if out_type is not None:
3698+
assert aval_out.sharding == out_type
36613699
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
36623700
result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp)
36633701
if accumulation_aval.dtype != aval_out.dtype:
@@ -3711,12 +3749,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
37113749
return (m, n)
37123750

37133751
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
3714-
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
3752+
precision, preferred_element_type: DTypeLike | None,
3753+
**_) -> np.dtype:
37153754
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
37163755
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
37173756
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
3718-
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
3719-
precision=precision, preferred_element_type=preferred_element_type)
3757+
return _dot_general_dtype_rule(
3758+
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
3759+
precision=precision, preferred_element_type=preferred_element_type,
3760+
out_type=None)
37203761

37213762

37223763
def _ragged_dot_jvp_rule(
@@ -3839,7 +3880,9 @@ def _ragged_dot_invoke_prim(
38393880
new_dimension_numbers,
38403881
precision,
38413882
preferred_element_type,
3883+
out_type,
38423884
):
3885+
del out_type
38433886
return ragged_dot(
38443887
lhs,
38453888
rhs,
@@ -3868,6 +3911,7 @@ def _ragged_dot_batch_rule(
38683911
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
38693912
precision=precision,
38703913
preferred_element_type=preferred_element_type,
3914+
out_type=None,
38713915
)
38723916

38733917

jax/_src/numpy/lax_numpy.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@
6767
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
6868
)
6969
from jax._src.util import (
70-
NumpyComplexWarning,
71-
canonicalize_axis as _canonicalize_axis,
72-
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
73-
from jax.sharding import Sharding, SingleDeviceSharding
70+
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
71+
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
72+
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
73+
PartitionSpec as P)
7474
from jax.tree_util import tree_flatten, tree_leaves, tree_map
7575
import numpy as np
7676
import opt_einsum
@@ -9081,6 +9081,7 @@ def einsum(
90819081
precision: PrecisionLike = None,
90829082
preferred_element_type: DTypeLike | None = None,
90839083
_dot_general: Callable[..., Array] = lax.dot_general,
9084+
out_type=None,
90849085
) -> Array: ...
90859086

90869087
@overload
@@ -9093,6 +9094,7 @@ def einsum(
90939094
precision: PrecisionLike = None,
90949095
preferred_element_type: DTypeLike | None = None,
90959096
_dot_general: Callable[..., Array] = lax.dot_general,
9097+
out_type=None,
90969098
) -> Array: ...
90979099

90989100
def einsum(
@@ -9103,6 +9105,7 @@ def einsum(
91039105
precision: PrecisionLike = None,
91049106
preferred_element_type: DTypeLike | None = None,
91059107
_dot_general: Callable[..., Array] = lax.dot_general,
9108+
out_type=None,
91069109
) -> Array:
91079110
"""Einstein summation
91089111
@@ -9334,11 +9337,11 @@ def einsum(
93349337

93359338
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
93369339

9337-
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
9340+
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
93389341
if spec is not None:
93399342
einsum = jax.named_call(einsum, name=spec)
93409343
return einsum(operands, contractions, precision,
9341-
preferred_element_type, _dot_general)
9344+
preferred_element_type, _dot_general, out_type)
93429345

93439346

93449347
# Enable other modules to override einsum_contact_path.
@@ -9437,7 +9440,15 @@ def _einsum(
94379440
precision,
94389441
preferred_element_type,
94399442
_dot_general=lax.dot_general,
9443+
out_type=None,
94409444
):
9445+
if out_type is not None and not config.sharding_in_types.value:
9446+
raise NotImplementedError("out_type only works when sharding_in_types "
9447+
"config is True.")
9448+
if out_type is not None and not isinstance(out_type, NamedSharding):
9449+
raise NotImplementedError(
9450+
"`out_type` argument of `einsum` only supports NamedSharding instances."
9451+
" Please file a bug if this is not enough for your use case.")
94419452
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
94429453
operands = list(map(asarray, operands))
94439454
if preferred_element_type is None:
@@ -9559,13 +9570,25 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
95599570
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
95609571
if names == result_names:
95619572
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
9573+
k_out_type = {} if out_type is None else {'out_type': out_type}
95629574
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
9563-
preferred_element_type=preferred_element_type)
9575+
preferred_element_type=preferred_element_type,
9576+
**k_out_type)
95649577
else:
95659578
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
9579+
if (config.sharding_in_types.value and out_type is not None and
9580+
names != result_names):
9581+
spec = out_type.spec
9582+
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
9583+
dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec))
9584+
else:
9585+
dot_general_out_type = out_type # type: ignore
95669586
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
9587+
dot_general_out_type = ({} if dot_general_out_type is None else # type: ignore
9588+
{'out_type': dot_general_out_type})
95679589
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
9568-
preferred_element_type=preferred_element_type)
9590+
preferred_element_type=preferred_element_type,
9591+
**dot_general_out_type)
95699592
else:
95709593
raise NotImplementedError # if this is actually reachable, open an issue!
95719594

@@ -9578,7 +9601,8 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
95789601
operand = lax.transpose(operand, perm)
95799602
operands.append(operand) # used in next iteration
95809603

9581-
return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)
9604+
return lax_internal._convert_element_type(operands[0], preferred_element_type,
9605+
output_weak_type)
95829606

95839607

95849608
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)

jax/_src/pallas/triton/lowering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2089,10 +2089,11 @@ def _dot_general_lowering(
20892089
b,
20902090
*,
20912091
dimension_numbers,
2092+
out_type,
20922093
precision,
20932094
preferred_element_type,
20942095
):
2095-
del preferred_element_type # Unused.
2096+
del preferred_element_type, out_type # Unused.
20962097
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
20972098
assert batch_dims == ((), ())
20982099

jax/experimental/jax2tf/jax2tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,7 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None):
21802180
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
21812181

21822182

2183-
def _dot_general(lhs, rhs, *, dimension_numbers,
2183+
def _dot_general(lhs, rhs, *, dimension_numbers, out_type,
21842184
precision: lax_internal.CanonicalPrecision,
21852185
preferred_element_type: DType | None,
21862186
_in_avals: Sequence[core.ShapedArray],

jax/experimental/sparse/bcoo.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,11 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc
606606

607607
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
608608

609-
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
610-
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
609+
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
610+
dimension_numbers: DotDimensionNumbers,
611+
precision: None = None,
612+
preferred_element_type: None = None,
613+
out_type=None) -> BCOO | Array:
611614
"""A general contraction operation.
612615
613616
Args:
@@ -625,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
625628
the result will be dense, of type ndarray.
626629
"""
627630
# TODO(jakevdp) make use of these?
628-
del precision # unused
631+
del precision, out_type # unused
629632
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
630633
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
631634
dimension_numbers)
@@ -1051,7 +1054,8 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
10511054
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
10521055
kwds = {'dimension_numbers': dimension_numbers,
10531056
'precision': None,
1054-
'preferred_element_type': None}
1057+
'preferred_element_type': None,
1058+
'out_type': None}
10551059
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
10561060
return A, B, indices
10571061

0 commit comments

Comments
 (0)