Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.torch import core
from keras.src.backend.torch import distribution_lib
from keras.src.backend.torch import image
from keras.src.backend.torch import linalg
from keras.src.backend.torch import math
Expand Down
259 changes: 246 additions & 13 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import contextlib
import functools
import os
import threading

import ml_dtypes
import numpy as np
import torch
import torch.func

from keras.src import tree
from keras.src.backend.common import KerasVariable
Expand All @@ -19,6 +21,7 @@
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.config import floatx
from keras.src.utils.module_utils import torch_xla

SUPPORTS_SPARSE_TENSORS = False
SUPPORTS_RAGGED_TENSORS = False
Expand All @@ -35,6 +38,8 @@
DEFAULT_DEVICE = "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
DEFAULT_DEVICE = "xpu"
elif torch_xla.available:
DEFAULT_DEVICE = "xla"
else:
DEFAULT_DEVICE = "cpu"

Expand Down Expand Up @@ -102,17 +107,48 @@ def to_torch_dtype(dtype):


class Variable(KerasVariable):
def __init__(self, *args, layout=None, **kwargs):
self._layout = layout
super().__init__(*args, **kwargs)

def _initialize_layout(self):
"""Initialize the variable layout based on the current distribution."""
distribution = global_state.get_global_attribute("distribution")
if self._layout is None and distribution is not None:
self._layout = distribution.get_variable_layout(self)
if self._layout is None:
from keras.src.distribution import TensorLayout

self._layout = TensorLayout(
[None] * len(self._shape), distribution.device_mesh
)

def _initialize(self, value):
self._shape = self._validate_shape(value.shape)
self._initialize_layout()
if isinstance(value, torch.nn.Parameter):
# Reuse same parameter
self._value = value
else:
requires_grad = self.trainable and torch.is_floating_point(
convert_to_tensor(value, dtype=self._dtype)
)
self._value = torch.nn.Parameter(
convert_to_tensor(value, dtype=self._dtype),
requires_grad=self.trainable,
requires_grad=requires_grad,
).to(get_device())
if self._layout is not None:
from keras.src.backend.torch import distribution_lib

self._value = distribution_lib.distribute_variable(
self._value, self._layout
)

def _direct_assign(self, value):
if self._layout is not None:
from keras.src.backend.torch import distribution_lib

value = distribution_lib.distribute_variable(value, self._layout)
with torch.no_grad():
self.value.copy_(value)

Expand All @@ -122,13 +158,15 @@ def _convert_to_tensor(self, value, dtype=None):
# Overload native accessor.
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
args = [arg.value if isinstance(arg, Variable) else arg for arg in args]
def unwrap(x):
if isinstance(x, Variable):
return x.value
return x

if kwargs is None:
kwargs = {}
kwargs = {
key: value.value if isinstance(value, Variable) else value
for key, value in kwargs.items()
}
args = tree.map_structure(unwrap, args)
kwargs = tree.map_structure(unwrap, kwargs)
return func(*args, **kwargs)

def __array__(self, dtype=None):
Expand Down Expand Up @@ -208,17 +246,21 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
return x
if dtype is None:
if isinstance(x, bool):
return torch.as_tensor(x, dtype=torch.bool, device=get_device())
res = torch.as_tensor(x, dtype=torch.bool, device=get_device())
return maybe_distribute_tensor(res)
elif isinstance(x, int):
res = torch.as_tensor(x, dtype=torch.int32, device=get_device())
return maybe_distribute_tensor(res)
if x < -(2**31) or x >= 2**31:
return torch.as_tensor(
x, dtype=torch.int64, device=get_device()
)
return torch.as_tensor(x, dtype=torch.int32, device=get_device())
elif isinstance(x, float):
return torch.as_tensor(
res = torch.as_tensor(
x, dtype=to_torch_dtype(floatx()), device=get_device()
)
return maybe_distribute_tensor(res)

# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
Expand All @@ -240,14 +282,16 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
)
dtype = to_torch_dtype(dtype)
return torch.as_tensor(x, dtype=dtype, device=get_device())
res = torch.as_tensor(x, dtype=dtype, device=get_device())
return maybe_distribute_tensor(res)


def convert_to_numpy(x):
def transform(x):
if is_tensor(x):
if x.requires_grad:
x = x.detach()
x = _ensure_replicated_local(x)
# Tensor has to be moved to CPU before converting to numpy.
if x.device != torch.device("cpu"):
x = x.cpu()
Expand Down Expand Up @@ -287,9 +331,10 @@ def cast(x, dtype):
x = x.value
if is_tensor(x):
if x.dtype == dtype:
return x
res = x
else:
return x.to(dtype)
res = x.to(dtype)
return maybe_distribute_tensor(res)
return convert_to_tensor(x, dtype)


Expand Down Expand Up @@ -614,6 +659,8 @@ def scatter_update(inputs, indices, updates, reduction=None):
def slice(inputs, start_indices, shape):
shape_dtype = to_torch_dtype("int64")
inputs = convert_to_tensor(inputs)
if hasattr(inputs, "device_mesh"):
inputs = _ensure_replicated_local(inputs)
start_indices = convert_to_tensor(start_indices).to(shape_dtype)
shape = convert_to_tensor(shape).to(shape_dtype)

Expand All @@ -622,7 +669,8 @@ def slice(inputs, start_indices, shape):
python_slice(start_index, start_index + length)
for start_index, length in zip(start_indices, shape)
]
return inputs[slices]
res = inputs[slices]
return maybe_distribute_tensor(res)


def slice_update(inputs, start_indices, updates):
Expand All @@ -631,14 +679,20 @@ def slice_update(inputs, start_indices, updates):
start_indices = convert_to_tensor(start_indices).to(shape_dtype)
updates = convert_to_tensor(updates)

if hasattr(inputs, "device_mesh"):
inputs = _ensure_replicated_local(inputs)

if hasattr(updates, "device_mesh"):
updates = _ensure_replicated_local(updates)

python_slice = __builtins__["slice"]
slices = [
python_slice(start_index, start_index + update_length)
for start_index, update_length in zip(start_indices, updates.shape)
]
outputs = torch.clone(inputs)
outputs[slices] = updates
return outputs
return maybe_distribute_tensor(outputs)


def switch(index, branches, *operands):
Expand Down Expand Up @@ -765,3 +819,182 @@ def backward(ctx, grad_output):
if not isinstance(grads, tuple):
grads = (grads,)
return (None,) + grads


def _is_sharded(tensor):
"""Check if a tensor is sharded across a device mesh."""
if hasattr(tensor, "placements"):
from torch.distributed.tensor import Shard

return any(isinstance(p, Shard) for p in tensor.placements)
return False


def _ensure_replicated_local(x):
"""Ensure a sharded tensor is replicated and converted to a local tensor."""
if is_tensor(x) and hasattr(x, "device_mesh"):
from torch.distributed.tensor import Replicate

if _is_sharded(x):
x = x.redistribute(
x.device_mesh, [Replicate()] * x.device_mesh.ndim
)
return x.to_local()
return x


def _sharding_aware_op(fn):
"""Decorator to make a Torch operation aware of sharded tensors."""

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
if hasattr(self, "device_mesh"):
res = fn(_ensure_replicated_local(self), *args, **kwargs)
return tree.map_structure(maybe_distribute_tensor, res)
return fn(self, *args, **kwargs)

return wrapper


_original_getitem = torch.Tensor.__getitem__


def _sharding_aware_getitem(self, *args, **kwargs):
"""Sharding-aware implementation of __getitem__."""
if hasattr(self, "device_mesh"):
index = args[0]
if isinstance(index, (list, np.ndarray)):
index = tuple(index)
res = _original_getitem(_ensure_replicated_local(self), index, **kwargs)
return maybe_distribute_tensor(res)
return _original_getitem(self, *args, **kwargs)


_original_unbind_fn = torch.unbind


def _sharding_aware_unbind_fn(input, *args, **kwargs):
"""Sharding-aware implementation of torch.unbind."""
if hasattr(input, "device_mesh"):
res = _original_unbind_fn(
_ensure_replicated_local(input), *args, **kwargs
)
return tree.map_structure(maybe_distribute_tensor, res)
return _original_unbind_fn(input, *args, **kwargs)


_original_broadcast_to = torch.broadcast_to


def _sharding_aware_broadcast_to(input, *args, **kwargs):
"""Sharding-aware implementation of torch.broadcast_to."""
if hasattr(input, "device_mesh"):
res = _original_broadcast_to(
_ensure_replicated_local(input), *args, **kwargs
)
return maybe_distribute_tensor(res)
return _original_broadcast_to(input, *args, **kwargs)


_original_einsum = torch.einsum


def _sharding_aware_einsum(subscripts, *operands, **kwargs):
"""Sharding-aware implementation of torch.einsum."""
new_operands = []
any_dtensor = False
for x in operands:
if is_tensor(x) and hasattr(x, "device_mesh"):
new_operands.append(_ensure_replicated_local(x))
any_dtensor = True
else:
new_operands.append(x)
if any_dtensor:
res = _original_einsum(subscripts, *new_operands, **kwargs)
return maybe_distribute_tensor(res)
return _original_einsum(subscripts, *operands, **kwargs)


torch.Tensor.reshape = _sharding_aware_op(torch.Tensor.reshape)
torch.Tensor.view = _sharding_aware_op(torch.Tensor.view)
torch.Tensor.expand = _sharding_aware_op(torch.Tensor.expand)
torch.Tensor.unbind = _sharding_aware_op(torch.Tensor.unbind)
torch.Tensor.squeeze = _sharding_aware_op(torch.Tensor.squeeze)
torch.Tensor.unsqueeze = _sharding_aware_op(torch.Tensor.unsqueeze)
torch.Tensor.__getitem__ = _sharding_aware_getitem
torch.unbind = _sharding_aware_unbind_fn
torch.broadcast_to = _sharding_aware_broadcast_to
torch.einsum = _sharding_aware_einsum

_original_detach = torch.Tensor.detach


def _sharding_aware_detach(self, *args, **kwargs):
"""Sharding-aware implementation of torch.Tensor.detach."""
if hasattr(self, "device_mesh"):
res = _original_detach(_ensure_replicated_local(self), *args, **kwargs)
return maybe_distribute_tensor(res)
return _original_detach(self, *args, **kwargs)


torch.Tensor.detach = _sharding_aware_detach


def _distribution_aware_creation_op(fn):
"""Decorator to make a tensor creation operation distribution-aware."""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
res = fn(*args, **kwargs)
return maybe_distribute_tensor(res)

return wrapper


torch.func.jvp(lambda x: x, (torch.tensor(1.0),), (torch.tensor(1.0),))


for name in [
"arange",
"ones",
"zeros",
"eye",
"full",
"linspace",
"logspace",
"ones_like",
"zeros_like",
"rand",
"randn",
"randint",
]:
if hasattr(torch, name):
setattr(
torch, name, _distribution_aware_creation_op(getattr(torch, name))
)


_DISTRIBUTION_AWARE_ACTIVE = threading.local()


def maybe_distribute_tensor(res):
"""Distribute a tensor if a distribution is currently active."""
if not is_tensor(res) or hasattr(res, "device_mesh"):
return res
if getattr(_DISTRIBUTION_AWARE_ACTIVE, "active", False):
return res
distribution = global_state.get_global_attribute("distribution")
if distribution is not None:
_DISTRIBUTION_AWARE_ACTIVE.active = True
try:
from keras.src.backend.torch import distribution_lib
from keras.src.distribution import TensorLayout

mesh = distribution.device_mesh.backend_mesh
if str(res.device).split(":")[0] != mesh.device_type:
res = res.to(mesh.device_type)
layout = TensorLayout([None] * res.ndim, distribution.device_mesh)
return distribution_lib.distribute_tensor(res, layout)
finally:
_DISTRIBUTION_AWARE_ACTIVE.active = False
return res
Loading