diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..c7a2d8789612 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -43,8 +43,6 @@ elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable - - distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index 371a62cd0f52..ba12e71cfb0e 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -16,6 +16,7 @@ from keras.src.backend.common.name_scope import name_scope from keras.src.backend.torch import core +from keras.src.backend.torch import distribution_lib from keras.src.backend.torch import image from keras.src.backend.torch import linalg from keras.src.backend.torch import math diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index f76d2f70935e..6037aa3c702e 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -2,10 +2,12 @@ import contextlib import functools import os +import threading import ml_dtypes import numpy as np import torch +import torch.func from keras.src import tree from keras.src.backend.common import KerasVariable @@ -19,6 +21,7 @@ from keras.src.backend.common.stateless_scope import in_stateless_scope from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.config import floatx +from keras.src.utils.module_utils import torch_xla SUPPORTS_SPARSE_TENSORS = False SUPPORTS_RAGGED_TENSORS = False @@ -35,6 +38,8 @@ DEFAULT_DEVICE = "cuda" elif hasattr(torch, "xpu") and torch.xpu.is_available(): DEFAULT_DEVICE = "xpu" +elif torch_xla.available: + DEFAULT_DEVICE = "xla" else: DEFAULT_DEVICE = "cpu" @@ -102,17 +107,48 @@ def to_torch_dtype(dtype): class Variable(KerasVariable): + def __init__(self, *args, layout=None, **kwargs): + self._layout = layout + super().__init__(*args, **kwargs) + + def _initialize_layout(self): + """Initialize the variable layout based on the current distribution.""" + distribution = global_state.get_global_attribute("distribution") + if self._layout is None and distribution is not None: + self._layout = distribution.get_variable_layout(self) + if self._layout is None: + from keras.src.distribution import TensorLayout + + self._layout = TensorLayout( + [None] * len(self._shape), distribution.device_mesh + ) + def _initialize(self, value): + self._shape = self._validate_shape(value.shape) + self._initialize_layout() if isinstance(value, torch.nn.Parameter): # Reuse same parameter self._value = value else: + requires_grad = self.trainable and torch.is_floating_point( + convert_to_tensor(value, dtype=self._dtype) + ) self._value = torch.nn.Parameter( convert_to_tensor(value, dtype=self._dtype), - requires_grad=self.trainable, + requires_grad=requires_grad, ).to(get_device()) + if self._layout is not None: + from keras.src.backend.torch import distribution_lib + + self._value = distribution_lib.distribute_variable( + self._value, self._layout + ) def _direct_assign(self, value): + if self._layout is not None: + from keras.src.backend.torch import distribution_lib + + value = distribution_lib.distribute_variable(value, self._layout) with torch.no_grad(): self.value.copy_(value) @@ -122,13 +158,15 @@ def _convert_to_tensor(self, value, dtype=None): # Overload native accessor. @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - args = [arg.value if isinstance(arg, Variable) else arg for arg in args] + def unwrap(x): + if isinstance(x, Variable): + return x.value + return x + if kwargs is None: kwargs = {} - kwargs = { - key: value.value if isinstance(value, Variable) else value - for key, value in kwargs.items() - } + args = tree.map_structure(unwrap, args) + kwargs = tree.map_structure(unwrap, kwargs) return func(*args, **kwargs) def __array__(self, dtype=None): @@ -208,17 +246,19 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): return x if dtype is None: if isinstance(x, bool): - return torch.as_tensor(x, dtype=torch.bool, device=get_device()) + res = torch.as_tensor(x, dtype=torch.bool, device=get_device()) + return maybe_distribute_tensor(res) elif isinstance(x, int): if x < -(2**31) or x >= 2**31: - return torch.as_tensor( - x, dtype=torch.int64, device=get_device() - ) - return torch.as_tensor(x, dtype=torch.int32, device=get_device()) + res = torch.as_tensor(x, dtype=torch.int64, device=get_device()) + else: + res = torch.as_tensor(x, dtype=torch.int32, device=get_device()) + return maybe_distribute_tensor(res) elif isinstance(x, float): - return torch.as_tensor( + res = torch.as_tensor( x, dtype=to_torch_dtype(floatx()), device=get_device() ) + return maybe_distribute_tensor(res) # Convert to np in case of any array-like that is not list or tuple. if not isinstance(x, (list, tuple)): @@ -240,7 +280,8 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)] ) dtype = to_torch_dtype(dtype) - return torch.as_tensor(x, dtype=dtype, device=get_device()) + res = torch.as_tensor(x, dtype=dtype, device=get_device()) + return maybe_distribute_tensor(res) def convert_to_numpy(x): @@ -248,6 +289,7 @@ def transform(x): if is_tensor(x): if x.requires_grad: x = x.detach() + x = _ensure_replicated_local(x) # Tensor has to be moved to CPU before converting to numpy. if x.device != torch.device("cpu"): x = x.cpu() @@ -287,9 +329,10 @@ def cast(x, dtype): x = x.value if is_tensor(x): if x.dtype == dtype: - return x + res = x else: - return x.to(dtype) + res = x.to(dtype) + return maybe_distribute_tensor(res) return convert_to_tensor(x, dtype) @@ -614,6 +657,8 @@ def scatter_update(inputs, indices, updates, reduction=None): def slice(inputs, start_indices, shape): shape_dtype = to_torch_dtype("int64") inputs = convert_to_tensor(inputs) + if hasattr(inputs, "device_mesh"): + inputs = _ensure_replicated_local(inputs) start_indices = convert_to_tensor(start_indices).to(shape_dtype) shape = convert_to_tensor(shape).to(shape_dtype) @@ -622,7 +667,8 @@ def slice(inputs, start_indices, shape): python_slice(start_index, start_index + length) for start_index, length in zip(start_indices, shape) ] - return inputs[slices] + res = inputs[slices] + return maybe_distribute_tensor(res) def slice_update(inputs, start_indices, updates): @@ -631,6 +677,12 @@ def slice_update(inputs, start_indices, updates): start_indices = convert_to_tensor(start_indices).to(shape_dtype) updates = convert_to_tensor(updates) + if hasattr(inputs, "device_mesh"): + inputs = _ensure_replicated_local(inputs) + + if hasattr(updates, "device_mesh"): + updates = _ensure_replicated_local(updates) + python_slice = __builtins__["slice"] slices = [ python_slice(start_index, start_index + update_length) @@ -638,7 +690,7 @@ def slice_update(inputs, start_indices, updates): ] outputs = torch.clone(inputs) outputs[slices] = updates - return outputs + return maybe_distribute_tensor(outputs) def switch(index, branches, *operands): @@ -765,3 +817,182 @@ def backward(ctx, grad_output): if not isinstance(grads, tuple): grads = (grads,) return (None,) + grads + + +def _is_sharded(tensor): + """Check if a tensor is sharded across a device mesh.""" + if hasattr(tensor, "placements"): + from torch.distributed.tensor import Shard + + return any(isinstance(p, Shard) for p in tensor.placements) + return False + + +def _ensure_replicated_local(x): + """Ensure a sharded tensor is replicated and converted to a local tensor.""" + if is_tensor(x) and hasattr(x, "device_mesh"): + from torch.distributed.tensor import Replicate + + if _is_sharded(x): + x = x.redistribute( + x.device_mesh, [Replicate()] * x.device_mesh.ndim + ) + return x.to_local() + return x + + +def _sharding_aware_op(fn): + """Decorator to make a Torch operation aware of sharded tensors.""" + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if hasattr(self, "device_mesh"): + res = fn(_ensure_replicated_local(self), *args, **kwargs) + return tree.map_structure(maybe_distribute_tensor, res) + return fn(self, *args, **kwargs) + + return wrapper + + +_original_getitem = torch.Tensor.__getitem__ + + +def _sharding_aware_getitem(self, *args, **kwargs): + """Sharding-aware implementation of __getitem__.""" + if hasattr(self, "device_mesh"): + index = args[0] + if isinstance(index, (list, np.ndarray)): + index = tuple(index) + res = _original_getitem(_ensure_replicated_local(self), index, **kwargs) + return maybe_distribute_tensor(res) + return _original_getitem(self, *args, **kwargs) + + +_original_unbind_fn = torch.unbind + + +def _sharding_aware_unbind_fn(input, *args, **kwargs): + """Sharding-aware implementation of torch.unbind.""" + if hasattr(input, "device_mesh"): + res = _original_unbind_fn( + _ensure_replicated_local(input), *args, **kwargs + ) + return tree.map_structure(maybe_distribute_tensor, res) + return _original_unbind_fn(input, *args, **kwargs) + + +_original_broadcast_to = torch.broadcast_to + + +def _sharding_aware_broadcast_to(input, *args, **kwargs): + """Sharding-aware implementation of torch.broadcast_to.""" + if hasattr(input, "device_mesh"): + res = _original_broadcast_to( + _ensure_replicated_local(input), *args, **kwargs + ) + return maybe_distribute_tensor(res) + return _original_broadcast_to(input, *args, **kwargs) + + +_original_einsum = torch.einsum + + +def _sharding_aware_einsum(subscripts, *operands, **kwargs): + """Sharding-aware implementation of torch.einsum.""" + new_operands = [] + any_dtensor = False + for x in operands: + if is_tensor(x) and hasattr(x, "device_mesh"): + new_operands.append(_ensure_replicated_local(x)) + any_dtensor = True + else: + new_operands.append(x) + if any_dtensor: + res = _original_einsum(subscripts, *new_operands, **kwargs) + return maybe_distribute_tensor(res) + return _original_einsum(subscripts, *operands, **kwargs) + + +torch.Tensor.reshape = _sharding_aware_op(torch.Tensor.reshape) +torch.Tensor.view = _sharding_aware_op(torch.Tensor.view) +torch.Tensor.expand = _sharding_aware_op(torch.Tensor.expand) +torch.Tensor.unbind = _sharding_aware_op(torch.Tensor.unbind) +torch.Tensor.squeeze = _sharding_aware_op(torch.Tensor.squeeze) +torch.Tensor.unsqueeze = _sharding_aware_op(torch.Tensor.unsqueeze) +torch.Tensor.__getitem__ = _sharding_aware_getitem +torch.unbind = _sharding_aware_unbind_fn +torch.broadcast_to = _sharding_aware_broadcast_to +torch.einsum = _sharding_aware_einsum + +_original_detach = torch.Tensor.detach + + +def _sharding_aware_detach(self, *args, **kwargs): + """Sharding-aware implementation of torch.Tensor.detach.""" + if hasattr(self, "device_mesh"): + res = _original_detach(_ensure_replicated_local(self), *args, **kwargs) + return maybe_distribute_tensor(res) + return _original_detach(self, *args, **kwargs) + + +torch.Tensor.detach = _sharding_aware_detach + + +def _distribution_aware_creation_op(fn): + """Decorator to make a tensor creation operation distribution-aware.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + res = fn(*args, **kwargs) + return maybe_distribute_tensor(res) + + return wrapper + + +torch.func.jvp(lambda x: x, (torch.tensor(1.0),), (torch.tensor(1.0),)) + + +for name in [ + "arange", + "ones", + "zeros", + "eye", + "full", + "linspace", + "logspace", + "ones_like", + "zeros_like", + "rand", + "randn", + "randint", +]: + if hasattr(torch, name): + setattr( + torch, name, _distribution_aware_creation_op(getattr(torch, name)) + ) + + +_DISTRIBUTION_AWARE_ACTIVE = threading.local() + + +def maybe_distribute_tensor(res): + """Distribute a tensor if a distribution is currently active.""" + if not is_tensor(res) or hasattr(res, "device_mesh"): + return res + if getattr(_DISTRIBUTION_AWARE_ACTIVE, "active", False): + return res + distribution = global_state.get_global_attribute("distribution") + if distribution is not None: + _DISTRIBUTION_AWARE_ACTIVE.active = True + try: + from keras.src.backend.torch import distribution_lib + from keras.src.distribution import TensorLayout + + mesh = distribution.device_mesh.backend_mesh + if str(res.device).split(":")[0] != mesh.device_type: + res = res.to(mesh.device_type) + layout = TensorLayout([None] * res.ndim, distribution.device_mesh) + return distribution_lib.distribute_tensor(res, layout) + finally: + _DISTRIBUTION_AWARE_ACTIVE.active = False + return res diff --git a/keras/src/backend/torch/distribution_lib.py b/keras/src/backend/torch/distribution_lib.py new file mode 100644 index 000000000000..cdfcd4a2f846 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib.py @@ -0,0 +1,426 @@ +"""Utilities for distribution strategy with Torch backend.""" + +import os + +import torch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor import Replicate +from torch.distributed.tensor import Shard +from torch.distributed.tensor import ( + distribute_tensor as torch_distribute_tensor, +) +from torch.distributed.tensor.parallel import ColwiseParallel +from torch.distributed.tensor.parallel import RowwiseParallel +from torch.distributed.tensor.parallel import parallelize_module + +from keras.src.backend.common import global_state + + +def list_devices(device_type=None): + """List all available devices for the given type. + + Args: + device_type: String, either "cpu", "gpu"/"cuda", or "xla". + If None, the default device type is used. + + Returns: + A list of device strings (e.g., ["cuda:0", "cuda:1"]). + """ + if device_type is None: + from keras.src.backend.torch.core import get_device + + device_type = str(get_device()).split(":")[0] + else: + device_type = device_type.lower() + if device_type == "gpu": + device_type = "cuda" + + if device_type == "cuda": + num_devices = torch.cuda.device_count() + elif device_type == "xla": + from keras.src.utils.module_utils import torch_xla + + if torch_xla.available: + import torch_xla.core.xla_model as xm + + num_devices = len(xm.get_xla_supported_devices()) + else: + num_devices = 0 + elif device_type == "cpu": + num_devices = 1 + else: + num_devices = 0 + + return [f"{device_type}:{i}" for i in range(num_devices)] + + +def get_device_count(device_type=None): + """Get the number of available devices for the given type. + + Args: + device_type: String, either "cpu", "gpu"/"cuda", or "xla". + + Returns: + Integer count of available devices. + """ + return len(list_devices(device_type)) + + +def initialize(job_addresses=None, num_processes=None, process_id=None): + """Initialize the Torch distributed process group. + + Args: + job_addresses: Optional string of comma-separated master addresses. + num_processes: Optional integer, total number of processes. + process_id: Optional integer, rank of the current process. + """ + if not torch.distributed.is_initialized(): + if job_addresses: + master_addr, master_port = job_addresses.split(",")[0].split(":") + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "127.0.0.1" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + + if num_processes is not None: + os.environ["WORLD_SIZE"] = str(num_processes) + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "1" + + if process_id is not None: + os.environ["RANK"] = str(process_id) + if "RANK" not in os.environ: + os.environ["RANK"] = "0" + + from keras.src.backend.torch.core import get_device + + device_type = str(get_device()).split(":")[0] + + if device_type == "xla": + backend = "xla" + elif device_type == "cuda": + backend = "nccl" + else: + backend = "gloo" + torch.distributed.init_process_group(backend=backend) + + +def num_processes(): + """Get the number of processes in the distributed group. + + Returns: + Integer number of processes. + """ + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + return 1 + + +def process_id(): + """Get the rank of the current process. + + Returns: + Integer rank of the process. + """ + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 + + +def _to_backend_mesh(device_mesh): + """Convert Keras DeviceMesh to Torch DeviceMesh. + + Args: + device_mesh: The Keras DeviceMesh instance. + + Returns: + A Torch DeviceMesh instance. + """ + from keras.src.backend.torch.core import get_device + + device_type = str(get_device()).split(":")[0] + mesh_shape = device_mesh.shape + return init_device_mesh( + device_type, mesh_shape, mesh_dim_names=device_mesh.axis_names + ) + + +def _to_backend_layout(tensor_layout): + """Convert Keras TensorLayout to Torch placements. + + Args: + tensor_layout: The Keras TensorLayout instance. + + Returns: + A tuple of (Torch DeviceMesh, list of placements). + """ + if tensor_layout.device_mesh is None: + raise ValueError( + "Cannot create sharding when device mesh is not set for " + "TensorLayout." + ) + + device_mesh = tensor_layout.device_mesh + torch_mesh = device_mesh.backend_mesh + + placements = [] + for mesh_dim_name in device_mesh.axis_names: + shard_dim = None + for i, axis in enumerate(tensor_layout.axes): + if axis == mesh_dim_name: + shard_dim = i + break + if shard_dim is not None: + placements.append(Shard(shard_dim)) + else: + placements.append(Replicate()) + + return (torch_mesh, placements) + + +class DDPModelWrapper(torch.nn.Module): + """A wrapper for Keras models to be used with PyTorch DDP. + + This wrapper avoids DDP's recursive traversal of Keras layer attributes, + which can lead to infinite recursion due to the way Keras tracks variables + and layers. + """ + + def __init__(self, keras_model): + super().__init__() + self._keras_model = [keras_model] + + def parameters(self, recurse=True): + for var in self._keras_model[0].variables: + yield var.value + + def named_parameters(self, prefix="", recurse=True, remove_duplicate=True): + for var in self._keras_model[0].variables: + yield prefix + var.path, var.value + + def forward(self, *args, **kwargs): + return self._keras_model[0](*args, **kwargs) + + +def distribute_variable(value, layout): + """Distribute a Torch variable based on the given layout. + + Args: + value: The Torch tensor or Parameter to distribute. + layout: The layout to apply (Torch mesh and placements). + + Returns: + A distributed Torch tensor or Parameter. + """ + is_parameter = isinstance(value, torch.nn.Parameter) + requires_grad = value.requires_grad if is_parameter else False + + sharded_tensor = distribute_tensor(value, layout) + + if is_parameter: + res = torch.nn.Parameter(sharded_tensor, requires_grad=requires_grad) + if hasattr(value, "constraint"): + res.constraint = value.constraint + else: + res.constraint = None + return res + return sharded_tensor + + +def distribute_tensor(tensor, layout): + """Distribute a Torch tensor based on the given layout. + + Args: + tensor: The Torch tensor to distribute. + layout: The layout to apply. Can be a Keras TensorLayout or a backend + layout tuple. + + Returns: + A distributed Torch tensor (DTensor). + """ + from keras.src.distribution import DataParallel + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + layout = layout.backend_layout + + mesh, placements = layout + mesh_device_type = mesh.device_type + + distribution = global_state.get_global_attribute("distribution") + if ( + isinstance(distribution, DataParallel) + and distribution._is_multi_process + ): + return tensor + + if hasattr(tensor, "device_mesh"): + return tensor.redistribute(mesh, placements) + + if str(tensor.device).split(":")[0] != mesh_device_type: + tensor = tensor.to(mesh_device_type) + + if not tensor.is_leaf: + res = DTensor.from_local(tensor, mesh, placements, run_check=False) + else: + res = torch_distribute_tensor(tensor, mesh, placements) + + return res + + +def distribute_data_input(per_process_batch, layout, batch_dim_name): + """Distribute the input data based on the given layout. + + Args: + per_process_batch: The local Torch tensor for the current process. + layout: The layout to apply. + batch_dim_name: Name of the batch dimension (unused). + + Returns: + A distributed Torch tensor. + """ + return distribute_tensor(per_process_batch, layout) + + +def parallelize_layer(layer, distribution): + """Parallelize a layer based on the given distribution. + + Args: + layer: The Keras Layer or Model instance to parallelize. + distribution: The Keras Distribution instance. + """ + from torch.nn.parallel import DistributedDataParallel as DDP + + from keras.src.backend.torch.core import Variable + from keras.src.distribution import DataParallel + from keras.src.distribution import ModelParallel + + if not isinstance(distribution, (ModelParallel, DataParallel)): + return + + if getattr(layer, "_is_parallelized", False): + return + + mesh = distribution.device_mesh.backend_mesh + + if isinstance(distribution, ModelParallel): + layout_map = distribution._layout_map + variable_to_attr = {} + param_id_to_var = {id(var.value): var for var in layer.variables} + + def find_variables(obj): + if isinstance(obj, Variable): + return + + for _, child in obj.named_children(): + find_variables(child) + + for name, param in obj.named_parameters(recurse=False): + var = param_id_to_var.get(id(param)) + if var is not None: + style = _infer_parallel_style(var, layout_map, name) + if style is not None: + variable_to_attr[var.path] = (var, obj, name, style) + + find_variables(layer) + + module_plans = {} + for var_path, ( + var, + module, + attr_name, + style, + ) in variable_to_attr.items(): + if module not in module_plans: + module_plans[module] = {} + module_plans[module][attr_name] = style + setattr(module, attr_name, var.value) + + tp_mesh = mesh + if "model" in distribution.device_mesh.axis_names: + tp_mesh = mesh["model"] + + for module, sub_plan in module_plans.items(): + if isinstance(module, torch.nn.ParameterDict): + continue + parallelize_module(module, tp_mesh, sub_plan) + + for var_path, ( + var, + module, + attr_name, + style, + ) in variable_to_attr.items(): + sharded_param = getattr(module, attr_name) + if not hasattr(sharded_param, "placements"): + layout = layout_map[var.path] + sharded_param = distribute_variable(var.value, layout) + setattr(module, attr_name, sharded_param) + + if not isinstance(sharded_param, Variable): + var._value = sharded_param + if not hasattr(sharded_param, "constraint"): + sharded_param.constraint = var.constraint + + if hasattr(layer, "_torch_params"): + for var in layer.variables: + if var.path in layer.torch_params: + layer.torch_params[var.path] = var.value + + if ( + isinstance(distribution, DataParallel) + and distribution._is_multi_process + ): + from keras.src.models import Model + + if isinstance(layer, Model): + from keras.src.backend.torch.core import get_device + + device = get_device() + + wrapper_module = DDPModelWrapper(layer) + if "cuda" in str(device): + device_ids = [torch.cuda.current_device()] + layer._ddp_wrapper = DDP(wrapper_module, device_ids=device_ids) + else: + layer._ddp_wrapper = DDP(wrapper_module) + + layer._is_parallelized = True + + +def _infer_parallel_style(variable, layout_map, attr_name): + """Infer PyTorch ParallelStyle from Keras LayoutMap. + + Args: + variable: The Keras Variable instance. + layout_map: The LayoutMap for the current distribution. + attr_name: Name of the attribute in the PyTorch module. + + Returns: + A Torch ParallelStyle instance (ColwiseParallel or RowwiseParallel), + or None if no parallel style is applicable. + """ + layout = layout_map[variable.path] + if layout is None or not any(axis is not None for axis in layout.axes): + return None + + model_dim = "model" + if model_dim in layout.axes: + shard_idx = layout.axes.index(model_dim) + if ( + "kernel" in attr_name + or "embeddings" in attr_name + or "weight" in attr_name + ): + if shard_idx == 1: + return ColwiseParallel() + elif shard_idx == 0: + return RowwiseParallel() + elif "bias" in attr_name: + if shard_idx == 0: + return ColwiseParallel() + return None diff --git a/keras/src/backend/torch/distribution_lib_test.py b/keras/src/backend/torch/distribution_lib_test.py new file mode 100644 index 000000000000..4b071b52a6e4 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib_test.py @@ -0,0 +1,103 @@ +"""Tests for distribution_lib.py.""" + +from unittest import mock + +import pytest +import torch + +from keras.src import backend +from keras.src import testing +from keras.src.backend.torch import distribution_lib +from keras.src.distribution import distribution_lib as keras_distribution_lib + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Backend specific test", +) +class TorchDistributionLibTest(testing.TestCase): + def test_list_devices(self): + devices = distribution_lib.list_devices("cpu") + self.assertEqual(devices, ["cpu:0"]) + + def test_get_device_count(self): + count = distribution_lib.get_device_count("cpu") + self.assertEqual(count, 1) + + def test_num_processes(self): + self.assertEqual(distribution_lib.num_processes(), 1) + + def test_process_id(self): + self.assertEqual(distribution_lib.process_id(), 0) + + @mock.patch("keras.src.backend.torch.distribution_lib.init_device_mesh") + def test_to_backend_mesh(self, mock_init_device_mesh): + mesh = keras_distribution_lib.DeviceMesh( + shape=(1,), axis_names=["data"], devices=["cpu:0"] + ) + distribution_lib._to_backend_mesh(mesh) + mock_init_device_mesh.assert_called_once() + + def test_to_backend_layout(self): + mock_backend_mesh = mock.MagicMock() + mock_backend_mesh.device_type = "cpu" + + mesh = keras_distribution_lib.DeviceMesh( + shape=(2,), axis_names=["model"], devices=["cpu:0", "cpu:0"] + ) + mesh._backend_mesh = mock_backend_mesh + + layout = keras_distribution_lib.TensorLayout( + axes=["model", None], device_mesh=mesh + ) + + torch_mesh, placements = distribution_lib._to_backend_layout(layout) + + self.assertEqual(torch_mesh, mock_backend_mesh) + self.assertEqual(len(placements), 1) + from torch.distributed.tensor import Shard + + self.assertIsInstance(placements[0], Shard) + self.assertEqual(placements[0].dim, 0) + + @mock.patch("keras.src.backend.torch.distribution_lib.distribute_tensor") + def test_distribute_variable(self, mock_distribute_tensor): + value = torch.nn.Parameter(torch.randn(4, 4)) + layout = mock.MagicMock() + mock_distribute_tensor.return_value = value.data + + res = distribution_lib.distribute_variable(value, layout) + self.assertIsInstance(res, torch.nn.Parameter) + mock_distribute_tensor.assert_called_once_with(value, layout) + + def test_infer_parallel_style(self): + from keras.src.backend.torch.core import Variable + from keras.src.distribution import LayoutMap + from keras.src.distribution import TensorLayout + + mesh = keras_distribution_lib.DeviceMesh( + shape=(2,), axis_names=["model"], devices=["cpu:0", "cpu:0"] + ) + + # Test ColwiseParallel + layout_map = LayoutMap(mesh) + layout_map["kernel_col"] = TensorLayout([None, "model"]) + variable = mock.MagicMock(spec=Variable) + variable.path = "kernel_col" + from torch.distributed.tensor.parallel import ColwiseParallel + + style = distribution_lib._infer_parallel_style( + variable, layout_map, "kernel" + ) + self.assertIsInstance(style, ColwiseParallel) + + # Test RowwiseParallel + layout_map = LayoutMap(mesh) + layout_map["kernel_row"] = TensorLayout(["model", None]) + variable.path = "kernel_row" + from torch.distributed.tensor.parallel import RowwiseParallel + + style = distribution_lib._infer_parallel_style( + variable, layout_map, "kernel" + ) + self.assertIsInstance(style, RowwiseParallel) diff --git a/keras/src/backend/torch/layer.py b/keras/src/backend/torch/layer.py index da05f32ddfb4..21238d734720 100644 --- a/keras/src/backend/torch/layer.py +++ b/keras/src/backend/torch/layer.py @@ -18,6 +18,14 @@ def _post_build(self): return self._track_variables() + from keras.src.backend.common import global_state + + distribution = global_state.get_global_attribute("distribution") + if distribution is not None: + from keras.src.backend.torch import distribution_lib + + distribution_lib.parallelize_layer(self, distribution) + def _track_variables(self): # set torch_params attribute will have module automatically track # parameters. @@ -38,6 +46,13 @@ def named_parameters( ) def forward(self, *args, **kwargs): + if hasattr(self, "_ddp_wrapper") and not getattr( + self, "_in_ddp_forward", False + ): + self._in_ddp_forward = True + res = self._ddp_wrapper(*args, **kwargs) + self._in_ddp_forward = False + return res return Operation.__call__(self, *args, **kwargs) def _setattr_hook(self, name, value): diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 6aaef8b2d1a5..8720886672e9 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2921,6 +2921,57 @@ def dot_product_attention( mask=mask, scale=scale, ) + + if ( + hasattr(query, "device_mesh") + or hasattr(key, "device_mesh") + or hasattr(value, "device_mesh") + ): + if scale is None: + scale = query.shape[-1] ** 0.5 + q = backend.numpy.transpose(query, (2, 0, 1, 3)) + k = backend.numpy.transpose(key, (2, 0, 1, 3)) + v = backend.numpy.transpose(value, (2, 0, 1, 3)) + + logits = ( + backend.numpy.matmul(q, backend.numpy.transpose(k, (0, 1, 3, 2))) + / scale + ) + + if is_causal: + t = logits.shape[-2] + s = logits.shape[-1] + i = backend.numpy.arange(t)[:, None] + j = backend.numpy.arange(s) + mask_causal = i < j + logits = backend.numpy.where(mask_causal, -float("inf"), logits) + + if mask is not None: + if len(mask.shape) == 4: + mask_t = backend.numpy.transpose(mask, (1, 0, 2, 3)) + else: + mask_t = mask + logits = backend.numpy.where(mask_t, logits, -float("inf")) + if bias is not None: + if len(bias.shape) == 4: + bias_t = backend.numpy.transpose(bias, (1, 0, 2, 3)) + else: + bias_t = bias + logits = logits + bias_t + + probs = backend.nn.softmax(logits, axis=-1) + res = backend.numpy.matmul(probs, v) + out = backend.numpy.transpose(res, (1, 2, 0, 3)) + if hasattr(out, "device_mesh") and hasattr(out, "placements"): + from torch.distributed.tensor import Replicate + from torch.distributed.tensor import Shard + + if any(isinstance(p, Shard) for p in out.placements): + out = out.redistribute( + out.device_mesh, [Replicate()] * out.device_mesh.ndim + ) + return out + return backend.nn.dot_product_attention( query, key,