Skip to content

Commit 62ffd32

Browse files
committed
Migrate AbstractArray to qp.math
1 parent 6198602 commit 62ffd32

4 files changed

Lines changed: 123 additions & 39 deletions

File tree

pennylane/math/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535

3636
import autoray as ar
3737

38+
from .abstract_types import (
39+
AbstractArray,
40+
AbstractBool,
41+
AbstractComplex,
42+
AbstractFloat,
43+
AbstractInt,
44+
AbstractWires,
45+
)
3846
from .binary_linalg import (
3947
binary_decimals,
4048
binary_finite_reduced_row_echelon,
@@ -210,6 +218,12 @@ def __getattr__(name):
210218

211219

212220
__all__ = [
221+
"AbstractArray",
222+
"AbstractBool",
223+
"AbstractComplex",
224+
"AbstractFloat",
225+
"AbstractInt",
226+
"AbstractWires",
213227
"add",
214228
"allclose",
215229
"allequal",

pennylane/math/abstract_types.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2026 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""This module contains data structures to represent abstract arrays."""
15+
16+
from dataclasses import dataclass
17+
from math import prod
18+
from numbers import Number
19+
20+
import numpy as np
21+
22+
23+
@dataclass(frozen=True)
24+
class AbstractArray:
25+
"""Abstract array type."""
26+
27+
shape: tuple[int, ...]
28+
dtype: np.dtype | type[Number] = np.int64
29+
30+
def __post_init__(self):
31+
object.__setattr__(self, "shape", tuple(self.shape))
32+
object.__setattr__(self, "dtype", np.dtype(self.dtype))
33+
34+
@property
35+
def size(self) -> int:
36+
"""Total number of elements."""
37+
return prod(self.shape)
38+
39+
@property
40+
def T(self) -> "AbstractArray":
41+
"""Transpose view of the array."""
42+
return AbstractArray(self.shape[::-1], self.dtype)
43+
44+
@property
45+
def ndim(self) -> int:
46+
"""Number of dimensions."""
47+
return len(self.shape)
48+
49+
def __getitem__(self, *_, **__):
50+
raise IndexError("Cannot index into an abstract array.")
51+
52+
def __setitem__(self, *_, **__):
53+
raise IndexError("Cannot index into an abstract array.")
54+
55+
def __len__(self) -> int:
56+
if not self.shape:
57+
raise TypeError("len() of unsized object.")
58+
return self.shape[0]
59+
60+
def __eq__(self, other: "AbstractArray") -> bool:
61+
# This should probably just raise an error
62+
if isinstance(other, AbstractArray):
63+
return self.shape == other.shape and self.dtype == other.dtype
64+
65+
raise TypeError("Tried to check equality against an abstract array.")
66+
67+
def __hash__(self) -> int:
68+
return hash((self.shape, self.dtype))
69+
70+
71+
AbstractBool = AbstractArray((), bool)
72+
AbstractInt = AbstractArray((), int)
73+
AbstractFloat = AbstractArray((), float)
74+
AbstractComplex = AbstractArray((), complex)
75+
76+
77+
@dataclass(frozen=True)
78+
class AbstractWires(AbstractArray):
79+
"""Abstract wires."""
80+
81+
num_wires: int
82+
83+
def __post_init__(self):
84+
object.__setattr__(self, "shape", (self.num_wires,))
85+
object.__setattr__(self, "dtype", int)
86+
87+
def __eq__(self, other: "AbstractWires"):
88+
if isinstance(other, AbstractWires):
89+
return self.num_wires == other.num_wires
90+
91+
raise TypeError("Tried to check equality against an abstract wire register.")
92+
93+
def __hash__(self):
94+
return hash(("AbstractWires", self.num_wires))

pennylane/math/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
from pennylane import math
2626

27+
from .abstract_types import AbstractArray
28+
2729

2830
def allequal(tensor1, tensor2, **kwargs):
2931
"""Returns True if two tensors are element-wise equal along a given axis.
@@ -417,6 +419,9 @@ def function(x):
417419
Abstract: True
418420
<tf.Tensor: shape=(), dtype=float32, numpy=0.26>
419421
"""
422+
if isinstance(tensor, AbstractArray):
423+
return True
424+
420425
interface = like or math.get_interface(tensor)
421426

422427
if interface == "jax":

pennylane/templates/core.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
2222
~Subroutine
2323
~SubroutineOp
24-
~AbstractArray
2524
~change_op_basis_subroutine_resource_rep
2625
~adjoint_subroutine_resource_rep
2726
~subroutine_resource_rep
@@ -32,8 +31,7 @@
3231
from collections import defaultdict
3332
from collections.abc import Callable
3433
from copy import deepcopy
35-
from dataclasses import dataclass
36-
from functools import lru_cache, reduce, update_wrapper
34+
from functools import lru_cache, update_wrapper
3735
from importlib.util import find_spec
3836
from inspect import BoundArguments, Signature, signature
3937
from typing import Any, ParamSpec
@@ -58,34 +56,6 @@
5856
has_jax = find_spec("jax") is not None
5957

6058

61-
@dataclass(frozen=True)
62-
class AbstractArray:
63-
"""An abstract representation of an array that contains the shape and dtype
64-
attributes necessary for resource calculations.
65-
66-
This class is used with :func:`~pennylane.templates.subroutine_resource_rep`
67-
for specifying abstract information about a :class:`~.Subroutine` for
68-
purposes of resource calculations used with graph decompositions.
69-
70-
Args:
71-
shape (tuple(int)): the dimensions of the array. ``()`` corresponds to a scalar.
72-
dtype (type): the data type of the array. Defaults to ``np.dtype(int)`` for easier use in specifying
73-
wires.
74-
"""
75-
76-
shape: tuple[int, ...]
77-
dtype: np.dtype = np.dtype(int)
78-
79-
def __len__(self):
80-
return reduce(lambda a, b: a * b, self.shape)
81-
82-
def __post_init__(self):
83-
if math.get_interface(self.dtype) == "torch":
84-
dummy = math.array((), dtype=self.dtype, like="torch")
85-
object.__setattr__(self, "dtype", dummy.numpy().dtype)
86-
object.__setattr__(self, "dtype", np.dtype(self.dtype))
87-
88-
8959
def _make_signature_key(subroutine: "Subroutine", *args, **kwargs):
9060
bound = subroutine.signature.bind(*args, **kwargs)
9161
bound.apply_defaults()
@@ -207,7 +177,8 @@ def S(params, wires, rotation):
207177
208178
.. code-block:: python
209179
210-
from pennylane.templates import AbstractArray, subroutine_resource_rep
180+
from pennylane.math import AbstractArray
181+
from pennylane.templates import subroutine_resource_rep
211182
212183
class MyOp(qp.operation.Operation):
213184
pass
@@ -254,12 +225,12 @@ def _create_signature_key(
254225
if arg in static_argnames:
255226
key.append(val)
256227
elif arg in wire_argnames:
257-
key.append(AbstractArray(shape=(len(val),), dtype=int))
228+
key.append(math.AbstractArray(shape=(len(val),), dtype=int))
258229
else:
259230
leaves, struct = flatten(val)
260231

261232
shapes = (
262-
AbstractArray(shape=math.shape(l), dtype=getattr(l, "dtype", type(l)))
233+
math.AbstractArray(shape=math.shape(l), dtype=getattr(l, "dtype", type(l)))
263234
for l in leaves
264235
)
265236
key.append((struct, tuple(shapes)))
@@ -427,11 +398,11 @@ def _default_resources(subroutine: "Subroutine", *args, **kwargs) -> defaultdict
427398
sig = subroutine.signature.bind(*args, **kwargs)
428399
for arg in subroutine.dynamic_argnames:
429400
avals, struct = flatten(sig.arguments[arg])
430-
if avals and isinstance(avals[0], AbstractArray):
401+
if avals and isinstance(avals[0], math.AbstractArray):
431402
params = (np.empty(shape=aval.shape, dtype=aval.dtype) for aval in avals)
432403
sig.arguments[arg] = unflatten(params, struct)
433404
for arg in subroutine.wire_argnames:
434-
if isinstance(sig.arguments[arg], AbstractArray):
405+
if isinstance(sig.arguments[arg], math.AbstractArray):
435406
sig.arguments[arg] = list(range(sig.arguments[arg].shape[0]))
436407
with queuing.AnnotatedQueue() as q:
437408
subroutine.definition(**sig.arguments)
@@ -633,7 +604,7 @@ def RXLayer(params, wires):
633604
For example, we should be able to calculate the resources using the :class:`~.AbstractArray`
634605
class.
635606
636-
>>> from pennylane.templates import AbstractArray
607+
>>> from pennylane.math import AbstractArray
637608
>>> abstract_params = AbstractArray((10,), float)
638609
>>> abstract_wires = AbstractArray((10,))
639610
>>> RXLayer.compute_resources(abstract_params, abstract_wires)
@@ -644,7 +615,8 @@ def RXLayer(params, wires):
644615
645616
.. code-block:: python
646617
647-
from pennylane.templates import AbstractArray, subroutine_resource_rep
618+
from pennylane.math import AbstractArray
619+
from pennylane.templates import subroutine_resource_rep
648620
649621
class MyOp(qp.operation.Operation):
650622
pass
@@ -918,7 +890,6 @@ def _(*args, **kwargs):
918890
__all__ = [
919891
"Subroutine",
920892
"SubroutineOp",
921-
"AbstractArray",
922893
"subroutine_resource_rep",
923894
"CollectedSubroutine",
924895
"adjoint_subroutine_resource_rep",

0 commit comments

Comments
 (0)