Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mlir/include/Ion/IR/IonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/BuiltinTypes.td"

include "Ion/IR/IonDialect.td"
include "Ion/IR/IonInterfaces.td"
Expand Down Expand Up @@ -211,20 +212,21 @@ def MeasurePulseOp : BasePulseOp<"measure_pulse"> {
}

def ReadoutBitOp : Ion_Op<"readout_bit", [AllTypesMatch<["in_qubit", "out_qubit"]>]> {
let summary = "Read out the classical measurement result from a qubit after a measure pulse.";
let summary = "Read out measurement count for a mid-circuit measurement.";

let description = [{
This op reads out the classical measurement result from a ParallelProtocolOp containing a
MeasurePulseOp, and threads the qubit wire through.
Threads the qubit SSA through and yields a 32-bit readout value (e.g. photon count or a
host-provided classical value). Callers that need a binary result can compare `cnt_val` to
zero.
}];

let arguments = (ins
QubitType: $in_qubit
);

let results = (outs
I1: $mres,
QubitType: $out_qubit
QubitType: $out_qubit,
I32: $cnt_val
);

let assemblyFormat = [{
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/Ion/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ void populateConversionPatterns(mlir::LLVMTypeConverter &typeConverter,
void populateIonPulseToRTIOPatterns(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, const IonInfo &ionInfo,
mlir::DenseMap<mlir::Value, mlir::Value> &qextractToMemrefMap);
void populateIonMeasurePulseToRTIOPatterns(
mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
mlir::DenseMap<mlir::Value, mlir::Value> &qextractToMemrefMap);
void populateIonReadoutBitToRTIOPatterns(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);
void populateParallelProtocolToRTIOPatterns(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);
void populateIonToRTIOFinalizePatterns(mlir::RewritePatternSet &patterns);
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/Ion/Transforms/ValueTracing.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ auto traceValueWithCallback(mlir::Value value, CallbackT &&callback)
else if (auto op = dyn_cast<rtio::RTIOQubitToChannelOp>(defOp)) {
worklist.push(op.getQubit());
}
else if (auto op = dyn_cast<ion::ReadoutBitOp>(defOp)) {
worklist.push(op.getInQubit());
}
else if (auto op = dyn_cast<quantum::InsertOp>(defOp)) {
Value inQreg = op.getInQreg();
Value qubit = op.getQubit();
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/RTIO/IR/RTIOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "RTIO/IR/RTIODialect.td"

Expand Down Expand Up @@ -213,6 +214,21 @@ def RTIORPCOp : RTIO_Op<"rpc"> {
let hasVerifier = 1;
}

def RTIOReadoutOp : RTIO_Op<"readout"> {
let summary = "Read measurement count from a measurement pulse event";
let description = [{
Consumes the `!rtio.event` produced by measurement pulse and yields the
hardware readout count (e.g. photon counts) once that pulse has completed.
}];

let arguments = (ins RTIOEventType:$event);
let results = (outs I32:$count);

let assemblyFormat = [{
$event attr-dict `:` type($event) `->` type($count)
}];
}

def RTIOEmptyOp : RTIO_Op<"empty", [Pure]> {
let summary = "Create an empty event for sequencing";
let description = [{
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Ion/Transforms/ConversionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,17 @@ struct ReadoutBitOpPattern : public OpConversionPattern<catalyst::ion::ReadoutBi
MLIRContext *ctx = this->getContext();
Type ptrType = LLVM::LLVMPointerType::get(ctx);

Type readoutFuncType = LLVM::LLVMFunctionType::get(IntegerType::get(ctx, 1), {ptrType});
Type i32Ty = IntegerType::get(ctx, 32);
Type readoutFuncType = LLVM::LLVMFunctionType::get(i32Ty, {ptrType});
LLVM::LLVMFuncOp readoutFnDecl = catalyst::ensureFunctionDeclaration<LLVM::LLVMFuncOp>(
rewriter, op, "__catalyst__oqd__readout_bit", readoutFuncType);

Value mres =
Value cntVal =
LLVM::CallOp::create(rewriter, loc, readoutFnDecl, ValueRange{adaptor.getInQubit()})
.getResult();

// Thread the qubit through unchanged; the physical qubit pointer is the same.
rewriter.replaceOp(op, {mres, adaptor.getInQubit()});
rewriter.replaceOp(op, {adaptor.getInQubit(), cntVal});
return success();
}
};
Expand Down
11 changes: 8 additions & 3 deletions mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,15 +624,20 @@ struct MeasureOpToMeasurePulsePattern : public mlir::OpRewritePattern<MeasureOp>

// Create a readout bit op to read out the classical measurement result
// and thread the qubit wire through.
auto readoutOp = ion::ReadoutBitOp::create(rewriter, loc, rewriter.getI1Type(),
ion::QubitType::get(ctx), ppOp.getResults()[0]);
Type readoutResultTys[] = {ion::QubitType::get(ctx), rewriter.getI32Type()};
auto readoutOp = ion::ReadoutBitOp::create(rewriter, loc, TypeRange(readoutResultTys),
ppOp.getResults()[0]);

auto qubitResults = convertIonQubitsToQuantumBits(rewriter, loc, readoutOp.getOutQubit());
if (!qubitResults.has_value()) {
return failure();
}

rewriter.replaceOp(op, {readoutOp.getMres(), qubitResults.value()[0]});
Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
Value mres = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
readoutOp.getCntVal(), zero);

rewriter.replaceOp(op, {mres, qubitResults.value()[0]});
return success();
}
};
Expand Down
101 changes: 101 additions & 0 deletions mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "Ion/Transforms/ValueTracing.h"
#include "Quantum/IR/QuantumDialect.h"
#include "RTIO/IR/RTIODialect.h"
#include "RTIO/IR/RTIOOps.h"

using namespace mlir;
using namespace catalyst;
Expand Down Expand Up @@ -451,6 +452,93 @@ struct PropagateEventsPattern : public OpRewritePattern<UnrealizedConversionCast
}
};

/// Convert ion.measure_pulse to rtio.pulse marked with `_measurement` for ARTIQ lowering.
struct MeasurePulseToRTIOPattern : public OpConversionPattern<ion::MeasurePulseOp> {
DenseMap<Value, Value> &qextractToMemrefMap;

MeasurePulseToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx,
DenseMap<Value, Value> &qextractToMemrefMap)
: OpConversionPattern<ion::MeasurePulseOp>(typeConverter, ctx),
qextractToMemrefMap(qextractToMemrefMap)
{
}

LogicalResult matchAndRewrite(ion::MeasurePulseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
Location loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();

Value duration = op.getTime();
Value freqValue = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0));
Value phaseValue = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0));

auto beamAttr = op.getBeam();
int64_t transitionIndex = beamAttr.getTransitionIndex().getInt();
ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)});
IntegerAttr channelIdAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); // TTL0
auto channelType = rtio::ChannelType::get(ctx, "ttl", qualifiers, channelIdAttr);

Value memrefLoadValue = nullptr;
traceValueWithCallback<TraceMode::Qreg>(op.getInQubit(), [&](Value value) -> WalkResult {
if (qextractToMemrefMap.count(value)) {
memrefLoadValue = qextractToMemrefMap[value];
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (memrefLoadValue == nullptr) {
op->emitError("Failed to trace the memref load value for measure_pulse");
return failure();
}

Value channel =
rtio::RTIOQubitToChannelOp::create(rewriter, loc, channelType, memrefLoadValue);

auto eventType = rtio::EventType::get(ctx);
Value event = rtio::RTIOPulseOp::create(rewriter, loc, eventType, channel, duration,
freqValue, phaseValue, nullptr);
event.getDefiningOp()->setAttr("_measurement", rewriter.getUnitAttr());

rewriter.replaceOp(op, event);
return success();
}
};

/// Convert ion.readout_bit to rtio.readout (measurement count).
struct ReadoutBitToRTIOPattern : public OpConversionPattern<ion::ReadoutBitOp> {
ReadoutBitToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx)
: OpConversionPattern<ion::ReadoutBitOp>(typeConverter, ctx)
{
}

LogicalResult matchAndRewrite(ion::ReadoutBitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
Location loc = op.getLoc();
rtio::RTIOPulseOp measPulse;
for (Operation *cur = op->getPrevNode(); cur; cur = cur->getPrevNode()) {
auto pulse = dyn_cast<rtio::RTIOPulseOp>(cur);
if (pulse && pulse->hasAttr("_measurement")) {
measPulse = pulse;
break;
}
}
if (!measPulse) {
return op->emitError(
"readout_bit: no preceding rtio.pulse(_measurement) in this block (expected "
"measure -> readout ordering after parallelprotocol lowering)");
}

auto readout =
rtio::RTIOReadoutOp::create(rewriter, loc, rewriter.getI32Type(), measPulse.getEvent());

rewriter.replaceOp(op, ValueRange{op.getInQubit(), readout.getCount()});
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -477,5 +565,18 @@ void populateIonToRTIOFinalizePatterns(RewritePatternSet &patterns)
patterns.add<ResolveChannelMappingPattern>(patterns.getContext());
}

void populateIonMeasurePulseToRTIOPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
DenseMap<Value, Value> &qextractToMemrefMap)
{
patterns.add<MeasurePulseToRTIOPattern>(typeConverter, patterns.getContext(),
qextractToMemrefMap);
}

void populateIonReadoutBitToRTIOPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
{
patterns.add<ReadoutBitToRTIOPattern>(typeConverter, patterns.getContext());
}

} // namespace ion
} // namespace catalyst
Loading
Loading