Skip to content

[torchax]: JittableModule statedict handling #9195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions torchax/test/test_statedict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
import torch
from torch.utils import _pytree as pytree

from torchax import (interop, mesh_util, tensor)


class Model(torch.nn.Module):

def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(10, 5)

def forward(self, x):
return self.linear(x)


mesh = mesh_util.Mesh.fsdp_mesh()
model = interop.JittableModule(mesh.initialize_model_sharded(Model, ()))


class TestTensorStateDict(unittest.TestCase):

def test_get_statedict(self):
state_dict_cpu = model.cpu_state_dict()
is_xla_tensor = pytree.tree_map(lambda t: isinstance(t, tensor.Tensor),
state_dict_cpu)
assert not any(
is_xla_tensor.values()), "State dict should not contain XLA tensors"

def test_load_statedict(self):
state_dict_cpu = model.cpu_state_dict()
state_dict_cpu = pytree.tree_map(torch.zeros_like, state_dict_cpu)
model.load_state_dict(state_dict_cpu)
is_zeros = pytree.tree_map(lambda t: torch.equal(t, torch.zeros_like(t)),
state_dict_cpu)
assert all(is_zeros.values()), "State dict should be zeros"

def test_load_statedict_partial(self):
state_dict_cpu = model.cpu_state_dict()
del state_dict_cpu['_model.linear.bias']
state_dict_cpu = pytree.tree_map(torch.ones_like, state_dict_cpu)
key_check = model.load_state_dict(state_dict_cpu, strict=False)
assert key_check.missing_keys == [
'_model.linear.bias'
], "Missing keys should be '_model.linear.bias'"
linear_weight = model.state_dict()['_model.linear.weight']
assert torch.equal(
linear_weight,
torch.ones_like(linear_weight)), "Linear weight should be ones"


if __name__ == '__main__':
unittest.main()
83 changes: 83 additions & 0 deletions torchax/torchax/interop.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Mapping, Any
import collections
import copy
import functools
Expand Down Expand Up @@ -126,6 +127,88 @@ def call(*args, **kwargs):

self._jitted[key] = call

def cpu_state_dict(self, *args, **kwargs):
"""
Wrapper for state_dict

this function will make sure to transfer all the parameters to CPU
making it easier to save the state dict with torch.save

Returns:
Mapping[str, Any]: A mapping of parameter names to their values (in torch CPU)
"""
state_dict = super().state_dict(*args, **kwargs)
state_dict = pytree.tree_map(lambda t: t.cpu(), state_dict)
return state_dict

def load_state_dict(self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False):
"""
Wrapper for load_state_dict

This function assumes torch CPU state dict and will transfer the parameters to the correct device
and dtype before loading them into the model.

Args:
state_dict (Mapping[str, Any]): A mapping of parameter names to their values (in torch CPU)
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
assign (bool, optional): When set to ``False``, the properties of the tensors
in the current module are preserved whereas setting it to ``True`` preserves
properties of the Tensors in the state dict. The only
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
for which the value from the module is preserved.
Default: ``False``

Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing any keys that are expected
by this module but missing from the provided ``state_dict``.
* **unexpected_keys** is a list of str containing the keys that are not
expected by this module but present in the provided ``state_dict``.
"""
# Move tensors to JAX to have easier time extracting sharding information
current_state_dict = super().state_dict()
current_state_dict = jax_view(current_state_dict)

# create out shardings that eithe reuses the current state dict sharding or replicates the weights
def extract_sharding_or_replicate(name):
if name in current_state_dict:
return current_state_dict[name].sharding
return jax.sharding.PartitionSpec()

output_shards = {
name: extract_sharding_or_replicate(name) for name in state_dict
}

def convert_to_xla_tensor_if_needed(t):
is_torch_tensor = isinstance(t, torch.Tensor)
is_xla_tensor = isinstance(t, torchax.tensor.Tensor)
if is_xla_tensor:
t = jax_view(t)
elif is_torch_tensor:
# convert to jax tensor
t = tensor.t2j(t)
return t

# convert the state dict to JAX and shard them
state_dict = pytree.tree_map(
tensor.t2j,
state_dict,
)
# Convert ordered dict to regular dict, pjit type-safety checks
state_dict = dict(state_dict)
jitted = jax_jit(
lambda t: t, kwargs_for_jax_jit={"out_shardings": output_shards})
state_dict = jitted(state_dict)
# review it as torch tensors, so we can use torch.assign if we need to
state_dict = torch_view(state_dict)

return super().load_state_dict(state_dict, strict, assign)


class CompileMixin:

Expand Down
Loading