2121
2222 ~Subroutine
2323 ~SubroutineOp
24- ~AbstractArray
2524 ~change_op_basis_subroutine_resource_rep
2625 ~adjoint_subroutine_resource_rep
2726 ~subroutine_resource_rep
3231from collections import defaultdict
3332from collections .abc import Callable
3433from copy import deepcopy
35- from dataclasses import dataclass
36- from functools import lru_cache , reduce , update_wrapper
34+ from functools import lru_cache , update_wrapper
3735from importlib .util import find_spec
3836from inspect import BoundArguments , Signature , signature
3937from typing import Any , ParamSpec
5856has_jax = find_spec ("jax" ) is not None
5957
6058
61- @dataclass (frozen = True )
62- class AbstractArray :
63- """An abstract representation of an array that contains the shape and dtype
64- attributes necessary for resource calculations.
65-
66- This class is used with :func:`~pennylane.templates.subroutine_resource_rep`
67- for specifying abstract information about a :class:`~.Subroutine` for
68- purposes of resource calculations used with graph decompositions.
69-
70- Args:
71- shape (tuple(int)): the dimensions of the array. ``()`` corresponds to a scalar.
72- dtype (type): the data type of the array. Defaults to ``np.dtype(int)`` for easier use in specifying
73- wires.
74- """
75-
76- shape : tuple [int , ...]
77- dtype : np .dtype = np .dtype (int )
78-
79- def __len__ (self ):
80- return reduce (lambda a , b : a * b , self .shape )
81-
82- def __post_init__ (self ):
83- if math .get_interface (self .dtype ) == "torch" :
84- dummy = math .array ((), dtype = self .dtype , like = "torch" )
85- object .__setattr__ (self , "dtype" , dummy .numpy ().dtype )
86- object .__setattr__ (self , "dtype" , np .dtype (self .dtype ))
87-
88-
8959def _make_signature_key (subroutine : "Subroutine" , * args , ** kwargs ):
9060 bound = subroutine .signature .bind (* args , ** kwargs )
9161 bound .apply_defaults ()
@@ -207,7 +177,8 @@ def S(params, wires, rotation):
207177
208178 .. code-block:: python
209179
210- from pennylane.templates import AbstractArray, subroutine_resource_rep
180+ from pennylane.math import AbstractArray
181+ from pennylane.templates import subroutine_resource_rep
211182
212183 class MyOp(qp.operation.Operation):
213184 pass
@@ -254,12 +225,12 @@ def _create_signature_key(
254225 if arg in static_argnames :
255226 key .append (val )
256227 elif arg in wire_argnames :
257- key .append (AbstractArray (shape = (len (val ),), dtype = int ))
228+ key .append (math . AbstractArray (shape = (len (val ),), dtype = int ))
258229 else :
259230 leaves , struct = flatten (val )
260231
261232 shapes = (
262- AbstractArray (shape = math .shape (l ), dtype = getattr (l , "dtype" , type (l )))
233+ math . AbstractArray (shape = math .shape (l ), dtype = getattr (l , "dtype" , type (l )))
263234 for l in leaves
264235 )
265236 key .append ((struct , tuple (shapes )))
@@ -427,11 +398,11 @@ def _default_resources(subroutine: "Subroutine", *args, **kwargs) -> defaultdict
427398 sig = subroutine .signature .bind (* args , ** kwargs )
428399 for arg in subroutine .dynamic_argnames :
429400 avals , struct = flatten (sig .arguments [arg ])
430- if avals and isinstance (avals [0 ], AbstractArray ):
401+ if avals and isinstance (avals [0 ], math . AbstractArray ):
431402 params = (np .empty (shape = aval .shape , dtype = aval .dtype ) for aval in avals )
432403 sig .arguments [arg ] = unflatten (params , struct )
433404 for arg in subroutine .wire_argnames :
434- if isinstance (sig .arguments [arg ], AbstractArray ):
405+ if isinstance (sig .arguments [arg ], math . AbstractArray ):
435406 sig .arguments [arg ] = list (range (sig .arguments [arg ].shape [0 ]))
436407 with queuing .AnnotatedQueue () as q :
437408 subroutine .definition (** sig .arguments )
@@ -633,7 +604,7 @@ def RXLayer(params, wires):
633604 For example, we should be able to calculate the resources using the :class:`~.AbstractArray`
634605 class.
635606
636- >>> from pennylane.templates import AbstractArray
607+ >>> from pennylane.math import AbstractArray
637608 >>> abstract_params = AbstractArray((10,), float)
638609 >>> abstract_wires = AbstractArray((10,))
639610 >>> RXLayer.compute_resources(abstract_params, abstract_wires)
@@ -644,7 +615,8 @@ def RXLayer(params, wires):
644615
645616 .. code-block:: python
646617
647- from pennylane.templates import AbstractArray, subroutine_resource_rep
618+ from pennylane.math import AbstractArray
619+ from pennylane.templates import subroutine_resource_rep
648620
649621 class MyOp(qp.operation.Operation):
650622 pass
@@ -918,7 +890,6 @@ def _(*args, **kwargs):
918890__all__ = [
919891 "Subroutine" ,
920892 "SubroutineOp" ,
921- "AbstractArray" ,
922893 "subroutine_resource_rep" ,
923894 "CollectedSubroutine" ,
924895 "adjoint_subroutine_resource_rep" ,
0 commit comments