Skip to content

Commit

Permalink
Convert entire model to jax xla (#8330)
Browse files Browse the repository at this point in the history
  • Loading branch information
barney-s authored Oct 29, 2024
1 parent 3d1d1b9 commit 2592a68
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,8 @@ def __exit__(self, *exc):

def _move_one_value(self, val):
if isinstance(val, torch.nn.Module):
state_dict = self.to_xla(val.state_dict())
val.load_state_dict(state_dict, assign=True)
# Non-persistent buffers are not in state_dict
for b_name, buffer in val.named_buffers():
setattr(val, b_name, self.to_xla(buffer))
return val
with self:
return val.to('jax')
if isinstance(val, XLATensor2):
return val
if isinstance(val, torch.Tensor):
Expand Down

0 comments on commit 2592a68

Please sign in to comment.