We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3d1d1b9 commit 2592a68Copy full SHA for 2592a68
experimental/torch_xla2/torch_xla2/tensor.py
@@ -452,12 +452,8 @@ def __exit__(self, *exc):
452
453
def _move_one_value(self, val):
454
if isinstance(val, torch.nn.Module):
455
- state_dict = self.to_xla(val.state_dict())
456
- val.load_state_dict(state_dict, assign=True)
457
- # Non-persistent buffers are not in state_dict
458
- for b_name, buffer in val.named_buffers():
459
- setattr(val, b_name, self.to_xla(buffer))
460
- return val
+ with self:
+ return val.to('jax')
461
if isinstance(val, XLATensor2):
462
return val
463
if isinstance(val, torch.Tensor):
0 commit comments