Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 128 additions & 164 deletions mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"

Expand All @@ -33,6 +35,9 @@ namespace ion {

namespace {

constexpr StringLiteral kPulseGroupAttr = "_group";
constexpr StringLiteral kParallelProtocolIdAttr = "parallel_protocol_id";

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
Expand All @@ -46,6 +51,82 @@ Value awaitEvents(ArrayRef<Value> events, PatternRewriter &rewriter)
return rtio::RTIOSyncOp::create(rewriter, rewriter.getUnknownLoc(), eventType, events);
}

/// Extract the qubit index from a memref.load value.
/// For given `memref.load @qubit_map[%cN]`, returns `N`.
static FailureOr<int64_t> getQubitIndex(Value memrefLoadValue)
{
auto loadOp = memrefLoadValue.getDefiningOp<memref::LoadOp>();
if (!loadOp || loadOp.getIndices().size() != 1) {
return failure();
}
IntegerAttr indexAttr;
if (!matchPattern(loadOp.getIndices()[0], m_Constant<IntegerAttr>(&indexAttr))) {
return failure();
}
return indexAttr.getInt();
}

static std::optional<double> getConstF64(Value v)
{
FloatAttr attr;
if (matchPattern(v, m_Constant<FloatAttr>(&attr))) {
return attr.getValueAsDouble();
}
return std::nullopt;
}

/// Find a pulse with the same (frequency, phase) tone.
static rtio::RTIOPulseOp findSameTonePulse(ArrayRef<rtio::RTIOPulseOp> pulses,
rtio::RTIOPulseOp pulse)
{
auto f = getConstF64(pulse.getFrequency());
auto p = getConstF64(pulse.getPhase());
if (!f || !p) {
return nullptr;
}

auto found = llvm::find_if(pulses, [=](rtio::RTIOPulseOp target) {
return getConstF64(target.getFrequency()) == f && getConstF64(target.getPhase()) == p;
});
return found != pulses.end() ? *found : nullptr;
}

/// Merge qubit qualifiers from src into dst.
static void mergeChannelQualifiers(rtio::RTIOPulseOp dst, rtio::RTIOPulseOp src,
PatternRewriter &rewriter, MLIRContext *ctx, Location loc)
{
auto dstCh = llvm::cast<rtio::ChannelType>(dst.getChannel().getType());
auto srcCh = llvm::cast<rtio::ChannelType>(src.getChannel().getType());

SetVector<int64_t> qubits;

// Merge qualifiers from dst and src
for (auto q : dstCh.getQualifiers()) {
qubits.insert(llvm::cast<IntegerAttr>(q).getInt());
}
for (auto q : srcCh.getQualifiers()) {
qubits.insert(llvm::cast<IntegerAttr>(q).getInt());
}

SmallVector<Attribute> quals;
for (int64_t q : qubits) {
quals.push_back(rewriter.getI64IntegerAttr(q));
}

auto mergedType = rtio::ChannelType::get(ctx, "dds", rewriter.getArrayAttr(quals),
rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
Value newCh = rtio::RTIOChannelOp::create(rewriter, loc, mergedType);
rewriter.moveOpBefore(newCh.getDefiningOp(), dst);

Value oldCh = dst.getChannel();
dst.getChannelMutable().assign(newCh);

// Remove unused channel op
if (oldCh.getDefiningOp() && oldCh.use_empty()) {
rewriter.eraseOp(oldCh.getDefiningOp());
}
}

//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -85,6 +166,10 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern<ion::ParallelP
MLIRContext *ctx = rewriter.getContext();
Location loc = op.getLoc();

auto protocolIdAttr = op->getAttrOfType<IntegerAttr>(kParallelProtocolIdAttr);
assert(protocolIdAttr && "parallel protocol must have parallel protocol id");
int64_t protocolId = protocolIdAttr.getInt();

Block *regionBlock = &op.getBodyRegion().front();
IRMapping irMapping;
SmallVector<Value> inQubits;
Expand All @@ -106,68 +191,46 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern<ion::ParallelP

Value inputSyncEvent = awaitEvents(llvm::to_vector(events), rewriter);

// Clone operations from the region to outside
SmallVector<Value> pulseEvents;
DenseMap<Value, int64_t> qubitToOffset;

// we cache the channel to index mapping to avoid multiple lookups
DenseMap<Value, Value> cache;
// Clone all operations from the region
SmallVector<rtio::RTIOPulseOp> clonedPulses;
for (auto &regionOp : regionBlock->without_terminator()) {
auto *clonedOp = rewriter.clone(regionOp, irMapping);
if (auto pulseOp = dyn_cast<rtio::RTIOPulseOp>(clonedOp)) {
// set wait event for the pulse operation
pulseOp.setWait(inputSyncEvent);

Value index = nullptr;

SmallVector<Value> chain;
traceValueWithCallback<TraceMode::Qreg>(
pulseOp.getChannel(), [&](Value value) -> WalkResult {
if (cache.count(value)) {
index = cache[value];
return WalkResult::interrupt();
}
chain.push_back(value);
if (auto loadOp =
llvm::dyn_cast_if_present<memref::LoadOp>(value.getDefiningOp())) {
index = loadOp.getIndices()[0];

// cache the channel to index mapping
cache[pulseOp.getChannel()] = index;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (index == nullptr) {
op->emitError("Failed to trace the channel index");
return failure();
}
if (auto pulseOp = dyn_cast<rtio::RTIOPulseOp>(clonedOp))
clonedPulses.push_back(pulseOp);
irMapping.map(regionOp.getResults(), clonedOp->getResults());
}

// update cache
for (Value value : chain) {
cache[value] = index;
// Merge same-tone pulses
SmallVector<rtio::RTIOPulseOp> survivors;
for (auto pulse : clonedPulses) {
if (auto match = findSameTonePulse(survivors, pulse)) {
// Merge channel qualifiers to `match`, and erase `pulse`
mergeChannelQualifiers(match, pulse, rewriter, ctx, loc);
Value ch = pulse.getChannel();
rewriter.eraseOp(pulse);
if (ch.getDefiningOp() && ch.use_empty()) {
rewriter.eraseOp(ch.getDefiningOp());
}
pulseOp->setAttr("offset", rewriter.getI64IntegerAttr(qubitToOffset[index]));

// the same qubit may appear multiple times in the parallel protocol
// so we need to increment the offset for each appearance
qubitToOffset[index]++;

pulseEvents.push_back(pulseOp.getEvent());
}
irMapping.map(regionOp.getResults(), clonedOp->getResults());
else {
survivors.push_back(pulse);
}
}

// Create sync operation from pulse events (must have at least one after Phase 1)
assert(pulseEvents.size() > 0 &&
"must have at least one pulse operation after parallel protocol conversion");
// Set wait dependency and _group on each surviving pulse
SmallVector<Value> pulseEvents;
for (auto pulse : survivors) {
pulse.setWait(inputSyncEvent);
pulse->setAttr(kPulseGroupAttr, rewriter.getI64IntegerAttr(protocolId));
pulseEvents.push_back(pulse.getEvent());
}

assert(!pulseEvents.empty() &&
"must have at least one pulse after parallel protocol conversion");
Value outputSyncEvent = awaitEvents(llvm::to_vector(pulseEvents), rewriter);

SmallVector<Value> results;
for (Value result : op.getResults()) {
// unrealized conversion cast sync event to result type
auto event = UnrealizedConversionCastOp::create(rewriter, loc, result.getType(),
outputSyncEvent);
results.push_back(event.getResult(0));
Expand All @@ -188,10 +251,9 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern<ion::ParallelP
/// ```
/// will be converted to:
/// ```
/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?>
/// ... // other pulse parameters settings
/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds", [N : i64], 0>
/// %event = rtio.pulse %ch duration(%duration) frequency(%freq) phase(%phase)
/// : !rtio.channel<"dds", ?> -> !rtio.event
/// : !rtio.channel<"dds", [N : i64], 0> -> !rtio.event
/// ```
struct PulseToRTIOPattern : public OpConversionPattern<ion::PulseOp> {
IonInfo ionInfo;
Expand Down Expand Up @@ -240,10 +302,7 @@ struct PulseToRTIOPattern : public OpConversionPattern<ion::PulseOp> {
Value phaseValue =
arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(phase));

// Convert the qubit to a channel
ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)});
auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, nullptr);

// Resolve qubit index and set it as the qualifier
Value memrefLoadValue = nullptr;
traceValueWithCallback<TraceMode::Qreg>(op.getInQubit(), [&](Value value) -> WalkResult {
if (qextractToMemrefMap.count(value)) {
Expand All @@ -258,8 +317,16 @@ struct PulseToRTIOPattern : public OpConversionPattern<ion::PulseOp> {
return failure();
}

Value channel =
rtio::RTIOQubitToChannelOp::create(rewriter, loc, channelType, memrefLoadValue);
auto qubitIdx = getQubitIndex(memrefLoadValue);
if (failed(qubitIdx)) {
op->emitError("Failed to resolve qubit index from memref load");
return failure();
}
ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(*qubitIdx)});

IntegerAttr channelIdAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, channelIdAttr);
Value channel = rtio::RTIOChannelOp::create(rewriter, loc, channelType);

// Create rtio.pulse
auto eventType = rtio::EventType::get(ctx);
Expand All @@ -271,108 +338,6 @@ struct PulseToRTIOPattern : public OpConversionPattern<ion::PulseOp> {
}
};

/// Resolve the static channel mapping for the rtio.qubit_to_channel operation
///
/// It's expecting `qubit_to_channel` has the following def-use chain:
/// memref.global w/ constants -> memref.get_global -> memref.load -> qubit_to_channel
///
/// Example:
/// ```
/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?>
/// ```
/// will be converted to:
/// ```
/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds">
/// ```
struct ResolveChannelMappingPattern : public OpRewritePattern<rtio::RTIOQubitToChannelOp> {
ResolveChannelMappingPattern(MLIRContext *ctx)
: OpRewritePattern<rtio::RTIOQubitToChannelOp>(ctx)
{
}

LogicalResult matchAndRewrite(rtio::RTIOQubitToChannelOp op,
PatternRewriter &rewriter) const override
{
Location loc = op.getLoc();
Value qubit = op.getQubit();

auto loadOp = qubit.getDefiningOp<memref::LoadOp>();
if (!loadOp) {
return failure();
}

Value memref = loadOp.getMemRef();
auto getGlobalOp = memref.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
return failure();
}

StringRef globalName = getGlobalOp.getName();
ModuleOp module = op->getParentOfType<ModuleOp>();
if (!module) {
return failure();
}
auto globalOp = module.lookupSymbol<memref::GlobalOp>(globalName);
if (!globalOp) {
return failure();
}

auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
return failure();
}

auto denseAttr = llvm::dyn_cast<DenseIntElementsAttr>(*initialValue);
if (!denseAttr) {
return failure();
}

ValueRange indices = loadOp.getIndices();
if (indices.size() != 1) {
return failure();
}

IntegerAttr indexAttr;
if (!matchPattern(indices[0], m_Constant<IntegerAttr>(&indexAttr))) {
return failure();
}

int64_t index = indexAttr.getInt();

size_t denseSize = denseAttr.size();
if (index < 0 || static_cast<size_t>(index) >= denseSize) {
return failure();
}

APInt channelIdValue = denseAttr.getValues<APInt>()[index];

auto originalChannelType = llvm::dyn_cast<rtio::ChannelType>(op.getChannel().getType());
if (!originalChannelType) {
return failure();
}
StringRef kind = originalChannelType.getKind();
ArrayAttr qualifiers = originalChannelType.getQualifiers();

// channel should have exactly one use before lowering to channel op
assert(op.getChannel().hasOneUse() && "channel should have exactly one use");

auto pulseOp = cast<rtio::RTIOPulseOp>(*op.getChannel().getUsers().begin());
int64_t offset = cast<IntegerAttr>(pulseOp->getAttr("offset")).getInt();

IntegerAttr channelIdAttr = rewriter.getIntegerAttr(
rewriter.getIndexType(), (channelIdValue.getSExtValue() * 2 + offset));

auto resolvedChannelType =
rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr);

Value channel = rtio::RTIOChannelOp::create(rewriter, loc, resolvedChannelType);

rewriter.replaceOp(op, channel);

return success();
}
};

/// Propagates RTIO events from chain of operations to event types.
///
/// Steps:
Expand Down Expand Up @@ -493,8 +458,7 @@ struct MeasurePulseToRTIOPattern : public OpConversionPattern<ion::MeasurePulseO
return failure();
}

Value channel =
rtio::RTIOQubitToChannelOp::create(rewriter, loc, channelType, memrefLoadValue);
Value channel = rtio::RTIOChannelOp::create(rewriter, loc, channelType);

auto eventType = rtio::EventType::get(ctx);
Value event = rtio::RTIOPulseOp::create(rewriter, loc, eventType, channel, duration,
Expand Down Expand Up @@ -562,7 +526,7 @@ void populateParallelProtocolToRTIOPatterns(TypeConverter &typeConverter,
void populateIonToRTIOFinalizePatterns(RewritePatternSet &patterns)
{
patterns.add<PropagateEventsPattern>(patterns.getContext());
patterns.add<ResolveChannelMappingPattern>(patterns.getContext());
// patterns.add<ResolveChannelMappingPattern>(patterns.getContext());
}

void populateIonMeasurePulseToRTIOPatterns(TypeConverter &typeConverter,
Expand Down
Loading
Loading