Skip to content

Commit f2731f7

Browse files
authored
Merge pull request #74 from OpenQuantumDesign/canonicalization
Canonicalization
2 parents bbbdd53 + 3039873 commit f2731f7

6 files changed

Lines changed: 504 additions & 57 deletions

File tree

src/oqd_core/compiler/analog/passes/canonicalize.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838
CanVerSortedOrder,
3939
VerifyHilberSpaceDim,
4040
)
41-
from oqd_core.compiler.math.rules import (
42-
DistributeMathExpr,
43-
PartitionMathExpr,
44-
ProperOrderMathExpr,
45-
)
41+
from oqd_core.compiler.math.passes import canonicalize_math_expr
4642

4743
########################################################################################
4844

@@ -77,12 +73,6 @@
7773
FixedPoint(Post(GatherMathExpr())),
7874
)
7975

80-
math_chain = Chain(
81-
FixedPoint(Post(DistributeMathExpr())),
82-
FixedPoint(Post(ProperOrderMathExpr())),
83-
FixedPoint(Post(PartitionMathExpr())),
84-
)
85-
8676
verify_canonicalization = Chain(
8777
Post(CanVerOperatorDistribute()),
8878
Post(CanVerGatherMathExpr()),
@@ -128,6 +118,6 @@ def analog_operator_canonicalization(model):
128118
FixedPoint(Post(PruneIdentity())),
129119
FixedPoint(scale_terms_chain),
130120
FixedPoint(Post(SortedOrder())),
131-
math_chain,
121+
canonicalize_math_expr,
132122
verify_canonicalization,
133123
)(model=model)

src/oqd_core/compiler/atomic/canonicalize.py

Lines changed: 228 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from oqd_compiler_infrastructure import Chain, Pre, RewriteRule
15+
from functools import partial, reduce
1616

17-
########################################################################################
17+
from oqd_compiler_infrastructure import Chain, Post, Pre, RewriteRule
18+
19+
from oqd_core.compiler.math.rules import SubstituteMathVar
1820
from oqd_core.interface.atomic import Level, Transition
21+
from oqd_core.interface.atomic.protocol import ParallelProtocol, SequentialProtocol
22+
from oqd_core.interface.math import MathVar
1923

2024
########################################################################################
2125

2226

2327
class UnrollLevelLabel(RewriteRule):
2428
"""
2529
Unrolls the [`Level`][oqd_core.interface.atomic.system.Level] labels present in [`Transitions`][oqd_core.interface.atomic.system.Transition].
30+
31+
Args:
32+
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.
33+
34+
Returns:
35+
model (AtomicCircuit):
36+
37+
Assumptions:
38+
None
39+
2640
"""
2741

2842
def map_Ion(self, model):
@@ -54,6 +68,15 @@ def map_Transition(self, model):
5468
class UnrollTransitionLabel(RewriteRule):
5569
"""
5670
Unrolls the [`Transition`][oqd_core.interface.atomic.system.Transition] labels present in [`Beams`][oqd_core.interface.atomic.protocol.Beam].
71+
72+
Args:
73+
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.
74+
75+
Returns:
76+
model (AtomicCircuit):
77+
78+
Assumptions:
79+
None
5780
"""
5881

5982
def map_System(self, model):
@@ -85,7 +108,210 @@ def map_Beam(self, model):
85108
)
86109

87110

111+
class ResolveNestedProtocol(RewriteRule):
112+
"""
113+
Unfolds nested protocols into a standard form with only 2 hierarchy levels, a sequential protocol of parallel protocols.
114+
115+
Args:
116+
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.
117+
118+
Returns:
119+
model (AtomicCircuit):
120+
121+
Assumptions:
122+
None
123+
"""
124+
125+
def __init__(self):
126+
super().__init__()
127+
128+
self.durations = []
129+
130+
@classmethod
131+
def _get_continuous_duration(self, model):
132+
if isinstance(model, ParallelProtocol):
133+
if len(model.sequence) == 1:
134+
return model.sequence[0].duration
135+
136+
return min(map(lambda x: x.duration, model.sequence))
137+
138+
if isinstance(model, SequentialProtocol):
139+
return self._get_continuous_duration(model.sequence[0])
140+
141+
return model.duration
142+
143+
@classmethod
144+
def _cut_protocol(cls, model, continuous_duration):
145+
if isinstance(model, ParallelProtocol):
146+
pairs = list(
147+
map(
148+
partial(cls._cut_protocol, continuous_duration=continuous_duration),
149+
model.sequence,
150+
)
151+
)
152+
153+
cut = reduce(lambda x, y: x + y, map(lambda x: x[0], pairs))
154+
155+
remainder = [r for r in map(lambda x: x[1], pairs) if r is not None]
156+
157+
if remainder:
158+
return cut, ParallelProtocol(sequence=remainder)
159+
160+
return cut, None
161+
162+
if isinstance(model, SequentialProtocol):
163+
cut, remainder = cls._cut_protocol(
164+
model.sequence[0], continuous_duration=continuous_duration
165+
)
166+
167+
if remainder:
168+
return cut, SequentialProtocol(
169+
sequence=[remainder, *model.sequence[1:]]
170+
)
171+
if model.sequence[1:]:
172+
return cut, SequentialProtocol(sequence=model.sequence[1:])
173+
174+
return cut, None
175+
176+
cut = model.model_copy(deep=True)
177+
if cut.duration == continuous_duration:
178+
return [cut], None
179+
cut.duration = continuous_duration
180+
181+
remainder = model.model_copy(deep=True)
182+
remainder.duration = remainder.duration - continuous_duration
183+
184+
return [cut], remainder
185+
186+
def map_ParallelProtocol(self, model):
187+
sequence = model.sequence
188+
189+
protocols = []
190+
while sequence:
191+
continuous_duration = min(map(self._get_continuous_duration, sequence))
192+
193+
pairs = list(
194+
map(
195+
partial(
196+
self._cut_protocol, continuous_duration=continuous_duration
197+
),
198+
sequence,
199+
)
200+
)
201+
202+
protocols.append(
203+
ParallelProtocol(
204+
sequence=reduce(lambda x, y: x + y, map(lambda x: x[0], pairs))
205+
)
206+
)
207+
208+
sequence = [r for r in map(lambda x: x[1], pairs) if r is not None]
209+
210+
return SequentialProtocol(sequence=protocols)
211+
212+
def map_SequentialProtocol(self, model):
213+
if len(model.sequence) == 1:
214+
return model.sequence[0]
215+
216+
new_sequence = []
217+
for subprotocol in model.sequence:
218+
if isinstance(subprotocol, SequentialProtocol):
219+
new_sequence.extend(
220+
list(
221+
map(
222+
lambda x: x
223+
if isinstance(x, ParallelProtocol)
224+
else ParallelProtocol(sequence=[x]),
225+
subprotocol.sequence,
226+
)
227+
)
228+
)
229+
elif isinstance(subprotocol, ParallelProtocol):
230+
new_sequence.append(subprotocol)
231+
else:
232+
new_sequence.append(ParallelProtocol(sequence=[subprotocol]))
233+
return model.__class__(sequence=new_sequence)
234+
235+
def map_Pulse(self, model):
236+
return SequentialProtocol(sequence=[model])
237+
238+
239+
class ResolveRelativeTime(RewriteRule):
240+
"""
241+
Handles conversion of relative time to absolute time.
242+
243+
Args:
244+
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.
245+
246+
Returns:
247+
model (AtomicCircuit):
248+
249+
Assumptions:
250+
None
251+
"""
252+
253+
def __init__(self):
254+
super().__init__()
255+
256+
def map_AtomicCircuit(self, model):
257+
protocol = Post(
258+
SubstituteMathVar(
259+
variable=MathVar(name="s"), substitution=MathVar(name="t")
260+
)
261+
)(model.protocol)
262+
263+
return model.__class__(system=model.system, protocol=protocol)
264+
265+
@classmethod
266+
def _get_duration(cls, model):
267+
if isinstance(model, SequentialProtocol):
268+
return reduce(
269+
lambda x, y: x + y,
270+
[cls._get_duration(p) for p in model.sequence],
271+
)
272+
if isinstance(model, ParallelProtocol):
273+
return max(
274+
*[cls._get_duration(p) for p in model.sequence],
275+
)
276+
return model.duration
277+
278+
def map_SequentialProtocol(self, model):
279+
current_time = 0
280+
281+
new_sequence = []
282+
for p in model.sequence:
283+
duration = self._get_duration(p)
284+
285+
new_p = Post(
286+
SubstituteMathVar(
287+
variable=MathVar(name="s"),
288+
substitution=MathVar(name="s") - current_time,
289+
)
290+
)(p)
291+
new_sequence.append(new_p)
292+
293+
current_time += duration
294+
295+
return model.__class__(sequence=new_sequence)
296+
297+
298+
########################################################################################
299+
88300
unroll_label_pass = Chain(
89301
Pre(UnrollLevelLabel()),
90302
Pre(UnrollTransitionLabel()),
91303
)
304+
"""
305+
Pass that unrolls the references to levels and transitions
306+
"""
307+
308+
309+
def canonicalize_atomic_circuit_factory():
310+
"""
311+
Factory for creating a pass for canonicalizing an atomic circuit.
312+
"""
313+
return Chain(
314+
unroll_label_pass,
315+
Post(ResolveRelativeTime()),
316+
Post(ResolveNestedProtocol()),
317+
)

src/oqd_core/compiler/math/passes.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from oqd_compiler_infrastructure import Post
16-
1715
########################################################################################
16+
17+
from oqd_compiler_infrastructure import Chain, FixedPoint, Post
18+
1819
from oqd_core.compiler.math.rules import (
20+
DistributeMathExpr,
1921
EvaluateMathExpr,
22+
PartitionMathExpr,
2023
PrintMathExpr,
24+
ProperOrderMathExpr,
25+
PruneMathExpr,
2126
SimplifyMathExpr,
2227
)
2328

@@ -27,6 +32,7 @@
2732
"evaluate_math_expr",
2833
"simplify_math_expr",
2934
"print_math_expr",
35+
"canonicalize_math_expr",
3036
]
3137

3238
########################################################################################
@@ -36,7 +42,6 @@
3642
Pass for evaluating math expression
3743
"""
3844

39-
4045
simplify_math_expr = Post(SimplifyMathExpr())
4146
"""
4247
Pass for simplifying math expression
@@ -46,3 +51,22 @@
4651
"""
4752
Pass for printing math expression
4853
"""
54+
55+
canonicalize_math_expr = Chain(
56+
FixedPoint(
57+
Post(
58+
Chain(
59+
PruneMathExpr(),
60+
SimplifyMathExpr(),
61+
DistributeMathExpr(),
62+
ProperOrderMathExpr(),
63+
)
64+
)
65+
),
66+
FixedPoint(Post(PartitionMathExpr())),
67+
simplify_math_expr,
68+
)
69+
70+
"""
71+
Pass for canonicalizing math expression
72+
"""

0 commit comments

Comments
 (0)