Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.experimental.sparse as jax_sparse
import jax.lax as lax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -529,6 +530,61 @@ def remat(f):
return jax.checkpoint(f)


def all_reduce(x, op="sum", axis_name="model"):
"""
Performs an **all-reduce** operation across all replicas in the specified
distribution axis.

The all-reduce operation computes a reduction (like sum, mean, or product)
of the input tensor `x` across all devices/replicas in the `axis_name`
group, and then broadcasts the result back to all participating devices.

Args:
x: The tensor to reduce.
op: The reduction operation to perform. Common options include "sum",
"mean", or "product". Defaults to "sum".
axis_name: The name of the distribution axis (e.g., "model",
"data") over which to perform the reduction. Defaults to "model".

Returns:
The result of the all-reduce operation, with the same shape as the
input `x`.
"""
if op == "sum":
return lax.psum(x, axis_name=axis_name)
elif op == "mean":
return lax.pmean(x, axis_name=axis_name)
else:
raise ValueError(
f"Unsupported reduction operation: {op}. "
"Supported options are 'sum' and 'mean'."
)


def all_gather(x, axis, axis_name="model"):
"""
Performs an all-gather operation across all replicas in the specified
distribution axis.

The all-gather operation collects the input tensor `x` from all devices
in the `axis_name` group and concatenates them along the specified `axis`.
This is often used in tensor parallelism to combine parts of a tensor
distributed across devices.

Args:
x: The tensor to gather.
axis: The dimension along which to concatenate the gathered tensors.
axis_name: The name of the distribution axis (e.g., "model",
"data") over which to perform the gather.
Defaults to "model".

Returns:
The gathered tensor, which will have a larger size along `axis`
dimension.
"""
return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
78 changes: 78 additions & 0 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os

import jax
Expand All @@ -9,6 +10,8 @@
from keras.src import backend
from keras.src import testing
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax.core import all_gather
from keras.src.backend.jax.core import all_reduce

if is_nnx_enabled():
from flax import nnx
Expand Down Expand Up @@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self):
state = jax.tree.map(lambda x: x + 1, state)
variable2 = nnx.merge(graphdef, state)
self.assertEqual(variable2._value, variable2.value)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for collective operations.",
)
@pytest.mark.skipif(
jax.local_device_count() < 2,
reason="Requires multiple local devices for testing.",
)
class JaxCollectiveOpsTest(testing.TestCase):
def test_all_reduce_sum(self):
"""Tests the all_reduce operation with the 'sum' reduction."""
num_devices = jax.local_device_count()
local_value = 10.0

local_inputs = jax.numpy.array([local_value] * num_devices)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def reduce_sum_fn(x):
return all_reduce(x, op="sum", axis_name="all")

result = reduce_sum_fn(local_inputs)
expected_sum = local_value * num_devices

self.assertTrue(np.allclose(result, expected_sum))
self.assertEqual(result.shape, (num_devices,))

def test_all_reduce_mean(self):
"""Tests the all_reduce operation with the 'mean' reduction."""
num_devices = jax.local_device_count()
local_value = 10.0

local_inputs = jax.numpy.array([local_value] * num_devices)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def reduce_mean_fn(x):
return all_reduce(x, op="mean", axis_name="all")

result = reduce_mean_fn(local_inputs)
expected_mean = local_value

self.assertTrue(np.allclose(result, expected_mean))
self.assertEqual(result.shape, (num_devices,))

def test_all_gather(self):
"""Tests the all_gather operation."""
num_devices = jax.local_device_count()
local_data = np.arange(5)

local_inputs = jax.numpy.stack(
[local_data + (i * 5) for i in range(num_devices)]
)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def gather_fn(x):
return all_gather(x, axis=0, axis_name="all")

result_array_on_devices = gather_fn(local_inputs)

expected_shape = (num_devices, num_devices * local_data.shape[0])
self.assertEqual(result_array_on_devices.shape, expected_shape)

expected_gathered_data = np.arange(num_devices * local_data.shape[0])

for i in range(num_devices):
self.assertTrue(
np.allclose(result_array_on_devices[i], expected_gathered_data)
)
43 changes: 43 additions & 0 deletions keras/src/distribution/tensor_parallel/tensor_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import collections

from keras.src import ops


def split_tensor_for_parallelism(tensor, index, device_count, dim):
"""Calculates a slice of a tensor along a specified dimension for a
given index.
This utility is used in tensor parallelism API to distribute a
tensor across multiple devices.
Args:
tensor: The full tensor to be sharded.
index: The index of the device/shard to return (e.g., 0, 1, 2...).
device_count: The total number of parallel devices or splits.
dim: The dimension along which to split the tensor. If -1, the
last dimension is used.
Returns:
A tensor slice corresponding to the given `index`.
"""
if dim == -1:
static_shape = getattr(tensor, "shape", None)
if static_shape is not None:
rank = len(static_shape)
else:
rank = None

if rank is not None:
split_dim = rank - 1
else:
split_dim = ops.ndim(tensor) - 1
else:
split_dim = dim

splits = ops.array_split(
tensor, indices_or_sections=device_count, axis=split_dim
)
return splits[index]


LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"])
163 changes: 163 additions & 0 deletions keras/src/distribution/tensor_parallel/tensor_layout_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from keras.src import ops
from keras.src import testing
from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap
from keras.src.distribution.tensor_parallel.tensor_layout import (
split_tensor_for_parallelism,
)


class LayoutTest(testing.TestCase):
"""Test suite for tensor layout actions and mappings."""

def test_split_with_even_division(self):
"""Tests splitting a tensor that divides evenly among workers."""
device_count = 4
dim = 0
tensor = ops.reshape(ops.arange(16, dtype="float32"), (8, 2))

expected_shard_0 = ops.array([[0.0, 1.0], [2.0, 3.0]])
expected_shard_2 = ops.array([[8.0, 9.0], [10.0, 11.0]])

shard_0 = split_tensor_for_parallelism(
tensor, index=0, device_count=device_count, dim=dim
)
shard_2 = split_tensor_for_parallelism(
tensor, index=2, device_count=device_count, dim=dim
)

self.assertAllClose(shard_0, expected_shard_0)
self.assertAllClose(shard_2, expected_shard_2)
self.assertEqual(shard_0.shape, (2, 2))

def test_split_with_uneven_division(self):
"""Tests splitting tensor where remainder is distributed correctly."""
device_count = 3
dim = 0
tensor = ops.reshape(ops.arange(10, dtype="float32"), (10, 1))

shard_0 = split_tensor_for_parallelism(
tensor, index=0, device_count=device_count, dim=dim
)
self.assertEqual(shard_0.shape, (4, 1))
self.assertAllClose(shard_0, ops.array([[0.0], [1.0], [2.0], [3.0]]))

shard_1 = split_tensor_for_parallelism(
tensor, index=1, device_count=device_count, dim=dim
)
self.assertEqual(shard_1.shape, (3, 1))
self.assertAllClose(shard_1, ops.array([[4.0], [5.0], [6.0]]))

shard_2 = split_tensor_for_parallelism(
tensor, index=2, device_count=device_count, dim=dim
)
self.assertEqual(shard_2.shape, (3, 1))
self.assertAllClose(shard_2, ops.array([[7.0], [8.0], [9.0]]))

def test_split_and_undo_cycle_even_removed(self):
"""
Confirms that the original tensor can be reconstructed.
"""
device_count = 2
dim = 0
original_tensor = ops.reshape(ops.arange(12, dtype="float32"), (6, 2))

shards = [
split_tensor_for_parallelism(
original_tensor, index=i, device_count=device_count, dim=dim
)
for i in range(device_count)
]

reconstructed_tensor = ops.concatenate(shards, axis=dim)

self.assertAllClose(original_tensor, reconstructed_tensor)

def test_split_and_undo_cycle_uneven_removed(self):
"""
Confirms that original tensor can be reconstructed with uneven split.
"""
device_count = 4
dim = 0
original_tensor = ops.reshape(ops.arange(22, dtype="float32"), (11, 2))

shards = [
split_tensor_for_parallelism(
original_tensor, index=i, device_count=device_count, dim=dim
)
for i in range(device_count)
]

self.assertEqual(shards[0].shape, (3, 2))
self.assertEqual(shards[1].shape, (3, 2))
self.assertEqual(shards[2].shape, (3, 2))
self.assertEqual(shards[3].shape, (2, 2))

reconstructed_tensor = ops.concatenate(shards, axis=dim)
self.assertAllClose(original_tensor, reconstructed_tensor)

def test_split_last_dimension(self):
"""Tests splitting on the last dimension using dim=-1."""
device_count = 3
dim = -1
original_tensor = ops.reshape(
ops.arange(30, dtype="float32"), (2, 5, 3)
)

shards = [
split_tensor_for_parallelism(
original_tensor, index=i, device_count=device_count, dim=dim
)
for i in range(device_count)
]

self.assertEqual(shards[0].shape, (2, 5, 1))
self.assertEqual(shards[1].shape, (2, 5, 1))
self.assertEqual(shards[2].shape, (2, 5, 1))

def test_split_with_sharding_type_hint(self):
"""Tests using 'row' and 'column' sharding hints for 2D tensors."""
device_count = 2
tensor = ops.reshape(ops.arange(16, dtype="float32"), (4, 4))

row_dim = 0
shard_row_0 = split_tensor_for_parallelism(
tensor, index=0, device_count=device_count, dim=row_dim
)
self.assertAllClose(shard_row_0, tensor[:2, :])

col_dim = 1
shard_col_0 = split_tensor_for_parallelism(
tensor, index=0, device_count=device_count, dim=col_dim
)
self.assertAllClose(shard_col_0, tensor[:, :2])

def test_layout_map_namedtuple_behavior(self):
"""Tests basic behavior of the LayoutMap namedtuple."""

def rule_kernel(tensor, index):
return split_tensor_for_parallelism(
tensor, index=index, device_count=2, dim=0
)

def rule_output(tensor, index):
return split_tensor_for_parallelism(
tensor, index=index, device_count=2, dim=-1
)

state_rules = {"kernel": rule_kernel}
output_rules = {"output": rule_output}

layout_map = LayoutMap(
state_rules=state_rules, output_rules=output_rules
)

self.assertIs(layout_map.state_rules, state_rules)
self.assertIs(layout_map.output_rules, output_rules)

self.assertIs(layout_map[0], state_rules)
self.assertIs(layout_map[1], output_rules)

with self.assertRaises(AttributeError):
layout_map.state_rules = {}

self.assertTrue(callable(layout_map.state_rules["kernel"]))
Loading