|
18 | 18 | import jax.numpy as jnp |
19 | 19 | import numpy as np |
20 | 20 | from absl.testing import absltest |
| 21 | +from jax._src import config |
21 | 22 | from jax._src import test_util as jtu |
22 | 23 | from jax.sharding import NamedSharding, PartitionSpec as P |
23 | 24 | from jax.experimental.shard_alike import shard_alike |
@@ -221,18 +222,16 @@ def test_shard_alike_inputs(self): |
221 | 222 | mesh = jtu.create_mesh((2,), ('x',)) |
222 | 223 | np_inp = np.arange(8.) |
223 | 224 | s = NamedSharding(mesh, P('x')) |
224 | | - rep_s = NamedSharding(mesh, P()) |
225 | 225 | arr = jax.device_put(np_inp, s) |
226 | | - arr2 = jax.device_put(np_inp, rep_s) |
227 | 226 |
|
228 | 227 | def f(x, y): |
229 | 228 | return shard_alike(x, y) |
230 | 229 |
|
231 | | - eager_out1, eager_out2 = f(arr, arr2) |
| 230 | + eager_out1, eager_out2 = f(arr, np_inp) |
232 | 231 | self.assertEqual(eager_out1.sharding, s) |
233 | 232 | self.assertEqual(eager_out2.sharding, s) |
234 | 233 |
|
235 | | - out1, out2 = jax.jit(f)(arr, arr2) |
| 234 | + out1, out2 = jax.jit(f)(arr, np_inp) |
236 | 235 | self.assertEqual(out1.sharding, s) |
237 | 236 | self.assertEqual(out2.sharding, s) |
238 | 237 |
|
@@ -282,6 +281,5 @@ def test_sharding_preserverd_single_device(self): |
282 | 281 | _, y = shard_alike(x, jnp.arange(8)) |
283 | 282 | self.assertEqual(y.sharding, s) |
284 | 283 |
|
285 | | - |
286 | 284 | if __name__ == '__main__': |
287 | 285 | absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments