diff --git a/mlir/include/Ion/IR/IonOps.td b/mlir/include/Ion/IR/IonOps.td index 33f8c21195..d7c9a69f76 100644 --- a/mlir/include/Ion/IR/IonOps.td +++ b/mlir/include/Ion/IR/IonOps.td @@ -22,6 +22,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/BuiltinTypes.td" include "Ion/IR/IonDialect.td" include "Ion/IR/IonInterfaces.td" @@ -211,11 +212,12 @@ def MeasurePulseOp : BasePulseOp<"measure_pulse"> { } def ReadoutBitOp : Ion_Op<"readout_bit", [AllTypesMatch<["in_qubit", "out_qubit"]>]> { - let summary = "Read out the classical measurement result from a qubit after a measure pulse."; + let summary = "Read out measurement count for a mid-circuit measurement."; let description = [{ - This op reads out the classical measurement result from a ParallelProtocolOp containing a - MeasurePulseOp, and threads the qubit wire through. + Threads the qubit SSA through and yields a 32-bit readout value (e.g. photon count or a + host-provided classical value). Callers that need a binary result can compare `cnt_val` to + zero. }]; let arguments = (ins @@ -223,8 +225,8 @@ def ReadoutBitOp : Ion_Op<"readout_bit", [AllTypesMatch<["in_qubit", "out_qubit" ); let results = (outs - I1: $mres, - QubitType: $out_qubit + QubitType: $out_qubit, + I32: $cnt_val ); let assemblyFormat = [{ diff --git a/mlir/include/Ion/Transforms/Patterns.h b/mlir/include/Ion/Transforms/Patterns.h index e5e68941ef..515c3ca1df 100644 --- a/mlir/include/Ion/Transforms/Patterns.h +++ b/mlir/include/Ion/Transforms/Patterns.h @@ -33,6 +33,11 @@ void populateConversionPatterns(mlir::LLVMTypeConverter &typeConverter, void populateIonPulseToRTIOPatterns(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, const IonInfo &ionInfo, mlir::DenseMap &qextractToMemrefMap); +void populateIonMeasurePulseToRTIOPatterns( + mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::DenseMap &qextractToMemrefMap); +void populateIonReadoutBitToRTIOPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); void populateParallelProtocolToRTIOPatterns(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns); void populateIonToRTIOFinalizePatterns(mlir::RewritePatternSet &patterns); diff --git a/mlir/include/Ion/Transforms/ValueTracing.h b/mlir/include/Ion/Transforms/ValueTracing.h index 4c14f4dbfd..0f86d9fdfe 100644 --- a/mlir/include/Ion/Transforms/ValueTracing.h +++ b/mlir/include/Ion/Transforms/ValueTracing.h @@ -119,6 +119,9 @@ auto traceValueWithCallback(mlir::Value value, CallbackT &&callback) else if (auto op = dyn_cast(defOp)) { worklist.push(op.getQubit()); } + else if (auto op = dyn_cast(defOp)) { + worklist.push(op.getInQubit()); + } else if (auto op = dyn_cast(defOp)) { Value inQreg = op.getInQreg(); Value qubit = op.getQubit(); diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index c81c84434e..30a0d67a31 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "RTIO/IR/RTIODialect.td" @@ -213,6 +214,21 @@ def RTIORPCOp : RTIO_Op<"rpc"> { let hasVerifier = 1; } +def RTIOReadoutOp : RTIO_Op<"readout"> { + let summary = "Read measurement count from a measurement pulse event"; + let description = [{ + Consumes the `!rtio.event` produced by measurement pulse and yields the + hardware readout count (e.g. photon counts) once that pulse has completed. + }]; + + let arguments = (ins RTIOEventType:$event); + let results = (outs I32:$count); + + let assemblyFormat = [{ + $event attr-dict `:` type($event) `->` type($count) + }]; +} + def RTIOEmptyOp : RTIO_Op<"empty", [Pure]> { let summary = "Create an empty event for sequencing"; let description = [{ diff --git a/mlir/lib/Ion/Transforms/ConversionPatterns.cpp b/mlir/lib/Ion/Transforms/ConversionPatterns.cpp index 2ef9957af1..e5ab77984c 100644 --- a/mlir/lib/Ion/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Ion/Transforms/ConversionPatterns.cpp @@ -344,16 +344,17 @@ struct ReadoutBitOpPattern : public OpConversionPatterngetContext(); Type ptrType = LLVM::LLVMPointerType::get(ctx); - Type readoutFuncType = LLVM::LLVMFunctionType::get(IntegerType::get(ctx, 1), {ptrType}); + Type i32Ty = IntegerType::get(ctx, 32); + Type readoutFuncType = LLVM::LLVMFunctionType::get(i32Ty, {ptrType}); LLVM::LLVMFuncOp readoutFnDecl = catalyst::ensureFunctionDeclaration( rewriter, op, "__catalyst__oqd__readout_bit", readoutFuncType); - Value mres = + Value cntVal = LLVM::CallOp::create(rewriter, loc, readoutFnDecl, ValueRange{adaptor.getInQubit()}) .getResult(); // Thread the qubit through unchanged; the physical qubit pointer is the same. - rewriter.replaceOp(op, {mres, adaptor.getInQubit()}); + rewriter.replaceOp(op, {adaptor.getInQubit(), cntVal}); return success(); } }; diff --git a/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp b/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp index 8ae6177044..57b1916529 100644 --- a/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp +++ b/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp @@ -624,15 +624,20 @@ struct MeasureOpToMeasurePulsePattern : public mlir::OpRewritePattern // Create a readout bit op to read out the classical measurement result // and thread the qubit wire through. - auto readoutOp = ion::ReadoutBitOp::create(rewriter, loc, rewriter.getI1Type(), - ion::QubitType::get(ctx), ppOp.getResults()[0]); + Type readoutResultTys[] = {ion::QubitType::get(ctx), rewriter.getI32Type()}; + auto readoutOp = ion::ReadoutBitOp::create(rewriter, loc, TypeRange(readoutResultTys), + ppOp.getResults()[0]); auto qubitResults = convertIonQubitsToQuantumBits(rewriter, loc, readoutOp.getOutQubit()); if (!qubitResults.has_value()) { return failure(); } - rewriter.replaceOp(op, {readoutOp.getMres(), qubitResults.value()[0]}); + Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); + Value mres = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, + readoutOp.getCntVal(), zero); + + rewriter.replaceOp(op, {mres, qubitResults.value()[0]}); return success(); } }; diff --git a/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp b/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp index 9706091062..0c160850f2 100644 --- a/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp +++ b/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp @@ -23,6 +23,7 @@ #include "Ion/Transforms/ValueTracing.h" #include "Quantum/IR/QuantumDialect.h" #include "RTIO/IR/RTIODialect.h" +#include "RTIO/IR/RTIOOps.h" using namespace mlir; using namespace catalyst; @@ -451,6 +452,93 @@ struct PropagateEventsPattern : public OpRewritePattern { + DenseMap &qextractToMemrefMap; + + MeasurePulseToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx, + DenseMap &qextractToMemrefMap) + : OpConversionPattern(typeConverter, ctx), + qextractToMemrefMap(qextractToMemrefMap) + { + } + + LogicalResult matchAndRewrite(ion::MeasurePulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Value duration = op.getTime(); + Value freqValue = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); + Value phaseValue = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); + + auto beamAttr = op.getBeam(); + int64_t transitionIndex = beamAttr.getTransitionIndex().getInt(); + ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)}); + IntegerAttr channelIdAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); // TTL0 + auto channelType = rtio::ChannelType::get(ctx, "ttl", qualifiers, channelIdAttr); + + Value memrefLoadValue = nullptr; + traceValueWithCallback(op.getInQubit(), [&](Value value) -> WalkResult { + if (qextractToMemrefMap.count(value)) { + memrefLoadValue = qextractToMemrefMap[value]; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (memrefLoadValue == nullptr) { + op->emitError("Failed to trace the memref load value for measure_pulse"); + return failure(); + } + + Value channel = + rtio::RTIOQubitToChannelOp::create(rewriter, loc, channelType, memrefLoadValue); + + auto eventType = rtio::EventType::get(ctx); + Value event = rtio::RTIOPulseOp::create(rewriter, loc, eventType, channel, duration, + freqValue, phaseValue, nullptr); + event.getDefiningOp()->setAttr("_measurement", rewriter.getUnitAttr()); + + rewriter.replaceOp(op, event); + return success(); + } +}; + +/// Convert ion.readout_bit to rtio.readout (measurement count). +struct ReadoutBitToRTIOPattern : public OpConversionPattern { + ReadoutBitToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) + { + } + + LogicalResult matchAndRewrite(ion::ReadoutBitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + rtio::RTIOPulseOp measPulse; + for (Operation *cur = op->getPrevNode(); cur; cur = cur->getPrevNode()) { + auto pulse = dyn_cast(cur); + if (pulse && pulse->hasAttr("_measurement")) { + measPulse = pulse; + break; + } + } + if (!measPulse) { + return op->emitError( + "readout_bit: no preceding rtio.pulse(_measurement) in this block (expected " + "measure -> readout ordering after parallelprotocol lowering)"); + } + + auto readout = + rtio::RTIOReadoutOp::create(rewriter, loc, rewriter.getI32Type(), measPulse.getEvent()); + + rewriter.replaceOp(op, ValueRange{op.getInQubit(), readout.getCount()}); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -477,5 +565,18 @@ void populateIonToRTIOFinalizePatterns(RewritePatternSet &patterns) patterns.add(patterns.getContext()); } +void populateIonMeasurePulseToRTIOPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + DenseMap &qextractToMemrefMap) +{ + patterns.add(typeConverter, patterns.getContext(), + qextractToMemrefMap); +} + +void populateIonReadoutBitToRTIOPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) +{ + patterns.add(typeConverter, patterns.getContext()); +} + } // namespace ion } // namespace catalyst diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index c0a8d9f000..1dd8b1b183 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -44,6 +44,9 @@ namespace ion { namespace { +constexpr StringLiteral rtioTransferMeasurementResults = "__rtio_transfer_measurement_results"; +constexpr StringLiteral rtioInitDataset = "__rtio_init_dataset"; + /// Load a JSON file and convert it to an rtio.config attribute FailureOr loadDeviceDbAsConfig(MLIRContext *ctx, StringRef filePath) { @@ -105,15 +108,26 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { { ConversionTarget target(baseTarget); target.addIllegalOp(); - + target.addIllegalOp(); RewritePatternSet patterns(ctx); populateIonPulseToRTIOPatterns(typeConverter, patterns, ionInfo, qextractToMemrefMap); + populateIonMeasurePulseToRTIOPatterns(typeConverter, patterns, qextractToMemrefMap); if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { return failure(); } return success(); } + LogicalResult IonReadoutBitConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, + TypeConverter &typeConverter, MLIRContext *ctx) + { + ConversionTarget target(baseTarget); + target.addIllegalOp(); + RewritePatternSet patterns(ctx); + populateIonReadoutBitToRTIOPatterns(typeConverter, patterns); + return applyPartialConversion(funcOp, target, std::move(patterns)); + } + LogicalResult ParallelProtocolConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, TypeConverter &typeConverter, MLIRContext *ctx) { @@ -209,6 +223,75 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return success(); } + void createMeasurementHelperFunctions(func::FuncOp funcOp) + { + MLIRContext *ctx = funcOp.getContext(); + SmallVector readouts; + funcOp.walk([&](rtio::RTIOReadoutOp readout) { readouts.push_back(readout); }); + + if (readouts.empty()) { + return; + } + + auto module = funcOp->getParentOfType(); + OpBuilder builder(ctx); + Location loc = funcOp.getLoc(); + + // Create __rtio_init_dataset: wraps the init_dataset RPC + { + auto funcType = FunctionType::get(ctx, {}, {}); + builder.setInsertionPointToEnd(module.getBody()); + auto initFunc = func::FuncOp::create(builder, loc, rtioInitDataset, funcType); + initFunc.setPrivate(); + + Block *entryBlock = initFunc.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + rtio::RTIORPCOp::create(builder, loc, TypeRange{}, + SymbolRefAttr::get(ctx, "init_dataset"), UnitAttr::get(ctx), + builder.getI32IntegerAttr(static_cast(1)), + ValueRange{}); + func::ReturnOp::create(builder, loc); + } + + // Create __rtio_transfer_measurement_results: + // loops over a memref of readout counts and sends each via RPC + { + auto numReadouts = static_cast(readouts.size()); + auto memrefType = MemRefType::get({numReadouts}, builder.getI32Type()); + auto funcType = FunctionType::get(ctx, {memrefType}, {}); + + builder.setInsertionPointToEnd(module.getBody()); + auto transferFunc = + func::FuncOp::create(builder, loc, rtioTransferMeasurementResults, funcType); + transferFunc.setPrivate(); + + Block *entryBlock = transferFunc.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Value memrefArg = entryBlock->getArgument(0); + Value lb = arith::ConstantIndexOp::create(builder, loc, 0); + Value ub = arith::ConstantIndexOp::create(builder, loc, numReadouts); + Value step = arith::ConstantIndexOp::create(builder, loc, 1); + + scf::ForOp forOp = scf::ForOp::create(builder, loc, lb, ub, step); + { + OpBuilder::InsertionGuard forGuard(builder); + builder.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value indexI32 = arith::IndexCastOp::create(builder, loc, builder.getI32Type(), iv); + Value count = memref::LoadOp::create(builder, loc, memrefArg, ValueRange{iv}); + // Host RPC: transfer_measurement_result(self, index, value) + rtio::RTIORPCOp::create(builder, loc, TypeRange{}, + SymbolRefAttr::get(ctx, "transfer_measurement_result"), + UnitAttr::get(ctx), + builder.getI32IntegerAttr(static_cast(2)), + ValueRange{indexI32, count}); + } + + func::ReturnOp::create(builder, loc); + } + } + LogicalResult FinalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) { RewritePatternSet patterns(ctx); @@ -223,6 +306,9 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return failure(); } + // Create helper function definitions (no kernel modifications) + createMeasurementHelperFunctions(funcOp); + // Clean up unused quantum/ion/memref/linalg/builtin ops after patterns cleanupUnusedOps(funcOp); @@ -355,6 +441,50 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { } } + // Remove other unused functions, keep the kernel and RTIO helpers only. + void eraseNonKernelFunctions(ModuleOp module, func::FuncOp kernelFunc) + { + for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { + StringRef name = funcOp.getName(); + if (name != kernelFunc.getName() && name != rtioInitDataset && + name != rtioTransferMeasurementResults) { + funcOp.erase(); + } + } + } + + LogicalResult lowerKernelToRtio(ModuleOp module, func::FuncOp kernelFunc, IonInfo ionInfo, + MLIRContext *ctx) + { + // Drop one pulse from the protocol + dropOnePulseFromProtocol(kernelFunc); + + // Map qreg alloc / extract to memref for rtio.pulse channel construction + DenseMap qregToMemrefMap; + DenseMap qextractToMemrefMap; + initializeMemrefMap(kernelFunc, module, qregToMemrefMap, qextractToMemrefMap, ctx); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [&](ion::PulseType type) -> Type { return rtio::EventType::get(ctx); }); + + ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + if (failed(IonPulseConversion(kernelFunc, target, typeConverter, ionInfo, + qextractToMemrefMap, ctx)) || + failed(ParallelProtocolConversion(kernelFunc, target, typeConverter, ctx)) || + failed(IonReadoutBitConversion(kernelFunc, target, typeConverter, ctx)) || + failed(SCFStructuralConversion(kernelFunc, target, typeConverter, ctx)) || + failed(FinalizeKernelFunction(kernelFunc, ctx))) { + return failure(); + } + + eraseNonKernelFunctions(module, kernelFunc); + return success(); + } + void runOnOperation() override { MLIRContext *ctx = &getContext(); @@ -404,40 +534,10 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { func::FuncOp newQnodeFunc = createKernelFunction(qnodeFunc, kernelName, builder); module.insert(qnodeFunc, newQnodeFunc); - // drop one of the pulse from the certain protocol - // the way we handle the dropped pulse will be updated in the future - dropOnePulseFromProtocol(newQnodeFunc); - - // Construct mapping from qreg alloc and qreg extract to memref - // In the later conversion, we use the mapping to construct the channel for rtio.pulse - DenseMap qregToMemrefMap; - DenseMap qextractToMemrefMap; - initializeMemrefMap(newQnodeFunc, module, qregToMemrefMap, qextractToMemrefMap, ctx); - - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion( - [&](ion::PulseType type) -> Type { return rtio::EventType::get(ctx); }); - - ConversionTarget target(*ctx); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - - // prepare kernel function - if (failed(IonPulseConversion(newQnodeFunc, target, typeConverter, ionInfo, - qextractToMemrefMap, ctx)) || - failed(ParallelProtocolConversion(newQnodeFunc, target, typeConverter, ctx)) || - failed(SCFStructuralConversion(newQnodeFunc, target, typeConverter, ctx)) || - failed(FinalizeKernelFunction(newQnodeFunc, ctx))) { + if (failed(lowerKernelToRtio(module, newQnodeFunc, ionInfo, ctx))) { newQnodeFunc->emitError("Failed to convert to rtio dialect"); return signalPassFailure(); } - - // remove other unused functions, only keep the kernel function - for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { - if (funcOp.getName().str() != newQnodeFunc.getName().str()) { - funcOp.erase(); - } - } } }; diff --git a/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp b/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp index ad58964f3a..7fb7c03159 100644 --- a/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp +++ b/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp @@ -14,6 +14,10 @@ #pragma once +#include +#include +#include + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -22,6 +26,7 @@ #include "Catalyst/Utils/EnsureFunctionDeclaration.h" #include "RTIO/IR/RTIOOps.h" // For ConfigAttr +#include "Utils.hpp" namespace catalyst { namespace rtio { @@ -42,11 +47,20 @@ constexpr StringLiteral delayMu = "delay_mu"; constexpr StringLiteral rtioOutput = "rtio_output"; constexpr StringLiteral rtioInit = "rtio_init"; constexpr StringLiteral rtioGetCounter = "rtio_get_counter"; +constexpr StringLiteral rtioInitDataset = "__rtio_init_dataset"; +constexpr StringLiteral rtioTransferMeasurementResults = "__rtio_transfer_measurement_results"; constexpr StringLiteral kernel = "__kernel__"; // ARTIQ RPC runtime constexpr StringLiteral rpcSend = "rpc_send"; constexpr StringLiteral rpcSendAsync = "rpc_send_async"; constexpr StringLiteral rpcRecv = "rpc_recv"; +// RTIO input FIFO read +constexpr StringLiteral rtioInputTimestamp = "rtio_input_timestamp"; +// Measurement helper functions +constexpr StringLiteral gateRisingMu = "__rtio_gate_rising_mu"; +constexpr StringLiteral mockMeasure = "__rtio_mock_measure"; +constexpr StringLiteral rtioCount = "__rtio_count"; +constexpr StringLiteral waitUntilMu = "__rtio_wait_until_mu"; } // namespace ARTIQFuncNames //===----------------------------------------------------------------------===// @@ -54,6 +68,7 @@ constexpr StringLiteral rpcRecv = "rpc_recv"; //===----------------------------------------------------------------------===// namespace ARTIQHardwareConfig { +// Hardware constants constexpr double nanosecondPeriod = 1e-9; constexpr double ftwScaleFactor = 4.294967296; // 2^32 / 1e9 constexpr double powScaleFactor = 65536.0; // 2^16 @@ -69,8 +84,45 @@ constexpr int32_t spiFlagsReleaseCS = 10; // SPI_CS_POLARITY | SPI_END (CS high constexpr int64_t ioUpdatePulseWidth = 8; constexpr int64_t refPeriodMu = 8; // RTIO reference period (Kasli = 8ns @ 125MHz RTIO clock) constexpr int64_t minTTLPulseMu = 8; // Minimum TTL pulse duration to avoid 0 duration events + +// Simulated fluorescence / DMD during measurement window +constexpr int64_t simPhotonPulseOnMu = 100; +constexpr int64_t simPhotonGapMu = 10; +constexpr int64_t simPhotonPeriodMu = simPhotonPulseOnMu + simPhotonGapMu; +constexpr int64_t measurementPhotonLeadMu = 10; +constexpr int64_t measurementStartOffsetMu = 100000; +constexpr int32_t defaultTtlInOutGateLatencyMu = 104; + +/// TTL6 trigger pulse fires 200 mu before start for Red Pitaya oscilloscope acquisition. +constexpr int64_t scopeTriggerLeadMu = 200; + +/// Compute a Poisson-distributed simulated photon count +inline int32_t computeSimulatedPhotonCount(int64_t durationMu) +{ + int64_t maxPhotons = (durationMu - measurementPhotonLeadMu) / simPhotonPeriodMu; + if (maxPhotons <= 0) { + return 0; + } + + // convert duration (ns) to us + double durationUs = static_cast(durationMu) / 1000.0; + + static std::mt19937 rng(std::random_device{}()); + std::poisson_distribution dist(durationUs); + int32_t count = dist(rng); + + return std::min(static_cast(count), maxPhotons); +} + } // namespace ARTIQHardwareConfig +struct MeasurementChannelAddrs { + int32_t gateRisingAddr = 0; + int32_t countChannel = 0; + int32_t acquisitionOutputAddr = 0; + int32_t dmdOutputAddr = 0; +}; + //===----------------------------------------------------------------------===// // ARTIQ Runtime Builder //===----------------------------------------------------------------------===// @@ -137,9 +189,26 @@ class ARTIQRuntimeBuilder { return call.getResult(); } + void waitUntilMu(Value time) + { + ensureWaitUntilMuFunc(); + auto func = getModule().lookupSymbol(ARTIQFuncNames::waitUntilMu); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{time}); + call.setTailCallKind(LLVM::TailCallKind::Tail); + } + // Duration conversion Value secToMu(Value durationSec) { + // Constant fold if the duration is a known constant + if (auto cst = durationSec.getDefiningOp()) { + if (auto fAttr = dyn_cast(cst.getValue())) { + double sec = fAttr.getValueAsDouble(); + int64_t mu = + static_cast(std::round(sec / ARTIQHardwareConfig::nanosecondPeriod)); + return constI64(mu); + } + } ensureSecToMuFunc(); auto func = getModule().lookupSymbol(ARTIQFuncNames::secToMu); auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{durationSec}); @@ -207,11 +276,70 @@ class ARTIQRuntimeBuilder { return call.getResult(); } + /// i64 rtio_input_timestamp(i64 deadline_mu, i32 channel), returns timestamp or -1. + Value rtioInputTimestamp(Value deadlineMu, Value inputChannelI32) + { + auto func = ensureFunc(ARTIQFuncNames::rtioInputTimestamp, + LLVM::LLVMFunctionType::get(i64Ty, {i64Ty, i32Ty})); + auto call = + LLVM::CallOp::create(builder, getLoc(), func, ValueRange{deadlineMu, inputChannelI32}); + return call.getResult(); + } + // TTL operations void ttlOn(Value channelAddr) { rtioOutput(channelAddr, constI32(1)); } void ttlOff(Value channelAddr) { rtioOutput(channelAddr, constI32(0)); } + /// void __rtio_gate_rising_mu(i32 sens_addr, i64 duration_mu) + void gateRisingMu(Value sensAddr, Value durationMu) + { + auto func = ensureFunc(ARTIQFuncNames::gateRisingMu, + LLVM::LLVMFunctionType::get(voidTy, {i32Ty, i64Ty})); + LLVM::CallOp::create(builder, getLoc(), func, ValueRange{sensAddr, durationMu}); + } + + /// void __rtio_mock_measure(i64 start_mu, i32 ttl7_addr, i64 photon_count) + void mockMeasure(Value startMu, Value ttl7Addr, Value photonCount) + { + auto func = ensureFunc(ARTIQFuncNames::mockMeasure, + LLVM::LLVMFunctionType::get(voidTy, {i64Ty, i32Ty, i64Ty})); + LLVM::CallOp::create(builder, getLoc(), func, ValueRange{startMu, ttl7Addr, photonCount}); + } + + /// i32 __rtio_count(i64 deadline, i32 channel), returns number of edges. + Value rtioCount(Value deadline, Value channel) + { + auto func = ensureFunc(ARTIQFuncNames::rtioCount, + LLVM::LLVMFunctionType::get(i32Ty, {i64Ty, i32Ty})); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{deadline, channel}); + return call.getResult(); + } + + /// Get the measurement channel addresses from the device_db + static MeasurementChannelAddrs getMeasurementChannelAddresses(ModuleOp module) + { + MeasurementChannelAddrs out; + auto configAttr = module->getAttrOfType(ConfigAttr::getModuleAttrName()); + if (!configAttr) { + return out; + } + + static constexpr StringRef kDb = "device_db"; + static constexpr StringRef kArgs = "arguments"; + static constexpr StringRef kCh = "channel"; + + int64_t ttl0Raw = device_db_detail::intAtPath(configAttr, {kDb, "ttl0", kArgs, kCh}); + int64_t ttl6Raw = device_db_detail::intAtPath(configAttr, {kDb, "ttl6", kArgs, kCh}); + int64_t ttl7Raw = device_db_detail::intAtPath(configAttr, {kDb, "ttl7", kArgs, kCh}); + + out.gateRisingAddr = static_cast((ttl0Raw << 8) | 2); + out.countChannel = static_cast(ttl0Raw); + out.acquisitionOutputAddr = static_cast(ttl6Raw << 8); + out.dmdOutputAddr = static_cast(ttl7Raw << 8); + return out; + } + // Constant creation helpers Value constI32(int32_t val) { @@ -238,9 +366,18 @@ class ARTIQRuntimeBuilder { /// This should be called before lowering patterns that depend on these functions. void ensureHelperFunctions() { + // Timing helpers ensureSecToMuFunc(); + ensureWaitUntilMuFunc(); + + // Frequency setting ensureConfigSpiFunc(); ensureSetFrequencyFunc(); + + // Measurement helper functions + ensureGateRisingMuFunc(); + ensureMockMeasureFunc(); + ensureCountFunc(); } private: @@ -423,32 +560,187 @@ class ARTIQRuntimeBuilder { auto configAttr = module->getAttrOfType(ConfigAttr::getModuleAttrName()); assert(configAttr && "rtio.config attribute not found on module"); - auto getChannel = [&](ArrayRef path) -> int64_t { - Attribute current = configAttr; - for (StringRef key : path) { - if (auto dict = dyn_cast(current)) { - current = dict.get(key); - } - else if (auto cfg = dyn_cast(current)) { - current = cfg.get(key); - } - else { - return 0; - } - } - return cast(current).getInt(); - }; - - int64_t spiChannel = getChannel({"device_db", "spi_urukul0", "arguments", "channel"}); + int64_t spiChannel = device_db_detail::intAtPath( + configAttr, {"device_db", "spi_urukul0", "arguments", "channel"}); // chip_select from urukul0_ch0 is the base CS (typically 4) // ch0->CS=4, ch1->CS=5, ch2->CS=6, ch3->CS=7 - int64_t csBase = getChannel({"device_db", "urukul0_ch0", "arguments", "chip_select"}); - int64_t ioUpdateChannel = - getChannel({"device_db", "ttl_urukul0_io_update", "arguments", "channel"}); + int64_t csBase = device_db_detail::intAtPath( + configAttr, {"device_db", "urukul0_ch0", "arguments", "chip_select"}); + int64_t ioUpdateChannel = device_db_detail::intAtPath( + configAttr, {"device_db", "ttl_urukul0_io_update", "arguments", "channel"}); return {static_cast(spiChannel << 8), static_cast(csBase), static_cast(ioUpdateChannel << 8)}; } + + /// Opens the TTL0 sensitivity gate for `duration_mu` + void ensureGateRisingMuFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::gateRisingMu)) { + return; + } + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i32Ty, i64Ty}); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::gateRisingMu, + funcTy, LLVM::Linkage::Internal); + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + Value sensAddr = entry->getArgument(0); + Value durationMu = entry->getArgument(1); + + rtioOutput(sensAddr, constI32(1)); + delayMu(durationMu); + rtioOutput(sensAddr, constI32(0)); + LLVM::ReturnOp::create(builder, getLoc(), ValueRange{}); + } + + /// Simulated photon events: + /// ``` + /// for i in 0..: + /// fire TTL7 at start_mu + leadMu + i * periodMu + /// ``` + void ensureMockMeasureFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::mockMeasure)) { + return; + } + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i64Ty, i32Ty, i64Ty}); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::mockMeasure, funcTy, + LLVM::Linkage::Internal); + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + Value startMu = entry->getArgument(0); + Value ttl7Addr = entry->getArgument(1); + Value photonCount = entry->getArgument(2); + + Value periodMu = constI64(ARTIQHardwareConfig::simPhotonPeriodMu); + Value leadMu = constI64(ARTIQHardwareConfig::measurementPhotonLeadMu); + Value onMu = constI64(ARTIQHardwareConfig::simPhotonPulseOnMu); + + Block *loopHead = new Block(); + Block *body = new Block(); + Block *exit = new Block(); + func.getBody().push_back(loopHead); + func.getBody().push_back(body); + func.getBody().push_back(exit); + loopHead->addArgument(i64Ty, getLoc()); + + Value c0 = constI64(0); + LLVM::BrOp::create(builder, getLoc(), ValueRange{c0}, loopHead); + + builder.setInsertionPointToStart(loopHead); + Value iv = loopHead->getArgument(0); + Value cond = + arith::CmpIOp::create(builder, getLoc(), arith::CmpIPredicate::slt, iv, photonCount); + LLVM::CondBrOp::create(builder, getLoc(), cond, body, exit); + + builder.setInsertionPointToStart(body); + Value offset = arith::MulIOp::create(builder, getLoc(), iv, periodMu); + Value base = arith::AddIOp::create(builder, getLoc(), startMu, leadMu); + Value tPulse = arith::AddIOp::create(builder, getLoc(), base, offset); + atMu(tPulse); + ttlOn(ttl7Addr); + delayMu(onMu); + ttlOff(ttl7Addr); + Value ivNext = arith::AddIOp::create(builder, getLoc(), iv, constI64(1)); + LLVM::BrOp::create(builder, getLoc(), ValueRange{ivNext}, loopHead); + + builder.setInsertionPointToStart(exit); + LLVM::ReturnOp::create(builder, getLoc(), ValueRange{}); + } + + /// Count the number of edges on the given channel before the deadline + /// ``` + /// count = 0 + /// while rtio_input_timestamp(deadline, channel) > -1: + /// count++ + /// return count + /// ``` + void ensureCountFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::rtioCount)) { + return; + } + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(i32Ty, {i64Ty, i32Ty}); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::rtioCount, funcTy, + LLVM::Linkage::Internal); + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + Value deadline = entry->getArgument(0); + Value channel = entry->getArgument(1); + + Block *loopHead = new Block(); + Block *done = new Block(); + func.getBody().push_back(loopHead); + func.getBody().push_back(done); + loopHead->addArgument(i32Ty, getLoc()); + + LLVM::BrOp::create(builder, getLoc(), ValueRange{constI32(0)}, loopHead); + + builder.setInsertionPointToStart(loopHead); + Value count = loopHead->getArgument(0); + Value ts = rtioInputTimestamp(deadline, channel); + Value more = + arith::CmpIOp::create(builder, getLoc(), arith::CmpIPredicate::sgt, ts, constI64(-1)); + Value countNext = arith::AddIOp::create(builder, getLoc(), count, constI32(1)); + LLVM::CondBrOp::create(builder, getLoc(), more, loopHead, ValueRange{countNext}, done, + ValueRange{}); + + builder.setInsertionPointToStart(done); + LLVM::ReturnOp::create(builder, getLoc(), ValueRange{count}); + } + + /// Busy-wait until the counter >= time + void ensureWaitUntilMuFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::waitUntilMu)) + return; + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i64Ty}); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::waitUntilMu, funcTy, + LLVM::Linkage::Internal); + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + Value time = entry->getArgument(0); + + Block *loopHead = new Block(); + Block *exit = new Block(); + func.getBody().push_back(loopHead); + func.getBody().push_back(exit); + + LLVM::BrOp::create(builder, getLoc(), ValueRange{}, loopHead); + + builder.setInsertionPointToStart(loopHead); + Value counter = rtioGetCounter(); + Value cond = + arith::CmpIOp::create(builder, getLoc(), arith::CmpIPredicate::slt, counter, time); + LLVM::CondBrOp::create(builder, getLoc(), cond, loopHead, ValueRange{}, exit, ValueRange{}); + + builder.setInsertionPointToStart(exit); + LLVM::ReturnOp::create(builder, getLoc(), ValueRange{}); + } }; } // namespace rtio diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp index 14c1928073..114972ecc6 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp @@ -20,9 +20,10 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -128,10 +129,9 @@ class PulseScheduler { // return the next events to process SmallVector processEvent(Value event) { - SmallVector nextEvents; auto consumers = getEventConsumers(event); if (consumers.empty()) { - return nextEvents; + return {}; } // Group pulses by channel, respecting grouping predicate @@ -160,7 +160,7 @@ class PulseScheduler { } if (channelPulses.empty()) { - return nextEvents; + return {}; } // Extend chains on each channel @@ -350,6 +350,7 @@ void decomposeFrequencyPulses(ScheduleGroupsMap &pulseGroups) } // Find root pulses (pulses whose wait isn't produced by another pulse in this group) + // And skip `_measurement` pulses. DenseMap channelRoots; for (auto *op : groupOps) { auto pulse = cast(op); @@ -359,7 +360,7 @@ void decomposeFrequencyPulses(ScheduleGroupsMap &pulseGroups) return cast(other).getEvent() == wait; }); - if (isRoot) { + if (isRoot && !pulse->hasAttr("_measurement")) { Value channel = pulse.getChannel(); if (!channelRoots.count(channel)) { channelRoots[channel] = pulse; @@ -442,9 +443,6 @@ void decomposeFrequencyPulses(ScheduleGroupsMap &pulseGroups) struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase { using RTIOEventToARTIQPassBase::RTIOEventToARTIQPassBase; - // id -> callee name mapping - llvm::ArrayRef> getRPCIdMap() const { return rpcIdMap; } - void runOnOperation() override { ModuleOp module = getOperation(); @@ -454,7 +452,7 @@ struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase pulseGroups; module.walk([&](func::FuncOp funcOp) { - PulseScheduler scheduler(funcOp, builder, sameChannelSameFrequency); + PulseScheduler scheduler(funcOp, builder, canGroup); pulseGroups[funcOp] = scheduler.schedule(); }); @@ -491,8 +489,13 @@ struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase callee name. - llvm::SmallVector> rpcIdMap; - - static bool sameChannelSameFrequency(RTIOPulseOp ref, RTIOPulseOp candidate) + static bool canGroup(RTIOPulseOp ref, RTIOPulseOp candidate) { + // If either pulse is a measurement, they cannot be grouped + if (ref->hasAttr("_measurement") || candidate->hasAttr("_measurement")) { + return false; + } + + // And only group pulses on the same channel and frequency if (ref.getChannel() == candidate.getChannel()) { return ref.getFrequency() == candidate.getFrequency(); } @@ -553,6 +559,91 @@ struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase(ARTIQFuncNames::kernel); + if (!kernelFunc) { + return success(); + } + + auto transferFunc = + module.lookupSymbol(ARTIQFuncNames::rtioTransferMeasurementResults); + auto initFunc = module.lookupSymbol(ARTIQFuncNames::rtioInitDataset); + + if (!transferFunc) { + return success(); + } + + for (Block &block : kernelFunc.getBody()) { + Operation *terminator = block.getTerminator(); + if (!terminator || !isa(terminator)) { + continue; + } + + Location loc = kernelFunc.getLoc(); + OpBuilder::InsertionGuard guard(builder); + + // 1. call @__rtio_init_dataset at the start + if (initFunc) { + builder.setInsertionPointToStart(&block); + func::CallOp::create(builder, loc, initFunc, ValueRange{}); + } + + // Collect rtio.readout ops in block order (before they become __rtio_count). + SmallVector readouts; + for (Operation &op : block) { + if (auto readout = dyn_cast(&op)) { + readouts.push_back(readout); + } + } + + if (readouts.empty()) { + continue; + } + + auto memrefType = cast(transferFunc.getArgumentTypes()[0]); + + // Allocate the results buffer at the top of the block + builder.setInsertionPointToStart(&block); + Value alloc = memref::AllocaOp::create(builder, loc, memrefType); + + // 2. Store each readout count right after its defining op + for (auto [i, readout] : llvm::enumerate(readouts)) { + builder.setInsertionPointAfter(readout); + Value idx = arith::ConstantIndexOp::create(builder, loc, static_cast(i)); + memref::StoreOp::create(builder, loc, readout.getCount(), alloc, ValueRange{idx}); + } + + // 3. call @__rtio_transfer_measurement_results + wait before return + builder.setInsertionPoint(terminator); + func::CallOp::create(builder, loc, transferFunc, ValueRange{alloc}); + + // 4. wait_until_mu(now_mu()) so async RPCs flush before return + ARTIQRuntimeBuilder artiq(builder, kernelFunc); + artiq.waitUntilMu(artiq.nowMu()); + } + return success(); } diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp index 8c9a8e0917..8a6b83fa7e 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" @@ -41,6 +43,10 @@ struct PulseOpLowering : public OpConversionPattern { { ARTIQRuntimeBuilder artiq(rewriter, op); + if (op->hasAttr("_measurement")) { + return lowerMeasurementPulse(op, adaptor, rewriter, artiq); + } + // Set timeline position artiq.atMu(adaptor.getWait()); @@ -102,6 +108,51 @@ struct PulseOpLowering : public OpConversionPattern { rewriter.replaceOp(op, newTime); return success(); } + + LogicalResult lowerMeasurementPulse(RTIOPulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + ARTIQRuntimeBuilder &artiq) const + { + ModuleOp mod = op->getParentOfType(); + Location loc = op.getLoc(); + + MeasurementChannelAddrs ch = ARTIQRuntimeBuilder::getMeasurementChannelAddresses(mod); + if (ch.gateRisingAddr == 0 || ch.dmdOutputAddr == 0) { + return op->emitError( + "measurement lowering requires device_db ttl0 (TTLInOut) and ttl7 (TTLOut)"); + } + + Value startMu = adaptor.getWait(); + artiq.atMu(startMu); + + // 1. gate_rising_mu(sensitivity_addr, duration_mu) + Value gateDurationMu = artiq.secToMu(adaptor.getDuration()); + artiq.gateRisingMu(artiq.constI32(ch.gateRisingAddr), gateDurationMu); + + // 2. Compute Poisson-distributed photon count for mock_measure + int32_t photonCount = 0; + if (auto cst = adaptor.getDuration().getDefiningOp()) { + if (auto fAttr = dyn_cast(cst.getValue())) { + double durationSec = fAttr.getValueAsDouble(); + int64_t durationMu = static_cast(std::round(durationSec / 1e-9)); + photonCount = ARTIQHardwareConfig::computeSimulatedPhotonCount(durationMu); + } + } + + // 3. mock_measure(start_mu, dmd_output_addr, photon_count) + artiq.mockMeasure(startMu, artiq.constI32(ch.dmdOutputAddr), + artiq.constI64(static_cast(photonCount))); + + // 4. Advance timeline to gateEnd + offset for underflow protection. + Value gateEnd = arith::AddIOp::create(rewriter, loc, startMu, gateDurationMu); + Value nextBase = arith::AddIOp::create( + rewriter, loc, gateEnd, artiq.constI64(ARTIQHardwareConfig::measurementStartOffsetMu)); + artiq.atMu(nextBase); + + Value newTime = artiq.nowMu(); + rewriter.replaceOp(op, newTime); + return success(); + } }; struct SyncOpLowering : public OpConversionPattern { @@ -157,6 +208,34 @@ struct ChannelOpLowering : public OpConversionPattern { } }; +struct ReadoutOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RTIOReadoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + ModuleOp mod = op->getParentOfType(); + Location loc = op.getLoc(); + + MeasurementChannelAddrs ch = ARTIQRuntimeBuilder::getMeasurementChannelAddresses(mod); + ARTIQRuntimeBuilder artiq(rewriter, op); + Value eventMu = adaptor.getEvent(); + + // The measurement pulse event = gateEnd + offset. + // We subtract the offset and add the gate latency to get the safer execution timing. + Value gateEnd = arith::SubIOp::create( + rewriter, loc, eventMu, artiq.constI64(ARTIQHardwareConfig::measurementStartOffsetMu)); + artiq.atMu(gateEnd); + + Value gateLatencyMu = artiq.constI64(ARTIQHardwareConfig::defaultTtlInOutGateLatencyMu); + Value deadline = arith::AddIOp::create(rewriter, loc, gateEnd, gateLatencyMu); + Value count = artiq.rtioCount(deadline, artiq.constI32(ch.countChannel)); + + rewriter.replaceOp(op, count); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Rewrite Patterns //===----------------------------------------------------------------------===// @@ -283,8 +362,8 @@ namespace rtio { void populateRTIOToARTIQConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); } void populateRTIORewritePatterns(RewritePatternSet &patterns) diff --git a/mlir/lib/RTIO/Transforms/Utils.hpp b/mlir/lib/RTIO/Transforms/Utils.hpp index ce2f84a84a..a6d0459c3f 100644 --- a/mlir/lib/RTIO/Transforms/Utils.hpp +++ b/mlir/lib/RTIO/Transforms/Utils.hpp @@ -15,6 +15,7 @@ #pragma once #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "RTIO/IR/RTIOOps.h" @@ -22,6 +23,49 @@ namespace catalyst { namespace rtio { +/// Helpers for reading nested keys from module rtio.config (device_db JSON). +namespace device_db_detail { + +/// Descend one level in nested DictionaryAttr / ConfigAttr. +inline mlir::Attribute descendByKey(mlir::Attribute parent, llvm::StringRef key) +{ + if (!parent) { + return {}; + } + if (auto dict = mlir::dyn_cast(parent)) { + return dict.get(key); + } + if (auto cfg = mlir::dyn_cast(parent)) { + return cfg.get(key); + } + return {}; +} + +/// Follow a chain of keys from root. Returns null if any step is missing or not a container. +inline mlir::Attribute walkAttrPath(mlir::Attribute root, llvm::ArrayRef path) +{ + mlir::Attribute current = root; + for (llvm::StringRef key : path) { + current = descendByKey(current, key); + if (!current) { + return {}; + } + } + return current; +} + +/// Integer at path +inline int64_t intAtPath(mlir::Attribute root, llvm::ArrayRef path) +{ + mlir::Attribute leaf = walkAttrPath(root, path); + if (auto intAttr = mlir::dyn_cast_or_null(leaf)) { + return intAttr.getInt(); + } + return 0; +} + +} // namespace device_db_detail + /// Extract the static channel ID from an RTIO channel type. inline int32_t extractChannelId(mlir::Value channelValue) { @@ -40,17 +84,10 @@ inline mlir::Value computeChannelDeviceAddr(mlir::OpBuilder &builder, mlir::Oper auto configAttr = mod->getAttrOfType(ConfigAttr::getModuleAttrName()); assert(configAttr && "configAttr not found"); - // Get base channel from config - mlir::Attribute current = configAttr; - for (llvm::StringRef key : {"device_db", "ttl_urukul0_sw0", "arguments", "channel"}) { - if (auto dict = mlir::dyn_cast(current)) { - current = dict.get(key); - } - else if (auto cfg = mlir::dyn_cast(current)) { - current = cfg.get(key); - } - } - int64_t channelBase = mlir::cast(current).getInt(); + mlir::Attribute leaf = device_db_detail::walkAttrPath( + configAttr, {"device_db", "ttl_urukul0_sw0", "arguments", "channel"}); + assert(leaf && "device_db.ttl_urukul0_sw0.arguments.channel missing"); + int64_t channelBase = mlir::cast(leaf).getInt(); llvm::APInt channelIdAPInt; assert(mlir::matchPattern(channelValue, mlir::m_ConstantInt(&channelIdAPInt)) && diff --git a/mlir/test/Ion/Dialect.mlir b/mlir/test/Ion/Dialect.mlir index e57616edfc..f8bdfc1cd2 100644 --- a/mlir/test/Ion/Dialect.mlir +++ b/mlir/test/Ion/Dialect.mlir @@ -159,7 +159,7 @@ func.func @example_parallel_protocol_two_qubits(%arg0: f64) -> (!ion.qubit, !ion return %3#0, %3#1: !ion.qubit, !ion.qubit } -func.func @example_measure_pulse(%arg0: f64) -> i1 { +func.func @example_measure_pulse(%arg0: f64) -> i32 { %0 = quantum.alloc( 1) : !quantum.reg // CHECK: [[q0:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit @@ -192,11 +192,11 @@ func.func @example_measure_pulse(%arg0: f64) -> i1 { ion.yield %arg1: !ion.qubit } - // CHECK: [[mres:%.+]], [[out_qubit:%.+]] = ion.readout_bit [[pp]] : i1, !ion.qubit - %mres, %out_qubit = ion.readout_bit %pp : i1, !ion.qubit + // CHECK: [[out_qubit:%.+]], [[cnt:%.+]] = ion.readout_bit [[pp]] : !ion.qubit, i32 + %out_qubit, %cnt_val = ion.readout_bit %pp : !ion.qubit, i32 - // CHECK: return [[mres]] : i1 - return %mres: i1 + // CHECK: return [[cnt]] : i32 + return %cnt_val: i32 } diff --git a/mlir/test/Ion/IonToRTIO.mlir b/mlir/test/Ion/IonToRTIO.mlir index 9399a5d0fa..e4cd49d1d2 100644 --- a/mlir/test/Ion/IonToRTIO.mlir +++ b/mlir/test/Ion/IonToRTIO.mlir @@ -650,3 +650,46 @@ module @if_circuit { return %alloc_9, %alloc : memref<4xi64>, memref<4xi64> } } + +// ----- + +// measure_pulse +// CHECK-LABEL: func.func @__kernel__() +module @measure_ion_to_rtio { + func.func public @circuit_measure() -> i32 attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage, qnode} { + %0 = ion.ion {charge = -1.000000e+00 : f64, levels = [#ion.level