Skip to content

Commit 2592a68

Browse files
authored
Convert entire model to jax xla (#8330)
1 parent 3d1d1b9 commit 2592a68

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

experimental/torch_xla2/torch_xla2/tensor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,8 @@ def __exit__(self, *exc):
452452

453453
def _move_one_value(self, val):
454454
if isinstance(val, torch.nn.Module):
455-
state_dict = self.to_xla(val.state_dict())
456-
val.load_state_dict(state_dict, assign=True)
457-
# Non-persistent buffers are not in state_dict
458-
for b_name, buffer in val.named_buffers():
459-
setattr(val, b_name, self.to_xla(buffer))
460-
return val
455+
with self:
456+
return val.to('jax')
461457
if isinstance(val, XLATensor2):
462458
return val
463459
if isinstance(val, torch.Tensor):

0 commit comments

Comments
 (0)