@@ -668,32 +668,51 @@ def sharded_view(tensor: ShardTensor, target_shape: Sequence[int]) -> ShardTenso
668668
669669
670670# ---------------------------------------------------------------------------
671- # __torch_function__ handlers
671+ # __torch_function__ handlers: argument repackaging
672672# ---------------------------------------------------------------------------
673673
674674
675- def _extract_view_shape (args : tuple [Any , ...]) -> tuple [ShardTensor , tuple [int , ...]]:
676- r"""Extract tensor and target shape from ``__torch_function__`` args.
677-
678- Handles both ``x.view(a, b, c)`` and ``x.view((a, b, c))`` calling
679- conventions.
675+ def _reshape_args (* shape_args : Any ) -> tuple [int , ...]:
676+ r"""Normalize shape arguments to a single tuple of ints.
680677
681- Parameters
682- ----------
683- args : tuple
684- Positional arguments from ``__torch_function__``.
685-
686- Returns
687- -------
688- tuple[ShardTensor, tuple[int, ...]]
689- The input tensor and the target shape.
678+ Handles both a single sequence (e.g. ``(2, 3, 4)``) and variadic ints
679+ (e.g. ``2, 3, 4``) as used by ``Tensor.view`` and ``Tensor.reshape``.
680+ """
681+ if len (shape_args ) == 1 and isinstance (shape_args [0 ], (tuple , list , torch .Size )):
682+ return tuple (shape_args [0 ])
683+ return tuple (shape_args )
684+
685+
686+ def extract_view_and_reshape_arguments (
687+ * args : Any , ** kwargs : Any
688+ ) -> tuple [
689+ ShardTensor ,
690+ tuple [int , ...] | None ,
691+ torch .dtype | None ,
692+ ]:
693+ r"""Extract (tensor, shape, dtype) from view/reshape __torch_function__ args.
694+
695+ Used by Tensor.view, Tensor.reshape, torch.reshape, and aten.view.default.
696+ For view(dtype), returns (tensor, None, dtype). Otherwise returns
697+ (tensor, shape, None) with shape normalized to tuple[int, ...].
690698 """
691699 tensor = args [0 ]
692- if len (args ) == 2 and isinstance (args [1 ], (tuple , list , torch .Size )):
693- shape = tuple (args [1 ])
694- else :
695- shape = tuple (args [1 :])
696- return tensor , shape
700+ # If there is a dtype, catch and exit early:
701+ if len (args ) == 2 and isinstance (args [1 ], torch .dtype ):
702+ # Honestly this execution path makes no sense to me ...
703+ return (tensor , None , args [1 ])
704+ # If it's in kwargs, use that:
705+ shape = kwargs .get ("shape" , None )
706+ if shape is not None :
707+ return (tensor , shape , None )
708+ # Otherwise, all remaning args get massaged into a tuple:
709+ shape = _reshape_args (* args [1 :])
710+ return (tensor , shape , None )
711+
712+
713+ # ---------------------------------------------------------------------------
714+ # __torch_function__ handlers
715+ # ---------------------------------------------------------------------------
697716
698717
699718def view_wrapper (
@@ -703,9 +722,13 @@ def view_wrapper(
703722 kwargs : dict [str , Any ],
704723) -> ShardTensor :
705724 r"""``__torch_function__`` handler for ``torch.Tensor.view``."""
706- if len (args ) == 2 and isinstance (args [1 ], torch .dtype ):
707- return _sharded_view_dtype (args [0 ], args [1 ])
708- tensor , shape = _extract_view_shape (args )
725+ tensor , shape , dtype = extract_view_and_reshape_arguments (* args , ** kwargs )
726+ if dtype is not None :
727+ return _sharded_view_dtype (tensor , dtype )
728+ if shape is None :
729+ raise ValueError (
730+ "ShardTensor.view_wrapper: Shape is required for view operation"
731+ )
709732 return sharded_view (tensor , shape )
710733
711734
@@ -716,7 +739,15 @@ def reshape_wrapper(
716739 kwargs : dict [str , Any ],
717740) -> ShardTensor :
718741 r"""``__torch_function__`` handler for ``torch.Tensor.reshape``."""
719- tensor , shape = _extract_view_shape (args )
742+ tensor , shape , dtype = extract_view_and_reshape_arguments (* args , ** kwargs )
743+ if dtype is not None :
744+ raise ValueError (
745+ "ShardTensor.reshape_wrapper: Dtype is not supported for reshape operation"
746+ )
747+ if shape is None :
748+ raise ValueError (
749+ "ShardTensor.reshape_wrapper: Shape is required for reshape operation"
750+ )
720751 return sharded_view (tensor , shape )
721752
722753
@@ -727,7 +758,11 @@ def torch_reshape_wrapper(
727758 kwargs : dict [str , Any ],
728759) -> ShardTensor :
729760 r"""``__torch_function__`` handler for ``torch.reshape``."""
730- tensor , shape = _extract_view_shape (args )
761+ tensor , shape , _ = extract_view_and_reshape_arguments (* args , ** kwargs )
762+ if shape is None :
763+ raise ValueError (
764+ "ShardTensor.torch_reshape_wrapper: Shape is required for reshape operation"
765+ )
731766 return sharded_view (tensor , shape )
732767
733768
@@ -789,8 +824,11 @@ def aten_view_wrapper(
789824 ShardTensor
790825 Viewed ShardTensor.
791826 """
792- tensor = args [0 ]
793- shape = args [1 ]
827+ tensor , shape , _ = extract_view_and_reshape_arguments (* args , ** kwargs )
828+ if shape is None :
829+ raise ValueError (
830+ "ShardTensor.aten_view_wrapper: Shape is required for view operation"
831+ )
794832 return sharded_view (tensor , shape )
795833
796834
0 commit comments