Skip to content

Commit deba0f1

Browse files
committed
fix(MetaTensor): astype with torch dtype now returns MetaTensor preserving metadata
When calling MetaTensor.astype() with a torch dtype (e.g. torch.int32), the result was a plain torch.Tensor, silently losing all metadata (affine, spacing, applied transforms, etc.). The root cause was that out_type was hardcoded to torch.Tensor instead of the actual type of self. Fix by using type(self) as out_type when a torch dtype is requested, so that convert_data_type() receives output_type=MetaTensor, sets track_meta=True, and preserves metadata through the dtype cast. The analyzer module already annotated the result of astype(torch.int16) as MetaTensor, relying on this contract. Updated test to assert the result is an instance of MetaTensor and that the metadata key is preserved after the cast. Closes #8202 Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
1 parent ef2acfb commit deba0f1

2 files changed

Lines changed: 9 additions & 4 deletions

File tree

monai/data/meta_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,8 @@ def astype(self, dtype, device=None, *_args, **_kwargs):
442442
_kwargs: additional kwargs (currently unused).
443443
444444
Returns:
445-
data array instance
445+
``MetaTensor`` when a torch dtype is given (metadata is preserved),
446+
or ``np.ndarray`` when a numpy dtype is given.
446447
"""
447448
if isinstance(dtype, str):
448449
mod_str, *dtype = dtype.split(".", 1)
@@ -453,7 +454,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs):
453454

454455
out_type: type[torch.Tensor] | type[np.ndarray] | None
455456
if mod_str == "torch":
456-
out_type = torch.Tensor
457+
out_type = type(self)
457458
elif mod_str in ("numpy", "np"):
458459
out_type = np.ndarray
459460
else:

tests/data/meta_tensor/test_meta_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,12 @@ def test_astype(self):
434434
for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.uint16):
435435
self.assertIsInstance(t.astype(np_types), np.ndarray)
436436
for pt_types in ("torch.float", torch.float, "torch.float64"):
437-
self.assertIsInstance(t.astype(pt_types), torch.Tensor)
438-
self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor)
437+
result = t.astype(pt_types)
438+
self.assertIsInstance(result, MetaTensor)
439+
self.assertEqual(result.meta.get("fname"), "filename")
440+
result = t.astype("torch.float", device="cpu")
441+
self.assertIsInstance(result, MetaTensor)
442+
self.assertEqual(result.meta.get("fname"), "filename")
439443

440444
def test_transforms(self):
441445
key = "im"

0 commit comments

Comments
 (0)