Skip to content

Commit 4650a54

Browse files
1 parent 9fcc4e7 commit 4650a54

File tree

8 files changed

+672
-485
lines changed

8 files changed

+672
-485
lines changed

keras/src/backend/jax/core.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import jax
22
import jax.experimental.sparse as jax_sparse
3+
import jax.lax as lax
34
import jax.numpy as jnp
45
import ml_dtypes
56
import numpy as np
@@ -529,6 +530,61 @@ def remat(f):
529530
return jax.checkpoint(f)
530531

531532

533+
def all_reduce(x, op="sum", axis_name="model"):
534+
"""
535+
Performs an **all-reduce** operation across all replicas in the specified
536+
distribution axis.
537+
538+
The all-reduce operation computes a reduction (like sum or mean)
539+
of the input tensor `x` across all devices/replicas in the `axis_name`
540+
group, and then broadcasts the result back to all participating devices.
541+
542+
Args:
543+
x: The tensor to reduce.
544+
op: The reduction operation to perform. Common options include "sum"
545+
and "mean". Defaults to "sum".
546+
axis_name: The name of the distribution axis (e.g., "model",
547+
"data") over which to perform the reduction. Defaults to "model".
548+
549+
Returns:
550+
The result of the all-reduce operation, with the same shape as the
551+
input `x`.
552+
"""
553+
if op == "sum":
554+
return lax.psum(x, axis_name=axis_name)
555+
elif op == "mean":
556+
return lax.pmean(x, axis_name=axis_name)
557+
else:
558+
raise ValueError(
559+
f"Unsupported reduction operation: {op}. "
560+
"Supported options are 'sum' and 'mean'."
561+
)
562+
563+
564+
def all_gather(x, axis, axis_name="model"):
565+
"""
566+
Performs an all-gather operation across all replicas in the specified
567+
distribution axis.
568+
569+
The all-gather operation collects the input tensor `x` from all devices
570+
in the `axis_name` group and concatenates them along the specified `axis`.
571+
This is often used in tensor parallelism to combine parts of a tensor
572+
distributed across devices.
573+
574+
Args:
575+
x: The tensor to gather.
576+
axis: The dimension along which to concatenate the gathered tensors.
577+
axis_name: The name of the distribution axis (e.g., "model",
578+
"data") over which to perform the gather.
579+
Defaults to "model".
580+
581+
Returns:
582+
The gathered tensor, which will have a larger size along `axis`
583+
dimension.
584+
"""
585+
return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True)
586+
587+
532588
class name_scope(base_name_scope):
533589
def __init__(self, name, **kwargs):
534590
super().__init__(name, **kwargs)

keras/src/backend/jax/core_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import os
23

34
import jax
@@ -9,6 +10,8 @@
910
from keras.src import backend
1011
from keras.src import testing
1112
from keras.src.backend.config import is_nnx_enabled
13+
from keras.src.backend.jax.core import all_gather
14+
from keras.src.backend.jax.core import all_reduce
1215

1316
if is_nnx_enabled():
1417
from flax import nnx
@@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self):
6669
state = jax.tree.map(lambda x: x + 1, state)
6770
variable2 = nnx.merge(graphdef, state)
6871
self.assertEqual(variable2._value, variable2.value)
72+
73+
74+
@pytest.mark.skipif(
75+
backend.backend() != "jax",
76+
reason="JAX backend specific test for collective operations.",
77+
)
78+
@pytest.mark.skipif(
79+
jax.local_device_count() < 2,
80+
reason="Requires multiple local devices for testing.",
81+
)
82+
class JaxCollectiveOpsTest(testing.TestCase):
83+
def test_all_reduce_sum(self):
84+
"""Tests the all_reduce operation with the 'sum' reduction."""
85+
num_devices = jax.local_device_count()
86+
local_value = 10.0
87+
88+
local_inputs = jax.numpy.array([local_value] * num_devices)
89+
90+
@functools.partial(
91+
jax.pmap, axis_name="all", devices=jax.devices("cpu")
92+
)
93+
def reduce_sum_fn(x):
94+
return all_reduce(x, op="sum", axis_name="all")
95+
96+
result = reduce_sum_fn(local_inputs)
97+
expected_sum = local_value * num_devices
98+
99+
self.assertTrue(np.allclose(result, expected_sum))
100+
self.assertEqual(result.shape, (num_devices,))
101+
102+
def test_all_reduce_mean(self):
103+
"""Tests the all_reduce operation with the 'mean' reduction."""
104+
num_devices = jax.local_device_count()
105+
local_value = 10.0
106+
107+
local_inputs = jax.numpy.array([local_value] * num_devices)
108+
109+
@functools.partial(
110+
jax.pmap, axis_name="all", devices=jax.devices("cpu")
111+
)
112+
def reduce_mean_fn(x):
113+
return all_reduce(x, op="mean", axis_name="all")
114+
115+
result = reduce_mean_fn(local_inputs)
116+
expected_mean = local_value
117+
118+
self.assertTrue(np.allclose(result, expected_mean))
119+
self.assertEqual(result.shape, (num_devices,))
120+
121+
def test_all_gather(self):
122+
"""Tests the all_gather operation."""
123+
num_devices = jax.local_device_count()
124+
local_data = np.arange(5)
125+
126+
local_inputs = jax.numpy.stack(
127+
[local_data + (i * 5) for i in range(num_devices)]
128+
)
129+
130+
@functools.partial(
131+
jax.pmap, axis_name="all", devices=jax.devices("cpu")
132+
)
133+
def gather_fn(x):
134+
return all_gather(x, axis=0, axis_name="all")
135+
136+
result_array_on_devices = gather_fn(local_inputs)
137+
138+
expected_shape = (num_devices, num_devices * local_data.shape[0])
139+
self.assertEqual(result_array_on_devices.shape, expected_shape)
140+
141+
expected_gathered_data = np.arange(num_devices * local_data.shape[0])
142+
143+
for i in range(num_devices):
144+
self.assertTrue(
145+
np.allclose(result_array_on_devices[i], expected_gathered_data)
146+
)

0 commit comments

Comments
 (0)