|
| 1 | +import functools |
1 | 2 | import os |
2 | 3 |
|
3 | 4 | import jax |
|
9 | 10 | from keras.src import backend |
10 | 11 | from keras.src import testing |
11 | 12 | from keras.src.backend.config import is_nnx_enabled |
| 13 | +from keras.src.backend.jax.core import all_gather |
| 14 | +from keras.src.backend.jax.core import all_reduce |
12 | 15 |
|
13 | 16 | if is_nnx_enabled(): |
14 | 17 | from flax import nnx |
@@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self): |
66 | 69 | state = jax.tree.map(lambda x: x + 1, state) |
67 | 70 | variable2 = nnx.merge(graphdef, state) |
68 | 71 | self.assertEqual(variable2._value, variable2.value) |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.skipif( |
| 75 | + backend.backend() != "jax", |
| 76 | + reason="JAX backend specific test for collective operations.", |
| 77 | +) |
| 78 | +@pytest.mark.skipif( |
| 79 | + jax.local_device_count() < 2, |
| 80 | + reason="Requires multiple local devices for testing.", |
| 81 | +) |
| 82 | +class JaxCollectiveOpsTest(testing.TestCase): |
| 83 | + def test_all_reduce_sum(self): |
| 84 | + """Tests the all_reduce operation with the 'sum' reduction.""" |
| 85 | + num_devices = jax.local_device_count() |
| 86 | + local_value = 10.0 |
| 87 | + |
| 88 | + local_inputs = jax.numpy.array([local_value] * num_devices) |
| 89 | + |
| 90 | + @functools.partial( |
| 91 | + jax.pmap, axis_name="all", devices=jax.devices("cpu") |
| 92 | + ) |
| 93 | + def reduce_sum_fn(x): |
| 94 | + return all_reduce(x, op="sum", axis_name="all") |
| 95 | + |
| 96 | + result = reduce_sum_fn(local_inputs) |
| 97 | + expected_sum = local_value * num_devices |
| 98 | + |
| 99 | + self.assertTrue(np.allclose(result, expected_sum)) |
| 100 | + self.assertEqual(result.shape, (num_devices,)) |
| 101 | + |
| 102 | + def test_all_reduce_mean(self): |
| 103 | + """Tests the all_reduce operation with the 'mean' reduction.""" |
| 104 | + num_devices = jax.local_device_count() |
| 105 | + local_value = 10.0 |
| 106 | + |
| 107 | + local_inputs = jax.numpy.array([local_value] * num_devices) |
| 108 | + |
| 109 | + @functools.partial( |
| 110 | + jax.pmap, axis_name="all", devices=jax.devices("cpu") |
| 111 | + ) |
| 112 | + def reduce_mean_fn(x): |
| 113 | + return all_reduce(x, op="mean", axis_name="all") |
| 114 | + |
| 115 | + result = reduce_mean_fn(local_inputs) |
| 116 | + expected_mean = local_value |
| 117 | + |
| 118 | + self.assertTrue(np.allclose(result, expected_mean)) |
| 119 | + self.assertEqual(result.shape, (num_devices,)) |
| 120 | + |
| 121 | + def test_all_gather(self): |
| 122 | + """Tests the all_gather operation.""" |
| 123 | + num_devices = jax.local_device_count() |
| 124 | + local_data = np.arange(5) |
| 125 | + |
| 126 | + local_inputs = jax.numpy.stack( |
| 127 | + [local_data + (i * 5) for i in range(num_devices)] |
| 128 | + ) |
| 129 | + |
| 130 | + @functools.partial( |
| 131 | + jax.pmap, axis_name="all", devices=jax.devices("cpu") |
| 132 | + ) |
| 133 | + def gather_fn(x): |
| 134 | + return all_gather(x, axis=0, axis_name="all") |
| 135 | + |
| 136 | + result_array_on_devices = gather_fn(local_inputs) |
| 137 | + |
| 138 | + expected_shape = (num_devices, num_devices * local_data.shape[0]) |
| 139 | + self.assertEqual(result_array_on_devices.shape, expected_shape) |
| 140 | + |
| 141 | + expected_gathered_data = np.arange(num_devices * local_data.shape[0]) |
| 142 | + |
| 143 | + for i in range(num_devices): |
| 144 | + self.assertTrue( |
| 145 | + np.allclose(result_array_on_devices[i], expected_gathered_data) |
| 146 | + ) |
0 commit comments