Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions src/oqd_core/compiler/analog/passes/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@
CanVerSortedOrder,
VerifyHilberSpaceDim,
)
from oqd_core.compiler.math.rules import (
DistributeMathExpr,
PartitionMathExpr,
ProperOrderMathExpr,
)
from oqd_core.compiler.math.passes import canonicalize_math_expr

########################################################################################

Expand Down Expand Up @@ -77,12 +73,6 @@
FixedPoint(Post(GatherMathExpr())),
)

math_chain = Chain(
FixedPoint(Post(DistributeMathExpr())),
FixedPoint(Post(ProperOrderMathExpr())),
FixedPoint(Post(PartitionMathExpr())),
)

verify_canonicalization = Chain(
Post(CanVerOperatorDistribute()),
Post(CanVerGatherMathExpr()),
Expand Down Expand Up @@ -128,6 +118,6 @@ def analog_operator_canonicalization(model):
FixedPoint(Post(PruneIdentity())),
FixedPoint(scale_terms_chain),
FixedPoint(Post(SortedOrder())),
math_chain,
canonicalize_math_expr,
verify_canonicalization,
)(model=model)
230 changes: 228 additions & 2 deletions src/oqd_core/compiler/atomic/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from oqd_compiler_infrastructure import Chain, Pre, RewriteRule
from functools import partial, reduce

########################################################################################
from oqd_compiler_infrastructure import Chain, Post, Pre, RewriteRule

from oqd_core.compiler.math.rules import SubstituteMathVar
from oqd_core.interface.atomic import Level, Transition
from oqd_core.interface.atomic.protocol import ParallelProtocol, SequentialProtocol
from oqd_core.interface.math import MathVar

########################################################################################


class UnrollLevelLabel(RewriteRule):
"""
Unrolls the [`Level`][oqd_core.interface.atomic.system.Level] labels present in [`Transitions`][oqd_core.interface.atomic.system.Transition].

Args:
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.

Returns:
model (AtomicCircuit):

Assumptions:
None

"""

def map_Ion(self, model):
Expand Down Expand Up @@ -54,6 +68,15 @@ def map_Transition(self, model):
class UnrollTransitionLabel(RewriteRule):
"""
Unrolls the [`Transition`][oqd_core.interface.atomic.system.Transition] labels present in [`Beams`][oqd_core.interface.atomic.protocol.Beam].

Args:
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.

Returns:
model (AtomicCircuit):

Assumptions:
None
"""

def map_System(self, model):
Expand Down Expand Up @@ -85,7 +108,210 @@ def map_Beam(self, model):
)


class ResolveNestedProtocol(RewriteRule):
"""
Unfolds nested protocols into a standard form with only 2 hierarchy levels, a sequential protocol of parallel protocols.

Args:
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.

Returns:
model (AtomicCircuit):

Assumptions:
None
"""

def __init__(self):
super().__init__()

self.durations = []

@classmethod
def _get_continuous_duration(self, model):
if isinstance(model, ParallelProtocol):
if len(model.sequence) == 1:
return model.sequence[0].duration

return min(map(lambda x: x.duration, model.sequence))

if isinstance(model, SequentialProtocol):
return self._get_continuous_duration(model.sequence[0])

return model.duration

@classmethod
def _cut_protocol(cls, model, continuous_duration):
if isinstance(model, ParallelProtocol):
pairs = list(
map(
partial(cls._cut_protocol, continuous_duration=continuous_duration),
model.sequence,
)
)

cut = reduce(lambda x, y: x + y, map(lambda x: x[0], pairs))

remainder = [r for r in map(lambda x: x[1], pairs) if r is not None]

if remainder:
return cut, ParallelProtocol(sequence=remainder)

return cut, None

if isinstance(model, SequentialProtocol):
cut, remainder = cls._cut_protocol(
model.sequence[0], continuous_duration=continuous_duration
)

if remainder:
return cut, SequentialProtocol(
sequence=[remainder, *model.sequence[1:]]
)
if model.sequence[1:]:
return cut, SequentialProtocol(sequence=model.sequence[1:])

return cut, None

cut = model.model_copy(deep=True)
if cut.duration == continuous_duration:
return [cut], None
cut.duration = continuous_duration

remainder = model.model_copy(deep=True)
remainder.duration = remainder.duration - continuous_duration

return [cut], remainder

def map_ParallelProtocol(self, model):
sequence = model.sequence

protocols = []
while sequence:
continuous_duration = min(map(self._get_continuous_duration, sequence))

pairs = list(
map(
partial(
self._cut_protocol, continuous_duration=continuous_duration
),
sequence,
)
)

protocols.append(
ParallelProtocol(
sequence=reduce(lambda x, y: x + y, map(lambda x: x[0], pairs))
)
)

sequence = [r for r in map(lambda x: x[1], pairs) if r is not None]

return SequentialProtocol(sequence=protocols)

def map_SequentialProtocol(self, model):
if len(model.sequence) == 1:
return model.sequence[0]

new_sequence = []
for subprotocol in model.sequence:
if isinstance(subprotocol, SequentialProtocol):
new_sequence.extend(
list(
map(
lambda x: x
if isinstance(x, ParallelProtocol)
else ParallelProtocol(sequence=[x]),
subprotocol.sequence,
)
)
)
elif isinstance(subprotocol, ParallelProtocol):
new_sequence.append(subprotocol)
else:
new_sequence.append(ParallelProtocol(sequence=[subprotocol]))
return model.__class__(sequence=new_sequence)

def map_Pulse(self, model):
return SequentialProtocol(sequence=[model])


class ResolveRelativeTime(RewriteRule):
"""
Handles conversion of relative time to absolute time.

Args:
model (AtomicCircuit): The rule only acts on [`AtomicCircuit`][oqd_core.interface.atomic.AtomicCircuit] objects.

Returns:
model (AtomicCircuit):

Assumptions:
None
"""

def __init__(self):
super().__init__()

def map_AtomicCircuit(self, model):
protocol = Post(
SubstituteMathVar(
variable=MathVar(name="s"), substitution=MathVar(name="t")
)
)(model.protocol)

return model.__class__(system=model.system, protocol=protocol)

@classmethod
def _get_duration(cls, model):
if isinstance(model, SequentialProtocol):
return reduce(
lambda x, y: x + y,
[cls._get_duration(p) for p in model.sequence],
)
if isinstance(model, ParallelProtocol):
return max(
*[cls._get_duration(p) for p in model.sequence],
)
return model.duration

def map_SequentialProtocol(self, model):
current_time = 0

new_sequence = []
for p in model.sequence:
duration = self._get_duration(p)

new_p = Post(
SubstituteMathVar(
variable=MathVar(name="s"),
substitution=MathVar(name="s") - current_time,
)
)(p)
new_sequence.append(new_p)

current_time += duration

return model.__class__(sequence=new_sequence)


########################################################################################

unroll_label_pass = Chain(
Pre(UnrollLevelLabel()),
Pre(UnrollTransitionLabel()),
)
"""
Pass that unrolls the references to levels and transitions
"""


def canonicalize_atomic_circuit_factory():
"""
Factory for creating a pass for canonicalizing an atomic circuit.
"""
return Chain(
unroll_label_pass,
Post(ResolveRelativeTime()),
Post(ResolveNestedProtocol()),
)
30 changes: 27 additions & 3 deletions src/oqd_core/compiler/math/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from oqd_compiler_infrastructure import Post

########################################################################################

from oqd_compiler_infrastructure import Chain, FixedPoint, Post

from oqd_core.compiler.math.rules import (
DistributeMathExpr,
EvaluateMathExpr,
PartitionMathExpr,
PrintMathExpr,
ProperOrderMathExpr,
PruneMathExpr,
SimplifyMathExpr,
)

Expand All @@ -27,6 +32,7 @@
"evaluate_math_expr",
"simplify_math_expr",
"print_math_expr",
"canonicalize_math_expr",
]

########################################################################################
Expand All @@ -36,7 +42,6 @@
Pass for evaluating math expression
"""


simplify_math_expr = Post(SimplifyMathExpr())
"""
Pass for simplifying math expression
Expand All @@ -46,3 +51,22 @@
"""
Pass for printing math expression
"""

canonicalize_math_expr = Chain(
FixedPoint(
Post(
Chain(
PruneMathExpr(),
SimplifyMathExpr(),
DistributeMathExpr(),
ProperOrderMathExpr(),
)
)
),
FixedPoint(Post(PartitionMathExpr())),
simplify_math_expr,
)

"""
Pass for canonicalizing math expression
"""
Loading
Loading