Skip to content

Commit 8266475

Browse files
authored
Fix view and reshape ops when shape is passed as a kwarg. This essentially (#1426)
consolidates the argument parsing for all the operatiosn.
1 parent 0c83135 commit 8266475

File tree

2 files changed

+176
-29
lines changed

2 files changed

+176
-29
lines changed

physicsnemo/domain_parallel/shard_utils/view_ops.py

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -668,32 +668,51 @@ def sharded_view(tensor: ShardTensor, target_shape: Sequence[int]) -> ShardTenso
668668

669669

670670
# ---------------------------------------------------------------------------
671-
# __torch_function__ handlers
671+
# __torch_function__ handlers: argument repackaging
672672
# ---------------------------------------------------------------------------
673673

674674

675-
def _extract_view_shape(args: tuple[Any, ...]) -> tuple[ShardTensor, tuple[int, ...]]:
676-
r"""Extract tensor and target shape from ``__torch_function__`` args.
677-
678-
Handles both ``x.view(a, b, c)`` and ``x.view((a, b, c))`` calling
679-
conventions.
675+
def _reshape_args(*shape_args: Any) -> tuple[int, ...]:
676+
r"""Normalize shape arguments to a single tuple of ints.
680677
681-
Parameters
682-
----------
683-
args : tuple
684-
Positional arguments from ``__torch_function__``.
685-
686-
Returns
687-
-------
688-
tuple[ShardTensor, tuple[int, ...]]
689-
The input tensor and the target shape.
678+
Handles both a single sequence (e.g. ``(2, 3, 4)``) and variadic ints
679+
(e.g. ``2, 3, 4``) as used by ``Tensor.view`` and ``Tensor.reshape``.
680+
"""
681+
if len(shape_args) == 1 and isinstance(shape_args[0], (tuple, list, torch.Size)):
682+
return tuple(shape_args[0])
683+
return tuple(shape_args)
684+
685+
686+
def extract_view_and_reshape_arguments(
687+
*args: Any, **kwargs: Any
688+
) -> tuple[
689+
ShardTensor,
690+
tuple[int, ...] | None,
691+
torch.dtype | None,
692+
]:
693+
r"""Extract (tensor, shape, dtype) from view/reshape __torch_function__ args.
694+
695+
Used by Tensor.view, Tensor.reshape, torch.reshape, and aten.view.default.
696+
For view(dtype), returns (tensor, None, dtype). Otherwise returns
697+
(tensor, shape, None) with shape normalized to tuple[int, ...].
690698
"""
691699
tensor = args[0]
692-
if len(args) == 2 and isinstance(args[1], (tuple, list, torch.Size)):
693-
shape = tuple(args[1])
694-
else:
695-
shape = tuple(args[1:])
696-
return tensor, shape
700+
# If there is a dtype, catch and exit early:
701+
if len(args) == 2 and isinstance(args[1], torch.dtype):
702+
# Honestly this execution path makes no sense to me ...
703+
return (tensor, None, args[1])
704+
# If it's in kwargs, use that:
705+
shape = kwargs.get("shape", None)
706+
if shape is not None:
707+
return (tensor, shape, None)
708+
# Otherwise, all remaning args get massaged into a tuple:
709+
shape = _reshape_args(*args[1:])
710+
return (tensor, shape, None)
711+
712+
713+
# ---------------------------------------------------------------------------
714+
# __torch_function__ handlers
715+
# ---------------------------------------------------------------------------
697716

698717

699718
def view_wrapper(
@@ -703,9 +722,13 @@ def view_wrapper(
703722
kwargs: dict[str, Any],
704723
) -> ShardTensor:
705724
r"""``__torch_function__`` handler for ``torch.Tensor.view``."""
706-
if len(args) == 2 and isinstance(args[1], torch.dtype):
707-
return _sharded_view_dtype(args[0], args[1])
708-
tensor, shape = _extract_view_shape(args)
725+
tensor, shape, dtype = extract_view_and_reshape_arguments(*args, **kwargs)
726+
if dtype is not None:
727+
return _sharded_view_dtype(tensor, dtype)
728+
if shape is None:
729+
raise ValueError(
730+
"ShardTensor.view_wrapper: Shape is required for view operation"
731+
)
709732
return sharded_view(tensor, shape)
710733

711734

@@ -716,7 +739,15 @@ def reshape_wrapper(
716739
kwargs: dict[str, Any],
717740
) -> ShardTensor:
718741
r"""``__torch_function__`` handler for ``torch.Tensor.reshape``."""
719-
tensor, shape = _extract_view_shape(args)
742+
tensor, shape, dtype = extract_view_and_reshape_arguments(*args, **kwargs)
743+
if dtype is not None:
744+
raise ValueError(
745+
"ShardTensor.reshape_wrapper: Dtype is not supported for reshape operation"
746+
)
747+
if shape is None:
748+
raise ValueError(
749+
"ShardTensor.reshape_wrapper: Shape is required for reshape operation"
750+
)
720751
return sharded_view(tensor, shape)
721752

722753

@@ -727,7 +758,11 @@ def torch_reshape_wrapper(
727758
kwargs: dict[str, Any],
728759
) -> ShardTensor:
729760
r"""``__torch_function__`` handler for ``torch.reshape``."""
730-
tensor, shape = _extract_view_shape(args)
761+
tensor, shape, _ = extract_view_and_reshape_arguments(*args, **kwargs)
762+
if shape is None:
763+
raise ValueError(
764+
"ShardTensor.torch_reshape_wrapper: Shape is required for reshape operation"
765+
)
731766
return sharded_view(tensor, shape)
732767

733768

@@ -789,8 +824,11 @@ def aten_view_wrapper(
789824
ShardTensor
790825
Viewed ShardTensor.
791826
"""
792-
tensor = args[0]
793-
shape = args[1]
827+
tensor, shape, _ = extract_view_and_reshape_arguments(*args, **kwargs)
828+
if shape is None:
829+
raise ValueError(
830+
"ShardTensor.aten_view_wrapper: Shape is required for view operation"
831+
)
794832
return sharded_view(tensor, shape)
795833

796834

test/domain_parallel/ops/test_view_ops.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,19 @@ def forward(self, tensor: torch.Tensor):
7676
return tensor.view(self.target_shape)
7777

7878

79+
class ViewVariadicWrapper(torch.nn.Module):
80+
"""Wrapper for testing tensor.view(*shape) with variadic int arguments."""
81+
82+
def __init__(self, target_shape: tuple[int, ...]):
83+
super().__init__()
84+
self.target_shape = target_shape
85+
86+
def forward(self, tensor: torch.Tensor):
87+
return tensor.view(*self.target_shape)
88+
89+
7990
class ReshapeWrapper(torch.nn.Module):
80-
"""Wrapper class for testing tensor.reshape operation."""
91+
"""Wrapper class for testing tensor.reshape(shape) with shape as a single tuple."""
8192

8293
def __init__(self, target_shape: tuple[int, ...]):
8394
super().__init__()
@@ -87,8 +98,19 @@ def forward(self, tensor: torch.Tensor):
8798
return tensor.reshape(self.target_shape)
8899

89100

101+
class ReshapeVariadicWrapper(torch.nn.Module):
102+
"""Wrapper for testing tensor.reshape(*shape) with variadic int arguments."""
103+
104+
def __init__(self, target_shape: tuple[int, ...]):
105+
super().__init__()
106+
self.target_shape = target_shape
107+
108+
def forward(self, tensor: torch.Tensor):
109+
return tensor.reshape(*self.target_shape)
110+
111+
90112
class TorchReshapeWrapper(torch.nn.Module):
91-
"""Wrapper class for testing torch.reshape operation."""
113+
"""Wrapper class for testing torch.reshape(tensor, shape) with shape as tuple."""
92114

93115
def __init__(self, target_shape: tuple[int, ...]):
94116
super().__init__()
@@ -98,6 +120,28 @@ def forward(self, tensor: torch.Tensor):
98120
return torch.reshape(tensor, self.target_shape)
99121

100122

123+
class TorchReshapeListWrapper(torch.nn.Module):
124+
"""Wrapper for testing torch.reshape(tensor, shape) with shape as list."""
125+
126+
def __init__(self, target_shape: tuple[int, ...]):
127+
super().__init__()
128+
self.target_shape = target_shape
129+
130+
def forward(self, tensor: torch.Tensor):
131+
return torch.reshape(tensor, list(self.target_shape))
132+
133+
134+
class TorchReshapeKwargWrapper(torch.nn.Module):
135+
"""Wrapper for testing torch.reshape(tensor, shape=...) with shape as kwarg."""
136+
137+
def __init__(self, target_shape: tuple[int, ...]):
138+
super().__init__()
139+
self.target_shape = target_shape
140+
141+
def forward(self, tensor: torch.Tensor):
142+
return torch.reshape(tensor, shape=self.target_shape)
143+
144+
101145
class ViewRoundTrip(torch.nn.Module):
102146
"""View to merge last two dims, then view back to the original shape.
103147
@@ -343,6 +387,71 @@ def test_torch_reshape_operation(
343387
)
344388

345389

390+
@pytest.mark.multigpu_static
391+
@pytest.mark.parametrize(
392+
"wrapper_cls,arg_style",
393+
[
394+
(ViewWrapper, "tuple"),
395+
(ViewVariadicWrapper, "variadic"),
396+
(ReshapeWrapper, "tuple"),
397+
(ReshapeVariadicWrapper, "variadic"),
398+
(TorchReshapeWrapper, "tuple"),
399+
(TorchReshapeListWrapper, "list"),
400+
(TorchReshapeKwargWrapper, "kwarg"),
401+
],
402+
ids=[
403+
"view_tuple",
404+
"view_variadic",
405+
"reshape_tuple",
406+
"reshape_variadic",
407+
"torch_reshape_tuple",
408+
"torch_reshape_list",
409+
"torch_reshape_kwarg",
410+
],
411+
)
412+
@pytest.mark.parametrize("backward", [False, True])
413+
def test_view_reshape_argument_permutations(
414+
distributed_mesh,
415+
wrapper_cls,
416+
arg_style,
417+
backward,
418+
):
419+
"""Test all argument permutations: view/reshape with shape as tuple, variadic, list, or kwarg.
420+
421+
Covers tensor.view(shape), tensor.view(*shape), tensor.reshape(shape),
422+
tensor.reshape(*shape), torch.reshape(tensor, shape),
423+
torch.reshape(tensor, list(shape)), and torch.reshape(tensor, shape=...).
424+
"""
425+
if not torch.cuda.is_available():
426+
pytest.skip("CUDA is not available")
427+
428+
dm = DistributedManager()
429+
shape = (4, 128, 8, 4)
430+
target_shape = (4, 128, 32)
431+
432+
original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward)
433+
434+
placements = (Shard(1),)
435+
436+
sharded_tensor = scatter_tensor(
437+
original_tensor,
438+
global_src=0,
439+
mesh=distributed_mesh,
440+
placements=placements,
441+
requires_grad=backward,
442+
)
443+
444+
module = wrapper_cls(target_shape=target_shape)
445+
446+
numerical_shard_tensor_check(
447+
distributed_mesh,
448+
module,
449+
[sharded_tensor],
450+
{},
451+
check_grads=backward,
452+
)
453+
454+
346455
@pytest.mark.multigpu_static
347456
@pytest.mark.parametrize("backward", [False, True])
348457
def test_view_shard_on_non_viewed_dim(

0 commit comments

Comments
 (0)