Skip to content

Commit afd8239

Browse files
VarchoGoogle-ML-Automation
authored andcommitted
[SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.
PiperOrigin-RevId: 694567084
1 parent 4d1a126 commit afd8239

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

jax/_src/shard_alike.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import partial
1616
import itertools
1717

18+
from jax._src import config
1819
from jax._src import core
1920
from jax._src.interpreters import ad
2021
from jax._src.interpreters import mlir
@@ -24,7 +25,7 @@
2425
from jax._src.util import safe_zip
2526
from jax._src.lib import xla_client as xc
2627
from jax._src.api_util import shaped_abstractify
27-
from jax._src.lib.mlir import ir
28+
from jax._src.lib.mlir import dialects, ir
2829

2930
_next_shard_group_id = itertools.count()
3031

@@ -91,6 +92,11 @@ def _group_shard(
9192
) -> tuple[ir.Value, ir.Value]:
9293
shard_group_id = next(_next_shard_group_id)
9394

95+
if config.use_shardy_partitioner.value:
96+
dialects.sdy.ShardingGroupOp(x, shard_group_id)
97+
dialects.sdy.ShardingGroupOp(y, shard_group_id)
98+
return x, y
99+
94100
unknown_op_sharding = xc.OpSharding()
95101
unknown_op_sharding.type = xc.OpSharding.Type.UNKNOWN
96102
unknown_op_sharding.is_shard_group = True

tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ jax_multiplatform_test(
283283
"tpu_v3_2x2",
284284
"tpu_v5e_4x2",
285285
"tpu_v4_2x2",
286+
"tpu_v3_2x2_shardy",
286287
],
287288
deps = [
288289
"//jax:experimental",

tests/shard_alike_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jax.numpy as jnp
1919
import numpy as np
2020
from absl.testing import absltest
21+
from jax._src import config
2122
from jax._src import test_util as jtu
2223
from jax.sharding import NamedSharding, PartitionSpec as P
2324
from jax.experimental.shard_alike import shard_alike
@@ -221,18 +222,16 @@ def test_shard_alike_inputs(self):
221222
mesh = jtu.create_mesh((2,), ('x',))
222223
np_inp = np.arange(8.)
223224
s = NamedSharding(mesh, P('x'))
224-
rep_s = NamedSharding(mesh, P())
225225
arr = jax.device_put(np_inp, s)
226-
arr2 = jax.device_put(np_inp, rep_s)
227226

228227
def f(x, y):
229228
return shard_alike(x, y)
230229

231-
eager_out1, eager_out2 = f(arr, arr2)
230+
eager_out1, eager_out2 = f(arr, np_inp)
232231
self.assertEqual(eager_out1.sharding, s)
233232
self.assertEqual(eager_out2.sharding, s)
234233

235-
out1, out2 = jax.jit(f)(arr, arr2)
234+
out1, out2 = jax.jit(f)(arr, np_inp)
236235
self.assertEqual(out1.sharding, s)
237236
self.assertEqual(out2.sharding, s)
238237

@@ -282,6 +281,5 @@ def test_sharding_preserverd_single_device(self):
282281
_, y = shard_alike(x, jnp.arange(8))
283282
self.assertEqual(y.sharding, s)
284283

285-
286284
if __name__ == '__main__':
287285
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)