Skip to content
2 changes: 0 additions & 2 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.torch import core
from keras.src.backend.torch import distribution_lib
from keras.src.backend.torch import image
from keras.src.backend.torch import linalg
from keras.src.backend.torch import math
Expand Down
254 changes: 240 additions & 14 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import functools
import os
import threading

import ml_dtypes
import numpy as np
Expand All @@ -19,6 +20,7 @@
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.config import floatx
from keras.src.utils.module_utils import torch_xla

SUPPORTS_SPARSE_TENSORS = False
SUPPORTS_RAGGED_TENSORS = False
Expand All @@ -35,6 +37,8 @@
DEFAULT_DEVICE = "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
DEFAULT_DEVICE = "xpu"
elif torch_xla.available:
DEFAULT_DEVICE = "xla"
else:
DEFAULT_DEVICE = "cpu"

Expand Down Expand Up @@ -102,17 +106,50 @@ def to_torch_dtype(dtype):


class Variable(KerasVariable):
def __init__(self, *args, layout=None, **kwargs):
self._layout = layout
super().__init__(*args, **kwargs)

def _initialize_layout(self):
"""Initialize the variable layout based on the current distribution."""
distribution = global_state.get_global_attribute("distribution")
if self._layout is None and distribution is not None:
self._layout = distribution.get_variable_layout(self)
if self._layout is None:
from keras.src.distribution import TensorLayout

self._layout = TensorLayout(
[None] * len(self._shape), distribution.device_mesh
)
if hasattr(self._layout, "backend_layout"):
self._layout = self._layout.backend_layout

def _initialize(self, value):
self._shape = self._validate_shape(value.shape)
self._initialize_layout()
if isinstance(value, torch.nn.Parameter):
# Reuse same parameter
self._value = value
else:
requires_grad = self.trainable and torch.is_floating_point(
convert_to_tensor(value, dtype=self._dtype)
)
self._value = torch.nn.Parameter(
convert_to_tensor(value, dtype=self._dtype),
requires_grad=self.trainable,
requires_grad=requires_grad,
).to(get_device())
if self._layout is not None:
from keras.src.backend.torch import distribution_lib

self._value = distribution_lib.distribute_variable(
self._value, self._layout
)

def _direct_assign(self, value):
if self._layout is not None:
from keras.src.backend.torch import distribution_lib

value = distribution_lib.distribute_variable(value, self._layout)
with torch.no_grad():
self.value.copy_(value)

Expand All @@ -122,13 +159,15 @@ def _convert_to_tensor(self, value, dtype=None):
# Overload native accessor.
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
args = [arg.value if isinstance(arg, Variable) else arg for arg in args]
def unwrap(x):
if isinstance(x, Variable):
return x.value
return x

if kwargs is None:
kwargs = {}
kwargs = {
key: value.value if isinstance(value, Variable) else value
for key, value in kwargs.items()
}
args = tree.map_structure(unwrap, args)
kwargs = tree.map_structure(unwrap, kwargs)
return func(*args, **kwargs)

def __array__(self, dtype=None):
Expand Down Expand Up @@ -208,11 +247,11 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
return x
if dtype is None:
if isinstance(x, bool):
return torch.as_tensor(x, dtype=torch.bool, device=get_device())
res = torch.as_tensor(x, dtype=torch.bool, device=get_device())
elif isinstance(x, int):
return torch.as_tensor(x, dtype=torch.int32, device=get_device())
res = torch.as_tensor(x, dtype=torch.int32, device=get_device())
elif isinstance(x, float):
return torch.as_tensor(
res = torch.as_tensor(
x, dtype=to_torch_dtype(floatx()), device=get_device()
)

Expand All @@ -236,14 +275,16 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
)
dtype = to_torch_dtype(dtype)
return torch.as_tensor(x, dtype=dtype, device=get_device())
res = torch.as_tensor(x, dtype=dtype, device=get_device())
return maybe_distribute_tensor(res)


def convert_to_numpy(x):
def transform(x):
if is_tensor(x):
if x.requires_grad:
x = x.detach()
x = _ensure_replicated_local(x)
# Tensor has to be moved to CPU before converting to numpy.
if x.device != torch.device("cpu"):
x = x.cpu()
Expand Down Expand Up @@ -283,9 +324,10 @@ def cast(x, dtype):
x = x.value
if is_tensor(x):
if x.dtype == dtype:
return x
res = x
else:
return x.to(dtype)
res = x.to(dtype)
return maybe_distribute_tensor(res)
return convert_to_tensor(x, dtype)


Expand Down Expand Up @@ -610,6 +652,8 @@ def scatter_update(inputs, indices, updates, reduction=None):
def slice(inputs, start_indices, shape):
shape_dtype = to_torch_dtype("int64")
inputs = convert_to_tensor(inputs)
if hasattr(inputs, "device_mesh"):
inputs = _ensure_replicated_local(inputs)
start_indices = convert_to_tensor(start_indices).to(shape_dtype)
shape = convert_to_tensor(shape).to(shape_dtype)

Expand All @@ -618,7 +662,8 @@ def slice(inputs, start_indices, shape):
python_slice(start_index, start_index + length)
for start_index, length in zip(start_indices, shape)
]
return inputs[slices]
res = inputs[slices]
return maybe_distribute_tensor(res)


def slice_update(inputs, start_indices, updates):
Expand All @@ -627,14 +672,20 @@ def slice_update(inputs, start_indices, updates):
start_indices = convert_to_tensor(start_indices).to(shape_dtype)
updates = convert_to_tensor(updates)

if hasattr(inputs, "device_mesh"):
inputs = _ensure_replicated_local(inputs)

if hasattr(updates, "device_mesh"):
updates = _ensure_replicated_local(updates)

python_slice = __builtins__["slice"]
slices = [
python_slice(start_index, start_index + update_length)
for start_index, update_length in zip(start_indices, updates.shape)
]
outputs = torch.clone(inputs)
outputs[slices] = updates
return outputs
return maybe_distribute_tensor(outputs)


def switch(index, branches, *operands):
Expand Down Expand Up @@ -759,3 +810,178 @@ def backward(ctx, grad_output):
if not isinstance(grads, tuple):
grads = (grads,)
return (None,) + grads


def _is_sharded(tensor):
"""Check if a tensor is sharded across a device mesh."""
if hasattr(tensor, "placements"):
from torch.distributed.tensor import Shard

return any(isinstance(p, Shard) for p in tensor.placements)
return False


def _ensure_replicated_local(x):
"""Ensure a sharded tensor is replicated and converted to a local tensor."""
if is_tensor(x) and hasattr(x, "device_mesh"):
from torch.distributed.tensor import Replicate

if _is_sharded(x):
x = x.redistribute(
x.device_mesh, [Replicate()] * x.device_mesh.ndim
)
return x.to_local()
return x


def _sharding_aware_op(fn):
"""Decorator to make a Torch operation aware of sharded tensors."""

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
if hasattr(self, "device_mesh"):
res = fn(_ensure_replicated_local(self), *args, **kwargs)
return tree.map_structure(maybe_distribute_tensor, res)
return fn(self, *args, **kwargs)

return wrapper


_original_getitem = torch.Tensor.__getitem__


def _sharding_aware_getitem(self, *args, **kwargs):
"""Sharding-aware implementation of __getitem__."""
if hasattr(self, "device_mesh"):
index = args[0]
if isinstance(index, (list, np.ndarray)):
index = tuple(index)
res = _original_getitem(_ensure_replicated_local(self), index, **kwargs)
return maybe_distribute_tensor(res)
return _original_getitem(self, *args, **kwargs)


_original_unbind_fn = torch.unbind


def _sharding_aware_unbind_fn(input, *args, **kwargs):
"""Sharding-aware implementation of torch.unbind."""
if hasattr(input, "device_mesh"):
res = _original_unbind_fn(
_ensure_replicated_local(input), *args, **kwargs
)
return tree.map_structure(maybe_distribute_tensor, res)
return _original_unbind_fn(input, *args, **kwargs)


_original_broadcast_to = torch.broadcast_to


def _sharding_aware_broadcast_to(input, *args, **kwargs):
"""Sharding-aware implementation of torch.broadcast_to."""
if hasattr(input, "device_mesh"):
res = _original_broadcast_to(
_ensure_replicated_local(input), *args, **kwargs
)
return maybe_distribute_tensor(res)
return _original_broadcast_to(input, *args, **kwargs)


_original_einsum = torch.einsum


def _sharding_aware_einsum(subscripts, *operands, **kwargs):
"""Sharding-aware implementation of torch.einsum."""
new_operands = []
any_dtensor = False
for x in operands:
if is_tensor(x) and hasattr(x, "device_mesh"):
new_operands.append(_ensure_replicated_local(x))
any_dtensor = True
else:
new_operands.append(x)
if any_dtensor:
res = _original_einsum(subscripts, *new_operands, **kwargs)
return maybe_distribute_tensor(res)
return _original_einsum(subscripts, *operands, **kwargs)


torch.Tensor.reshape = _sharding_aware_op(torch.Tensor.reshape)
torch.Tensor.view = _sharding_aware_op(torch.Tensor.view)
torch.Tensor.expand = _sharding_aware_op(torch.Tensor.expand)
torch.Tensor.unbind = _sharding_aware_op(torch.Tensor.unbind)
torch.Tensor.squeeze = _sharding_aware_op(torch.Tensor.squeeze)
torch.Tensor.unsqueeze = _sharding_aware_op(torch.Tensor.unsqueeze)
torch.Tensor.__getitem__ = _sharding_aware_getitem
torch.unbind = _sharding_aware_unbind_fn
torch.broadcast_to = _sharding_aware_broadcast_to
torch.einsum = _sharding_aware_einsum

_original_detach = torch.Tensor.detach


def _sharding_aware_detach(self, *args, **kwargs):
"""Sharding-aware implementation of torch.Tensor.detach."""
if hasattr(self, "device_mesh"):
res = _original_detach(_ensure_replicated_local(self), *args, **kwargs)
return maybe_distribute_tensor(res)
return _original_detach(self, *args, **kwargs)


torch.Tensor.detach = _sharding_aware_detach


def _distribution_aware_creation_op(fn):
"""Decorator to make a tensor creation operation distribution-aware."""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
res = fn(*args, **kwargs)
return maybe_distribute_tensor(res)

return wrapper


for name in [
"arange",
"ones",
"zeros",
"eye",
"full",
"linspace",
"logspace",
"ones_like",
"zeros_like",
"rand",
"randn",
"randint",
]:
if hasattr(torch, name):
setattr(
torch, name, _distribution_aware_creation_op(getattr(torch, name))
)


_DISTRIBUTION_AWARE_ACTIVE = threading.local()


def maybe_distribute_tensor(res):
"""Distribute a tensor if a distribution is currently active."""
if not is_tensor(res) or hasattr(res, "device_mesh"):
return res
if getattr(_DISTRIBUTION_AWARE_ACTIVE, "active", False):
return res
distribution = global_state.get_global_attribute("distribution")
if distribution is not None:
_DISTRIBUTION_AWARE_ACTIVE.active = True
from keras.src.backend.torch import distribution_lib
from keras.src.distribution import TensorLayout

mesh = distribution.device_mesh.backend_mesh
if str(res.device).split(":")[0] != mesh.device_type:
res = res.to(mesh.device_type)
layout = TensorLayout([None] * res.ndim, distribution.device_mesh)
res = distribution_lib.distribute_tensor(res, layout)
_DISTRIBUTION_AWARE_ACTIVE.active = False
return res
return res
Loading
Loading