From c73350ca900e5b162d57bf5110c007d1d9de91ef Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 7 Apr 2026 16:52:04 -0400 Subject: [PATCH 1/2] Update RTIO schedule for OQD's AWG arch --- mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp | 292 +++++------ mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 8 + mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp | 459 +++--------------- .../Transforms/RTIOEventToARTIQPatterns.cpp | 24 +- mlir/lib/RTIO/Transforms/Utils.hpp | 26 + mlir/test/Ion/IonToRTIO.mlir | 46 +- mlir/test/RTIO/RTIOEventToARTIQ.mlir | 40 +- 7 files changed, 286 insertions(+), 609 deletions(-) diff --git a/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp b/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp index 0c160850f2..06be3b9b3e 100644 --- a/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp +++ b/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" @@ -33,6 +35,9 @@ namespace ion { namespace { +constexpr StringLiteral kPulseGroupAttr = "_group"; +constexpr StringLiteral kParallelProtocolIdAttr = "parallel_protocol_id"; + //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// @@ -46,6 +51,82 @@ Value awaitEvents(ArrayRef events, PatternRewriter &rewriter) return rtio::RTIOSyncOp::create(rewriter, rewriter.getUnknownLoc(), eventType, events); } +/// Extract the qubit index from a memref.load value. +/// For given `memref.load @qubit_map[%cN]`, returns `N`. +static FailureOr getQubitIndex(Value memrefLoadValue) +{ + auto loadOp = memrefLoadValue.getDefiningOp(); + if (!loadOp || loadOp.getIndices().size() != 1) { + return failure(); + } + IntegerAttr indexAttr; + if (!matchPattern(loadOp.getIndices()[0], m_Constant(&indexAttr))) { + return failure(); + } + return indexAttr.getInt(); +} + +static std::optional getConstF64(Value v) +{ + FloatAttr attr; + if (matchPattern(v, m_Constant(&attr))) { + return attr.getValueAsDouble(); + } + return std::nullopt; +} + +/// Find a pulse with the same (frequency, phase) tone. +static rtio::RTIOPulseOp findSameTonePulse(ArrayRef pulses, + rtio::RTIOPulseOp pulse) +{ + auto f = getConstF64(pulse.getFrequency()); + auto p = getConstF64(pulse.getPhase()); + if (!f || !p) { + return nullptr; + } + + auto found = llvm::find_if(pulses, [=](rtio::RTIOPulseOp target) { + return getConstF64(target.getFrequency()) == f && getConstF64(target.getPhase()) == p; + }); + return found != pulses.end() ? *found : nullptr; +} + +/// Merge qubit qualifiers from src into dst. +static void mergeChannelQualifiers(rtio::RTIOPulseOp dst, rtio::RTIOPulseOp src, + PatternRewriter &rewriter, MLIRContext *ctx, Location loc) +{ + auto dstCh = llvm::cast(dst.getChannel().getType()); + auto srcCh = llvm::cast(src.getChannel().getType()); + + SetVector qubits; + + // Merge qualifiers from dst and src + for (auto q : dstCh.getQualifiers()) { + qubits.insert(llvm::cast(q).getInt()); + } + for (auto q : srcCh.getQualifiers()) { + qubits.insert(llvm::cast(q).getInt()); + } + + SmallVector quals; + for (int64_t q : qubits) { + quals.push_back(rewriter.getI64IntegerAttr(q)); + } + + auto mergedType = rtio::ChannelType::get(ctx, "dds", rewriter.getArrayAttr(quals), + rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + Value newCh = rtio::RTIOChannelOp::create(rewriter, loc, mergedType); + rewriter.moveOpBefore(newCh.getDefiningOp(), dst); + + Value oldCh = dst.getChannel(); + dst.getChannelMutable().assign(newCh); + + // Remove unused channel op + if (oldCh.getDefiningOp() && oldCh.use_empty()) { + rewriter.eraseOp(oldCh.getDefiningOp()); + } +} + //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// @@ -85,6 +166,10 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPatterngetAttrOfType(kParallelProtocolIdAttr); + assert(protocolIdAttr && "parallel protocol must have parallel protocol id"); + int64_t protocolId = protocolIdAttr.getInt(); + Block *regionBlock = &op.getBodyRegion().front(); IRMapping irMapping; SmallVector inQubits; @@ -106,68 +191,46 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern pulseEvents; - DenseMap qubitToOffset; - - // we cache the channel to index mapping to avoid multiple lookups - DenseMap cache; + // Clone all operations from the region + SmallVector clonedPulses; for (auto ®ionOp : regionBlock->without_terminator()) { auto *clonedOp = rewriter.clone(regionOp, irMapping); - if (auto pulseOp = dyn_cast(clonedOp)) { - // set wait event for the pulse operation - pulseOp.setWait(inputSyncEvent); - - Value index = nullptr; - - SmallVector chain; - traceValueWithCallback( - pulseOp.getChannel(), [&](Value value) -> WalkResult { - if (cache.count(value)) { - index = cache[value]; - return WalkResult::interrupt(); - } - chain.push_back(value); - if (auto loadOp = - llvm::dyn_cast_if_present(value.getDefiningOp())) { - index = loadOp.getIndices()[0]; - - // cache the channel to index mapping - cache[pulseOp.getChannel()] = index; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - - if (index == nullptr) { - op->emitError("Failed to trace the channel index"); - return failure(); - } + if (auto pulseOp = dyn_cast(clonedOp)) + clonedPulses.push_back(pulseOp); + irMapping.map(regionOp.getResults(), clonedOp->getResults()); + } - // update cache - for (Value value : chain) { - cache[value] = index; + // Merge same-tone pulses + SmallVector survivors; + for (auto pulse : clonedPulses) { + if (auto match = findSameTonePulse(survivors, pulse)) { + // Merge channel qualifiers to `match`, and erase `pulse` + mergeChannelQualifiers(match, pulse, rewriter, ctx, loc); + Value ch = pulse.getChannel(); + rewriter.eraseOp(pulse); + if (ch.getDefiningOp() && ch.use_empty()) { + rewriter.eraseOp(ch.getDefiningOp()); } - pulseOp->setAttr("offset", rewriter.getI64IntegerAttr(qubitToOffset[index])); - - // the same qubit may appear multiple times in the parallel protocol - // so we need to increment the offset for each appearance - qubitToOffset[index]++; - - pulseEvents.push_back(pulseOp.getEvent()); } - irMapping.map(regionOp.getResults(), clonedOp->getResults()); + else { + survivors.push_back(pulse); + } } - // Create sync operation from pulse events (must have at least one after Phase 1) - assert(pulseEvents.size() > 0 && - "must have at least one pulse operation after parallel protocol conversion"); + // Set wait dependency and _group on each surviving pulse + SmallVector pulseEvents; + for (auto pulse : survivors) { + pulse.setWait(inputSyncEvent); + pulse->setAttr(kPulseGroupAttr, rewriter.getI64IntegerAttr(protocolId)); + pulseEvents.push_back(pulse.getEvent()); + } + assert(!pulseEvents.empty() && + "must have at least one pulse after parallel protocol conversion"); Value outputSyncEvent = awaitEvents(llvm::to_vector(pulseEvents), rewriter); SmallVector results; for (Value result : op.getResults()) { - // unrealized conversion cast sync event to result type auto event = UnrealizedConversionCastOp::create(rewriter, loc, result.getType(), outputSyncEvent); results.push_back(event.getResult(0)); @@ -188,10 +251,9 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern !rtio.channel<"dds", ?> -/// ... // other pulse parameters settings +/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds", [N : i64], 0> /// %event = rtio.pulse %ch duration(%duration) frequency(%freq) phase(%phase) -/// : !rtio.channel<"dds", ?> -> !rtio.event +/// : !rtio.channel<"dds", [N : i64], 0> -> !rtio.event /// ``` struct PulseToRTIOPattern : public OpConversionPattern { IonInfo ionInfo; @@ -240,10 +302,7 @@ struct PulseToRTIOPattern : public OpConversionPattern { Value phaseValue = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(phase)); - // Convert the qubit to a channel - ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)}); - auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, nullptr); - + // Resolve qubit index and set it as the qualifier Value memrefLoadValue = nullptr; traceValueWithCallback(op.getInQubit(), [&](Value value) -> WalkResult { if (qextractToMemrefMap.count(value)) { @@ -258,8 +317,16 @@ struct PulseToRTIOPattern : public OpConversionPattern { return failure(); } - Value channel = - rtio::RTIOQubitToChannelOp::create(rewriter, loc, channelType, memrefLoadValue); + auto qubitIdx = getQubitIndex(memrefLoadValue); + if (failed(qubitIdx)) { + op->emitError("Failed to resolve qubit index from memref load"); + return failure(); + } + ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(*qubitIdx)}); + + IntegerAttr channelIdAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); + auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, channelIdAttr); + Value channel = rtio::RTIOChannelOp::create(rewriter, loc, channelType); // Create rtio.pulse auto eventType = rtio::EventType::get(ctx); @@ -271,108 +338,6 @@ struct PulseToRTIOPattern : public OpConversionPattern { } }; -/// Resolve the static channel mapping for the rtio.qubit_to_channel operation -/// -/// It's expecting `qubit_to_channel` has the following def-use chain: -/// memref.global w/ constants -> memref.get_global -> memref.load -> qubit_to_channel -/// -/// Example: -/// ``` -/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> -/// ``` -/// will be converted to: -/// ``` -/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds"> -/// ``` -struct ResolveChannelMappingPattern : public OpRewritePattern { - ResolveChannelMappingPattern(MLIRContext *ctx) - : OpRewritePattern(ctx) - { - } - - LogicalResult matchAndRewrite(rtio::RTIOQubitToChannelOp op, - PatternRewriter &rewriter) const override - { - Location loc = op.getLoc(); - Value qubit = op.getQubit(); - - auto loadOp = qubit.getDefiningOp(); - if (!loadOp) { - return failure(); - } - - Value memref = loadOp.getMemRef(); - auto getGlobalOp = memref.getDefiningOp(); - if (!getGlobalOp) { - return failure(); - } - - StringRef globalName = getGlobalOp.getName(); - ModuleOp module = op->getParentOfType(); - if (!module) { - return failure(); - } - auto globalOp = module.lookupSymbol(globalName); - if (!globalOp) { - return failure(); - } - - auto initialValue = globalOp.getInitialValue(); - if (!initialValue) { - return failure(); - } - - auto denseAttr = llvm::dyn_cast(*initialValue); - if (!denseAttr) { - return failure(); - } - - ValueRange indices = loadOp.getIndices(); - if (indices.size() != 1) { - return failure(); - } - - IntegerAttr indexAttr; - if (!matchPattern(indices[0], m_Constant(&indexAttr))) { - return failure(); - } - - int64_t index = indexAttr.getInt(); - - size_t denseSize = denseAttr.size(); - if (index < 0 || static_cast(index) >= denseSize) { - return failure(); - } - - APInt channelIdValue = denseAttr.getValues()[index]; - - auto originalChannelType = llvm::dyn_cast(op.getChannel().getType()); - if (!originalChannelType) { - return failure(); - } - StringRef kind = originalChannelType.getKind(); - ArrayAttr qualifiers = originalChannelType.getQualifiers(); - - // channel should have exactly one use before lowering to channel op - assert(op.getChannel().hasOneUse() && "channel should have exactly one use"); - - auto pulseOp = cast(*op.getChannel().getUsers().begin()); - int64_t offset = cast(pulseOp->getAttr("offset")).getInt(); - - IntegerAttr channelIdAttr = rewriter.getIntegerAttr( - rewriter.getIndexType(), (channelIdValue.getSExtValue() * 2 + offset)); - - auto resolvedChannelType = - rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr); - - Value channel = rtio::RTIOChannelOp::create(rewriter, loc, resolvedChannelType); - - rewriter.replaceOp(op, channel); - - return success(); - } -}; - /// Propagates RTIO events from chain of operations to event types. /// /// Steps: @@ -493,8 +458,7 @@ struct MeasurePulseToRTIOPattern : public OpConversionPattern(patterns.getContext()); - patterns.add(patterns.getContext()); + // patterns.add(patterns.getContext()); } void populateIonMeasurePulseToRTIOPatterns(TypeConverter &typeConverter, diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 1dd8b1b183..ee0acf4e31 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/CSE.h" @@ -46,6 +47,7 @@ namespace { constexpr StringLiteral rtioTransferMeasurementResults = "__rtio_transfer_measurement_results"; constexpr StringLiteral rtioInitDataset = "__rtio_init_dataset"; +constexpr StringLiteral kParallelProtocolIdAttr = "parallel_protocol_id"; /// Load a JSON file and convert it to an rtio.config attribute FailureOr loadDeviceDbAsConfig(MLIRContext *ctx, StringRef filePath) @@ -131,6 +133,12 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { LogicalResult ParallelProtocolConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, TypeConverter &typeConverter, MLIRContext *ctx) { + int64_t nextProtocolId = 0; + funcOp.walk([&](ion::ParallelProtocolOp op) { + op->setAttr(kParallelProtocolIdAttr, + IntegerAttr::get(IndexType::get(ctx), nextProtocolId++)); + }); + ConversionTarget target(baseTarget); target.addIllegalOp(); diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp index 114972ecc6..d1638e711b 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp @@ -14,6 +14,8 @@ #include +#include "llvm/ADT/STLExtras.h" + #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -46,394 +48,92 @@ namespace rtio { namespace { -//===----------------------------------------------------------------------===// -// Type Aliases -//===----------------------------------------------------------------------===// - -using ScheduleGroupsMap = DenseMap>; -using GroupingPredicate = - std::function; - -//===----------------------------------------------------------------------===// -// Pulse Scheduling -//===----------------------------------------------------------------------===// - -class PulseScheduler { - public: - PulseScheduler(func::FuncOp funcOp, OpBuilder &builder, GroupingPredicate predicate) - : funcOp(funcOp), builder(builder), groupingPredicate(std::move(predicate)) - { - } - - ScheduleGroupsMap schedule() - { - // Collect all pulses - funcOp.walk([&](rtio::RTIOPulseOp pulse) { allPulses.push_back(pulse); }); - - // Build consumer map - for (auto pulse : allPulses) { - if (auto producer = pulse.getWait().getDefiningOp()) { - pulseConsumers[producer].insert(pulse); - } - } - - processFromEmptyOps(); - return std::move(groups); - } - - private: - func::FuncOp funcOp; - OpBuilder &builder; - GroupingPredicate groupingPredicate; - - SmallVector allPulses; - DenseMap> pulseConsumers; - DenseSet processedEvents; - DenseSet processedPulses; - ScheduleGroupsMap groups; - int nextGroupId = 0; - - SmallVector getEventConsumers(Value event) - { - SmallVector consumers; - for (Operation *user : event.getUsers()) { - auto pulse = dyn_cast(user); - if (!pulse || pulse.getWait() != event) { - continue; - } - consumers.push_back(pulse); - } - return consumers; - } - - void processFromEmptyOps() - { - std::deque worklist; - funcOp.walk([&](rtio::RTIOEmptyOp emptyOp) { worklist.push_back(emptyOp.getResult()); }); - - while (!worklist.empty()) { - Value event = worklist.front(); - worklist.pop_front(); - - // check if event has already been processed - // if not, process the event and insert it into the processed events - if (!processedEvents.insert(event).second) { - continue; - } - - SmallVector newEvents = processEvent(event); - llvm::append_range(worklist, newEvents); - } - } - - // return the next events to process - SmallVector processEvent(Value event) - { - auto consumers = getEventConsumers(event); - if (consumers.empty()) { - return {}; - } - - // Group pulses by channel, respecting grouping predicate - DenseMap> channelPulses; - DenseMap channelLastPulse; - DenseMap channelBoundary; - SmallVector boundaryConsumers; - - // Initial - for (auto pulse : consumers) { - if (processedPulses.contains(pulse)) { - continue; - } - - int32_t channel = extractChannelId(pulse.getChannel()); - if (canJoinGroup(pulse, channelPulses)) { - channelPulses[channel].push_back(pulse); - channelLastPulse[channel] = pulse; - } - else { - if (!channelBoundary.count(channel)) { - channelBoundary[channel] = pulse; - } - boundaryConsumers.push_back(pulse); - } - } - - if (channelPulses.empty()) { - return {}; - } - - // Extend chains on each channel - extendChannelChains(channelPulses, channelLastPulse, channelBoundary); - - // Record group - recordGroup(channelPulses); - - // Create sync and update dependencies - return createSyncAndUpdateDeps(channelPulses, channelLastPulse, channelBoundary, - boundaryConsumers); - } - - bool canJoinGroup(rtio::RTIOPulseOp cand, - const DenseMap> &channelPulses) - { - for (auto &[ch, pulses] : channelPulses) { - if (!llvm::all_of(pulses, [&](auto pulse) { return groupingPredicate(pulse, cand); })) { - return false; - } - } - return true; - } - - void extendChannelChains(DenseMap> &channelPulses, - DenseMap &channelLastPulse, - DenseMap &channelBoundary) - { - DenseSet stopped; - - while (stopped.size() < channelPulses.size()) { - for (auto &[channel, pulses] : channelPulses) { - if (stopped.contains(channel)) { - continue; - } - - auto currentPulse = channelLastPulse[channel]; - processedPulses.insert(currentPulse); - - bool foundNext = false; - for (auto user : pulseConsumers[currentPulse]) { - int32_t userChannel = extractChannelId(user.getChannel()); - if (userChannel != channel || processedPulses.contains(user)) { - continue; - } - - if (groupingPredicate(currentPulse, user)) { - channelPulses[channel].push_back(user); - channelLastPulse[channel] = user; - } - else { - channelBoundary[channel] = user; - stopped.insert(channel); - } - foundNext = true; - break; - } - - if (!foundNext) { - stopped.insert(channel); - } - } - } - } - - void recordGroup(const DenseMap> &channelPulses) - { - int groupId = nextGroupId++; - auto &groupOps = groups[groupId]; - for (auto &[_, pulses] : channelPulses) { - for (auto pulse : pulses) { - groupOps.insert(pulse.getOperation()); - } - } +/// Schedule pulses for executing on ARTIQ +static void schedule(func::FuncOp funcOp, OpBuilder &builder) +{ + auto eventType = rtio::EventType::get(funcOp.getContext()); + if (funcOp.getBody().empty()) { + return; } + Block &block = funcOp.getBody().front(); - SmallVector - createSyncAndUpdateDeps(const DenseMap> &channelPulses, - DenseMap &channelLastPulse, - DenseMap &channelBoundary, - SmallVector &boundaryConsumers) - { - if (channelPulses.size() > 1 && !channelBoundary.empty()) { - return createSyncEvent(channelLastPulse, channelBoundary, boundaryConsumers); + // Collect pulses by _group + DenseMap> pulseGroups; + for (auto &op : block) { + if (auto p = dyn_cast(&op)) { + pulseGroups[pulseGroupId(p)].push_back(p); } - return collectNextEvents(channelLastPulse, channelBoundary, boundaryConsumers); } - - SmallVector createSyncEvent(DenseMap &channelLastPulse, - DenseMap &channelBoundary, - SmallVector &boundaryConsumers) - { - // Collect events to sync - SmallVector eventsToSync; - for (auto &entry : channelLastPulse) { - rtio::RTIOPulseOp pulse = entry.second; - eventsToSync.push_back(pulse.getEvent()); - } - - auto anyPulse = channelLastPulse.begin()->second; - builder.setInsertionPointAfter(anyPulse); - - auto eventType = rtio::EventType::get(builder.getContext()); - Value syncEvent = - rtio::RTIOSyncOp::create(builder, anyPulse.getLoc(), eventType, eventsToSync); - - // Update boundaries and consumers - for (auto &[_, pulse] : channelBoundary) { - pulse.setWait(syncEvent); - } - for (auto pulse : boundaryConsumers) { - pulse.setWait(syncEvent); - } - for (auto &entry : channelLastPulse) { - rtio::RTIOPulseOp pulse = entry.second; - for (auto user : pulseConsumers[pulse]) { - auto userChannel = extractChannelId(user.getChannel()); - if (!channelBoundary.count(userChannel) || channelBoundary[userChannel] != user) { - if (user.getWait() == pulse.getEvent()) { - user.setWait(syncEvent); - } - } - } - } - - return {syncEvent}; + if (pulseGroups.empty()) { + return; } - SmallVector collectNextEvents(DenseMap &channelLastPulse, - DenseMap &channelBoundary, - SmallVector &boundaryConsumers) - { - SmallVector nextEvents; + DenseSet visited; + DenseSet scheduled; + std::deque worklist; + Value chain; - for (auto &entry : channelBoundary) { - rtio::RTIOPulseOp pulse = entry.second; - nextEvents.push_back(pulse.getWait()); + for (auto &op : block) { + if (auto empty = dyn_cast(&op)) { + chain = empty.getResult(); + worklist.push_back(chain); + break; } - if (!boundaryConsumers.empty() && !channelLastPulse.empty()) { - rtio::RTIOPulseOp firstPulse = channelLastPulse.begin()->second; - Value lastEvent = firstPulse.getEvent(); - for (auto pulse : boundaryConsumers) { - pulse.setWait(lastEvent); - } - nextEvents.push_back(lastEvent); - } - for (auto &entry : channelLastPulse) { - rtio::RTIOPulseOp pulse = entry.second; - for (auto *user : pulse.getEvent().getUsers()) { - if (auto syncOp = dyn_cast(user)) { - nextEvents.push_back(syncOp.getSyncEvent()); - } - } - } - - return nextEvents; } -}; - -//===----------------------------------------------------------------------===// -// Frequency Decomposition -//===----------------------------------------------------------------------===// - -void decomposeFrequencyPulses(ScheduleGroupsMap &pulseGroups) -{ - if (pulseGroups.empty()) { + if (!chain) { return; } - auto firstOp = pulseGroups.begin()->second.front(); - OpBuilder builder(firstOp->getContext()); - - // Track last frequency per channel (to avoid redundant frequency settings) - DenseMap channelLastFreq; - - // Sort groups by ID for deterministic processing - SmallVector *>> sortedGroups; - for (auto &entry : pulseGroups) { - sortedGroups.push_back({entry.first, &entry.second}); - } - llvm::sort(sortedGroups, [](const auto &a, const auto &b) { return a.first < b.first; }); - - for (auto &[groupId, groupOpsPtr] : sortedGroups) { - auto &groupOps = *groupOpsPtr; - if (groupOps.empty()) { + while (!worklist.empty()) { + Value ev = worklist.front(); + worklist.pop_front(); + if (!visited.insert(ev).second) { continue; } - // Find root pulses (pulses whose wait isn't produced by another pulse in this group) - // And skip `_measurement` pulses. - DenseMap channelRoots; - for (auto *op : groupOps) { - auto pulse = cast(op); - Value wait = pulse.getWait(); - - bool isRoot = llvm::none_of(groupOps, [&](Operation *other) { - return cast(other).getEvent() == wait; - }); - - if (isRoot && !pulse->hasAttr("_measurement")) { - Value channel = pulse.getChannel(); - if (!channelRoots.count(channel)) { - channelRoots[channel] = pulse; + SmallVector consumers; + for (auto *user : ev.getUsers()) { + if (auto p = dyn_cast(user)) { + if (p.getWait() == ev && p->getBlock() == &block) { + consumers.push_back(p); } } - } - - if (channelRoots.empty()) { - continue; - } - - // Filter to channels needing frequency change - DenseMap needsFreqSet; - for (auto &entry : channelRoots) { - Value channel = entry.first; - rtio::RTIOPulseOp pulse = entry.second; - Value freq = pulse.getFrequency(); - auto it = channelLastFreq.find(channel); - if (it == channelLastFreq.end() || it->second != freq) { - needsFreqSet[channel] = pulse; - channelLastFreq[channel] = freq; + else if (auto s = dyn_cast(user)) { + if (s->getBlock() == &block) { + worklist.push_back(s.getSyncEvent()); + } } } - if (needsFreqSet.empty()) { - continue; - } - - // Collect original wait events - SmallVector originalWaits; - for (auto &entry : channelRoots) { - rtio::RTIOPulseOp pulse = entry.second; - Value wait = pulse.getWait(); - if (!llvm::is_contained(originalWaits, wait)) { - originalWaits.push_back(wait); + SmallVector newGroupIds; + for (auto p : consumers) { + int64_t gid = pulseGroupId(p); + if (scheduled.insert(gid).second) { + newGroupIds.push_back(gid); } } + llvm::sort(newGroupIds); - // Find first root pulse (for insertion point) - rtio::RTIOPulseOp firstRoot = nullptr; - for (auto &entry : channelRoots) { - rtio::RTIOPulseOp pulse = entry.second; - if (!firstRoot || pulse->isBeforeInBlock(firstRoot)) { - firstRoot = pulse; + for (int64_t gid : newGroupIds) { + auto &grp = pulseGroups[gid]; + for (auto p : grp) { + p.setWait(chain); } - } - builder.setInsertionPoint(firstRoot); - - // Create sync - Value chainStart = originalWaits.size() > 1 - ? rtio::RTIOSyncOp::create( - builder, firstRoot.getLoc(), - rtio::EventType::get(builder.getContext()), originalWaits) - : originalWaits[0]; - - // Create frequency setting chain - Value lastFreqEvent = chainStart; - for (auto &entry : needsFreqSet) { - rtio::RTIOPulseOp originalPulse = entry.second; - auto freqPulse = cast(builder.clone(*originalPulse.getOperation())); - freqPulse.setWait(lastFreqEvent); - freqPulse->setAttr("_frequency", builder.getUnitAttr()); - lastFreqEvent = freqPulse.getEvent(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfter(grp.back()); + SmallVector evts; + for (auto p : grp) { + evts.push_back(p.getEvent()); + } + chain = rtio::RTIOSyncOp::create(builder, grp.back().getLoc(), eventType, evts); } - // Update root pulses to wait on last frequency event - for (auto &entry : channelRoots) { - rtio::RTIOPulseOp pulse = entry.second; - pulse.setWait(lastFreqEvent); - } + for (auto p : consumers) + worklist.push_back(p.getEvent()); } } + } // namespace //===----------------------------------------------------------------------===// @@ -449,12 +149,9 @@ struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase pulseGroups; - module.walk([&](func::FuncOp funcOp) { - PulseScheduler scheduler(funcOp, builder, canGroup); - pulseGroups[funcOp] = scheduler.schedule(); - }); + // Schedule pulses + module.walk([&](func::FuncOp funcOp) { schedule(funcOp, builder); }); + sortAllBlocks(module); // Simplify sync operations { @@ -466,24 +163,6 @@ struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBasehasAttr("_measurement") || candidate->hasAttr("_measurement")) { - return false; - } - - // And only group pulses on the same channel and frequency - if (ref.getChannel() == candidate.getChannel()) { - return ref.getFrequency() == candidate.getFrequency(); - } - return true; - } - static void sortAllBlocks(ModuleOp module) { module.walk([](func::FuncOp funcOp) { diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp index 8a6b83fa7e..7bd618f7ee 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp @@ -70,10 +70,12 @@ struct PulseOpLowering : public OpConversionPattern { return op->emitError("Cannot find ") << ARTIQFuncNames::setFrequency << " function"; } + Type chTy = getTypeConverter()->convertType(op.getChannel().getType()); + Value chVal = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIntegerAttr(chTy, pulseGroupId(op))); Value amplitude = artiq.constF64(1.0); LLVM::CallOp::create(rewriter, op.getLoc(), setFreqFunc, - ValueRange{adaptor.getChannel(), adaptor.getFrequency(), - adaptor.getPhase(), amplitude}); + ValueRange{chVal, adaptor.getFrequency(), adaptor.getPhase(), amplitude}); Value newTime = artiq.nowMu(); rewriter.replaceOp(op, newTime); @@ -93,7 +95,8 @@ struct PulseOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter, ARTIQRuntimeBuilder &artiq) const { - Value channelAddr = computeChannelDeviceAddr(rewriter, op, adaptor.getChannel()); + Value channelAddr = + computeChannelDeviceAddrForId(rewriter, op, pulseGroupId(op)); Value durationMu = artiq.secToMu(adaptor.getDuration()); // Enforce minimum pulse duration to avoid 0 duratoin events @@ -271,15 +274,26 @@ struct DecomposePulsePattern : public OpRewritePattern { } }; -/// Removes redundant transitive dependencies from sync operations +/// Removes redundant syncs: unused result, unary identity, transitive operand redundancy. struct SimplifySyncPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(RTIOSyncOp op, PatternRewriter &rewriter) const override { + // Leftover merge with no consumer (e.g. duplicate sync in the schedule). + if (op.getSyncEvent().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + auto events = op.getEvents(); - if (events.size() <= 1) { + if (events.empty()) return failure(); + + // Simplify unary sync + if (events.size() == 1) { + rewriter.replaceOp(op, events[0]); + return success(); } // Find events that aren't reachable from other events diff --git a/mlir/lib/RTIO/Transforms/Utils.hpp b/mlir/lib/RTIO/Transforms/Utils.hpp index a6d0459c3f..142a9314f7 100644 --- a/mlir/lib/RTIO/Transforms/Utils.hpp +++ b/mlir/lib/RTIO/Transforms/Utils.hpp @@ -75,6 +75,32 @@ inline int32_t extractChannelId(mlir::Value channelValue) return type.getChannelId().getInt(); } +/// Get group id from pulse +inline int64_t pulseGroupId(RTIOPulseOp op) +{ + assert(op->hasAttr("_group") && "pulse must have _group attr"); + auto a = op->getAttrOfType("_group"); + return a.getInt(); +} + +/// Device RTIO address for a static logical channel id (used when `_group` holds the lane id). +inline mlir::Value computeChannelDeviceAddrForId(mlir::OpBuilder &builder, mlir::Operation *op, + int64_t channelId) +{ + mlir::Location loc = op->getLoc(); + mlir::ModuleOp mod = op->getParentOfType(); + auto configAttr = mod->getAttrOfType(ConfigAttr::getModuleAttrName()); + assert(configAttr && "configAttr not found"); + + mlir::Attribute leaf = device_db_detail::walkAttrPath( + configAttr, {"device_db", "ttl_urukul0_sw0", "arguments", "channel"}); + assert(leaf && "device_db.ttl_urukul0_sw0.arguments.channel missing"); + int64_t channelBase = mlir::cast(leaf).getInt(); + + int32_t addr = static_cast((channelId + channelBase) << 8); + return mlir::arith::ConstantOp::create(builder, loc, builder.getI32IntegerAttr(addr)); +} + /// Compute the device address for a given channel value. inline mlir::Value computeChannelDeviceAddr(mlir::OpBuilder &builder, mlir::Operation *op, mlir::Value channelValue) diff --git a/mlir/test/Ion/IonToRTIO.mlir b/mlir/test/Ion/IonToRTIO.mlir index e4cd49d1d2..895c5491cf 100644 --- a/mlir/test/Ion/IonToRTIO.mlir +++ b/mlir/test/Ion/IonToRTIO.mlir @@ -45,8 +45,8 @@ module @circuit { %8 = arith.divf %6, %7 : f64 %9 = builtin.unrealized_conversion_cast %2 : !quantum.bit to !ion.qubit // CHECK: %[[EMPTY:.*]] = rtio.empty : !rtio.event - // CHECK: %[[CH:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> - // CHECK: %[[PULSE:.*]] = rtio.pulse %[[CH]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) wait(%[[EMPTY]]) + // CHECK: %[[CH:.*]] = rtio.channel : !rtio.channel<"dds", [1 : i64], 0> + // CHECK: %[[PULSE:.*]] = rtio.pulse %[[CH]] {{.*}} wait(%[[EMPTY]]) // CHECK: return %10 = ion.parallelprotocol(%9) : !ion.qubit { ^bb0(%arg0: !ion.qubit): @@ -118,7 +118,7 @@ module @cnot_circuit { %10 = builtin.unrealized_conversion_cast %2 : !quantum.bit to !ion.qubit // CHECK: %[[EMPTY:.*]] = rtio.empty : !rtio.event - // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> + // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", [0 : i64], 0> // CHECK: %[[P1:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[EMPTY]]) %11 = ion.parallelprotocol(%10) : !ion.qubit { ^bb0(%arg0: !ion.qubit): @@ -147,14 +147,12 @@ module @cnot_circuit { %20 = builtin.unrealized_conversion_cast %3 : !quantum.bit to !ion.qubit // CHECK: %[[SYNC1:.*]] = rtio.sync %[[P1]], %[[EMPTY]] : !rtio.event - // CHECK: %[[P2:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[SYNC1]]) - // CHECK: %[[CH1:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 1> - // CHECK: %[[P3:.*]] = rtio.pulse %[[CH1]] {{.*}} wait(%[[SYNC1]]) - // CHECK: %[[CH2:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> - // CHECK: %[[P4:.*]] = rtio.pulse %[[CH2]] {{.*}} wait(%[[SYNC1]]) - // CHECK: %[[CH3:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 3> - // CHECK: %[[P5:.*]] = rtio.pulse %[[CH3]] {{.*}} wait(%[[SYNC1]]) - // CHECK: %[[SYNC2:.*]] = rtio.sync %[[P2]], %[[P3]], %[[P4]], %[[P5]] : !rtio.event + // CHECK: rtio.pulse {{.*}} wait(%[[SYNC1]]) {_group = 1 + // CHECK: rtio.pulse {{.*}} wait(%[[SYNC1]]) {_group = 1 + // CHECK: %[[CH1:.*]] = rtio.channel : !rtio.channel<"dds", [1 : i64], 0> + // CHECK: rtio.pulse {{.*}} wait(%[[SYNC1]]) {_group = 1 + // CHECK: rtio.pulse {{.*}} wait(%[[SYNC1]]) {_group = 1 + // CHECK: %[[SYNC2:.*]] = rtio.sync {{.*}} : !rtio.event %21:2 = ion.parallelprotocol(%19, %20) : !ion.qubit, !ion.qubit { ^bb0(%arg0: !ion.qubit, %arg1: !ion.qubit): %54 = ion.pulse(%18 : f64) %arg0 {beam = #ion.beam, phase = 0.000000e+00 : f64} : !ion.pulse @@ -185,7 +183,7 @@ module @cnot_circuit { %29 = arith.divf %27, %28 : f64 %30 = builtin.unrealized_conversion_cast %22 : !quantum.bit to !ion.qubit - // CHECK: %[[P6:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[SYNC2]]) + // CHECK: %[[P6:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[SYNC2]]) {_group = 2 %31 = ion.parallelprotocol(%30) : !ion.qubit { ^bb0(%arg0: !ion.qubit): %54 = ion.pulse(%29 : f64) %arg0 {beam = #ion.beam, phase = 0.000000e+00 : f64} : !ion.pulse @@ -211,7 +209,7 @@ module @cnot_circuit { %38 = arith.divf %36, %37 : f64 %39 = builtin.unrealized_conversion_cast %23 : !quantum.bit to !ion.qubit - // CHECK: %[[P7:.*]] = rtio.pulse %[[CH2]] {{.*}} wait(%[[SYNC2]]) + // CHECK: %[[P7:.*]] = rtio.pulse %[[CH1]] {{.*}} wait(%[[SYNC2]]) {_group = 3 %40 = ion.parallelprotocol(%39) : !ion.qubit { ^bb0(%arg0: !ion.qubit): %54 = ion.pulse(%38 : f64) %arg0 {beam = #ion.beam, phase = 1.5707963267948966 : f64} : !ion.pulse @@ -237,7 +235,7 @@ module @cnot_circuit { %47 = arith.divf %45, %46 : f64 %48 = builtin.unrealized_conversion_cast %32 : !quantum.bit to !ion.qubit - // CHECK: %[[P8:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[P6]]) + // CHECK: %[[P8:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[P6]]) {_group = 4 %49 = ion.parallelprotocol(%48) : !ion.qubit { ^bb0(%arg0: !ion.qubit): %54 = ion.pulse(%47 : f64) %arg0 {beam = #ion.beam, phase = 1.5707963267948966 : f64} : !ion.pulse @@ -293,7 +291,7 @@ module @sequential_circuit { %ion_q0 = builtin.unrealized_conversion_cast %q0 : !quantum.bit to !ion.qubit // CHECK: %[[EMPTY:.*]] = rtio.empty : !rtio.event - // CHECK: %[[CH:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> + // CHECK: %[[CH:.*]] = rtio.channel : !rtio.channel<"dds", [0 : i64], 0> // CHECK: %[[PULSE1:.*]] = rtio.pulse %[[CH]] {{.*}} wait(%[[EMPTY]]) %out1 = ion.parallelprotocol(%ion_q0) : !ion.qubit { ^bb0(%arg0: !ion.qubit): @@ -356,7 +354,7 @@ module @loop_circuit { // Gate before the loop // CHECK: %[[EMPTY:.*]] = rtio.empty : !rtio.event - // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> + // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", [0 : i64], 0> // CHECK: %[[P0:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[EMPTY]]) %10 = ion.parallelprotocol(%9) : !ion.qubit { ^bb0(%arg0: !ion.qubit): @@ -369,8 +367,8 @@ module @loop_circuit { // Loop // CHECK: %[[LOOP:.*]] = scf.for {{.*}} iter_args(%[[ARG:.*]] = %[[P0]]) -> (!rtio.event) { - // CHECK: %[[CH2_LOOP:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> - // CHECK: %[[P_Q1:.*]] = rtio.pulse %[[CH2_LOOP]] {{.*}} wait(%[[ARG]]) + // CHECK: %[[CH1_LOOP:.*]] = rtio.channel : !rtio.channel<"dds", [1 : i64], 0> + // CHECK: %[[P_Q1:.*]] = rtio.pulse %[[CH1_LOOP]] {{.*}} wait(%[[ARG]]) // CHECK: %[[P_Q0:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[ARG]]) // CHECK: %[[SYNC_LOOP:.*]] = rtio.sync %[[P_Q0]], %[[P_Q1]] : !rtio.event // CHECK: scf.yield %[[SYNC_LOOP]] : !rtio.event @@ -522,7 +520,7 @@ module @if_circuit { // Gate before the if // CHECK: %[[EMPTY:.*]] = rtio.empty : !rtio.event - // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> + // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", [0 : i64], 0> // CHECK: %[[P0:.*]] = rtio.pulse %[[CH0]] {{.*}} wait(%[[EMPTY]]) %11 = ion.parallelprotocol(%10) : !ion.qubit { ^bb0(%arg1: !ion.qubit): @@ -538,8 +536,8 @@ module @if_circuit { // TODO: it's not fully optimized yet, the P0 event is already dominated by the P_THEN event, // so it's no need to sync P0 and P_THEN. // CHECK: %[[IF_RESULT:.*]] = scf.if {{.*}} -> (!rtio.event) { - // CHECK: %[[CH2_THEN:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> - // CHECK: %[[P_THEN:.*]] = rtio.pulse %[[CH2_THEN]] {{.*}} wait(%[[P0]]) + // CHECK: %[[CH1_THEN:.*]] = rtio.channel : !rtio.channel<"dds", [1 : i64], 0> + // CHECK: %[[P_THEN:.*]] = rtio.pulse %[[CH1_THEN]] {{.*}} wait(%[[P0]]) // CHECK: %[[SYNC_THEN:.*]] = rtio.sync %[[P_THEN]], %[[P0]] : !rtio.event // CHECK: scf.yield %[[SYNC_THEN]] : !rtio.event // CHECK: } else { @@ -604,8 +602,8 @@ module @if_circuit { } // Gate after the if - // CHECK: %[[CH2_AFTER:.*]] = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> - // CHECK: %[[P_AFTER:.*]] = rtio.pulse %[[CH2_AFTER]] {{.*}} wait(%[[IF_RESULT]]) + // CHECK: %[[CH1_AFTER:.*]] = rtio.channel : !rtio.channel<"dds", [1 : i64], 0> + // CHECK: %[[P_AFTER:.*]] = rtio.pulse %[[CH1_AFTER]] {{.*}} wait(%[[IF_RESULT]]) %16 = quantum.extract %15[ 1] : !quantum.reg -> !quantum.bit %cst_5 = arith.constant 12.566370614359172 : f64 %17 = arith.remf %cst_0, %cst_5 : f64 @@ -687,7 +685,7 @@ module @measure_ion_to_rtio { } } -// CHECK: rtio.pulse {{.*}} {_measurement +// CHECK: rtio.pulse {{.*}}_measurement // CHECK: rtio.readout {{.*}} : !rtio.event -> i32 // CHECK: func.func private @__rtio_init_dataset() // CHECK: rtio.rpc @init_dataset diff --git a/mlir/test/RTIO/RTIOEventToARTIQ.mlir b/mlir/test/RTIO/RTIOEventToARTIQ.mlir index d562ffb7e9..93971240de 100644 --- a/mlir/test/RTIO/RTIOEventToARTIQ.mlir +++ b/mlir/test/RTIO/RTIOEventToARTIQ.mlir @@ -51,18 +51,22 @@ module @circuit attributes {rtio.config = #rtio.config<{core_addr = "172.31.9.64 %1 = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> %3 = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> - // Test rtio.pulse with wait on empty, should set frequency and generate TTL pulse + // Test rtio.pulse with wait on empty, generates TTL pulse // First pulse on channel 2 waiting on empty event // CHECK: llvm.call tail @now_mu() // CHECK: llvm.call tail @at_mu - // CHECK: llvm.call @__rtio_set_frequency + // CHECK: arith.maxsi + // CHECK: llvm.call tail @rtio_output + // CHECK: llvm.call fastcc tail @delay_mu + // CHECK: llvm.call tail @rtio_output // CHECK: llvm.call tail @now_mu() - %2 = rtio.pulse %1 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%0) {offset = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event + %2 = rtio.pulse %1 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%0) {_group = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event // Test parallel pulses, both wait on same event // CHECK: llvm.call tail @at_mu - // CHECK: llvm.call @__rtio_set_frequency - %4 = rtio.pulse %3 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%0) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + // CHECK: arith.maxsi + // CHECK: llvm.call tail @rtio_output + %4 = rtio.pulse %3 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%0) {_group = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event // Test sequential pulse on same channel (duration via max(duration_mu, minTTL), not __rtio_sec_to_mu) // CHECK: llvm.call tail @at_mu @@ -70,7 +74,7 @@ module @circuit attributes {rtio.config = #rtio.config<{core_addr = "172.31.9.64 // CHECK: llvm.call tail @rtio_output // CHECK: llvm.call fastcc tail @delay_mu // CHECK: llvm.call tail @rtio_output - %5 = rtio.pulse %3 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%4) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %5 = rtio.pulse %3 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%4) {_group = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event // Test rtio.sync, synchronizes multiple events using maxsi // CHECK: arith.maxsi @@ -78,12 +82,12 @@ module @circuit attributes {rtio.config = #rtio.config<{core_addr = "172.31.9.64 %6 = rtio.sync %5, %2 : !rtio.event // Test multiple parallel pulses after sync - %7 = rtio.pulse %3 duration(%cst_0) frequency(%cst_5) phase(%cst_7) wait(%6) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %7 = rtio.pulse %3 duration(%cst_0) frequency(%cst_5) phase(%cst_7) wait(%6) {_group = 1 : i64} : <"dds", [2 : i64], 0> -> !rtio.event %8 = rtio.channel : !rtio.channel<"dds", [2 : i64], 1> - %9 = rtio.pulse %8 duration(%cst_0) frequency(%cst_4) phase(%cst_7) wait(%6) {offset = 1 : i64} : <"dds", [2 : i64], 1> -> !rtio.event - %10 = rtio.pulse %1 duration(%cst_0) frequency(%cst_3) phase(%cst_7) wait(%6) {offset = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event + %9 = rtio.pulse %8 duration(%cst_0) frequency(%cst_4) phase(%cst_7) wait(%6) {_group = 1 : i64} : <"dds", [2 : i64], 1> -> !rtio.event + %10 = rtio.pulse %1 duration(%cst_0) frequency(%cst_3) phase(%cst_7) wait(%6) {_group = 1 : i64} : <"dds", [2 : i64], 2> -> !rtio.event %11 = rtio.channel : !rtio.channel<"dds", [2 : i64], 3> - %12 = rtio.pulse %11 duration(%cst_0) frequency(%cst_2) phase(%cst_7) wait(%6) {offset = 1 : i64} : <"dds", [2 : i64], 3> -> !rtio.event + %12 = rtio.pulse %11 duration(%cst_0) frequency(%cst_2) phase(%cst_7) wait(%6) {_group = 1 : i64} : <"dds", [2 : i64], 3> -> !rtio.event // Test sync with 4 events // CHECK: arith.maxsi @@ -93,9 +97,9 @@ module @circuit attributes {rtio.config = #rtio.config<{core_addr = "172.31.9.64 %13 = rtio.sync %7, %9, %10, %12 : !rtio.event // Final pulses after sync - %14 = rtio.pulse %3 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%13) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event - %15 = rtio.pulse %1 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%13) {offset = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event - %16 = rtio.pulse %3 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%14) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %14 = rtio.pulse %3 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%13) {_group = 2 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %15 = rtio.pulse %1 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%13) {_group = 2 : i64} : <"dds", [2 : i64], 2> -> !rtio.event + %16 = rtio.pulse %3 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%14) {_group = 3 : i64} : <"dds", [2 : i64], 0> -> !rtio.event // CHECK: return return @@ -117,14 +121,12 @@ module @simple_sequential attributes {rtio.config = #rtio.config<{core_addr = "1 %ch0 = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> - // First pulse, sets frequency and generates TTL - // CHECK: llvm.call @__rtio_set_frequency - // CHECK: llvm.call fastcc tail @delay_mu + // First pulse, generates TTL // CHECK: arith.maxsi // CHECK: llvm.call tail @rtio_output // CHECK: llvm.call fastcc tail @delay_mu // CHECK: llvm.call tail @rtio_output - %1 = rtio.pulse %ch0 duration(%cst_dur) frequency(%cst_freq) phase(%cst_phase) wait(%0) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %1 = rtio.pulse %ch0 duration(%cst_dur) frequency(%cst_freq) phase(%cst_phase) wait(%0) {_group = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event // Second pulse, sequential, waits for first // CHECK: llvm.call tail @at_mu @@ -132,7 +134,7 @@ module @simple_sequential attributes {rtio.config = #rtio.config<{core_addr = "1 // CHECK: llvm.call tail @rtio_output // CHECK: llvm.call fastcc tail @delay_mu // CHECK: llvm.call tail @rtio_output - %2 = rtio.pulse %ch0 duration(%cst_dur) frequency(%cst_freq) phase(%cst_phase) wait(%1) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %2 = rtio.pulse %ch0 duration(%cst_dur) frequency(%cst_freq) phase(%cst_phase) wait(%1) {_group = 1 : i64} : <"dds", [2 : i64], 0> -> !rtio.event // CHECK: return return @@ -232,7 +234,7 @@ module @measure_rtio_to_artiq attributes {rtio.config = #rtio.config<{device_db %cst_0 = arith.constant 1.000000e-04 : f64 %0 = rtio.empty : !rtio.event %1 = rtio.channel : !rtio.channel<"ttl", [1 : i64], 0> - %2 = rtio.pulse %1 duration(%cst_0) frequency(%cst) phase(%cst) wait(%0) {_measurement, offset = 0 : i64} : <"ttl", [1 : i64], 0> -> !rtio.event + %2 = rtio.pulse %1 duration(%cst_0) frequency(%cst) phase(%cst) wait(%0) {_group = 0 : i64, _measurement} : <"ttl", [1 : i64], 0> -> !rtio.event %3 = rtio.readout %2 : !rtio.event -> i32 return } From b323f6228a3105f726e447a201b9ae01e20075cd Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 8 Apr 2026 10:39:17 -0400 Subject: [PATCH 2/2] formatting and fix codecov --- mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp | 70 ++++++++++++------- .../Transforms/RTIOEventToARTIQPatterns.cpp | 12 ++-- 2 files changed, 49 insertions(+), 33 deletions(-) diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp index d1638e711b..6c64c6f40a 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp @@ -48,10 +48,50 @@ namespace rtio { namespace { +/// Schedule newly discovered pulse groups, set their wait to the current chain event and produce a +/// sync follow by the chain. +static llvm::SmallVector +schedulePulseGroups(ArrayRef consumers, DenseSet &scheduled, + DenseMap> &pulseGroups, Value &chain, + OpBuilder &builder) +{ + auto eventType = rtio::EventType::get(builder.getContext()); + + SmallVector newGroupIds; + for (auto p : consumers) { + int64_t gid = pulseGroupId(p); + if (scheduled.insert(gid).second) { + newGroupIds.push_back(gid); + } + } + llvm::sort(newGroupIds); + + for (int64_t gid : newGroupIds) { + auto &grp = pulseGroups[gid]; + for (auto p : grp) { + p.setWait(chain); + } + + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfter(grp.back()); + SmallVector evts; + for (auto p : grp) { + evts.push_back(p.getEvent()); + } + chain = rtio::RTIOSyncOp::create(builder, grp.back().getLoc(), eventType, evts); + } + + llvm::SmallVector nextEvents; + for (auto p : consumers) { + nextEvents.push_back(p.getEvent()); + } + + return nextEvents; +} + /// Schedule pulses for executing on ARTIQ static void schedule(func::FuncOp funcOp, OpBuilder &builder) { - auto eventType = rtio::EventType::get(funcOp.getContext()); if (funcOp.getBody().empty()) { return; } @@ -105,32 +145,8 @@ static void schedule(func::FuncOp funcOp, OpBuilder &builder) } } - SmallVector newGroupIds; - for (auto p : consumers) { - int64_t gid = pulseGroupId(p); - if (scheduled.insert(gid).second) { - newGroupIds.push_back(gid); - } - } - llvm::sort(newGroupIds); - - for (int64_t gid : newGroupIds) { - auto &grp = pulseGroups[gid]; - for (auto p : grp) { - p.setWait(chain); - } - - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfter(grp.back()); - SmallVector evts; - for (auto p : grp) { - evts.push_back(p.getEvent()); - } - chain = rtio::RTIOSyncOp::create(builder, grp.back().getLoc(), eventType, evts); - } - - for (auto p : consumers) - worklist.push_back(p.getEvent()); + auto nextEvents = schedulePulseGroups(consumers, scheduled, pulseGroups, chain, builder); + llvm::append_range(worklist, nextEvents); } } diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp index 7bd618f7ee..d779846b7a 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp @@ -71,11 +71,12 @@ struct PulseOpLowering : public OpConversionPattern { } Type chTy = getTypeConverter()->convertType(op.getChannel().getType()); - Value chVal = arith::ConstantOp::create( - rewriter, op.getLoc(), rewriter.getIntegerAttr(chTy, pulseGroupId(op))); + Value chVal = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIntegerAttr(chTy, pulseGroupId(op))); Value amplitude = artiq.constF64(1.0); - LLVM::CallOp::create(rewriter, op.getLoc(), setFreqFunc, - ValueRange{chVal, adaptor.getFrequency(), adaptor.getPhase(), amplitude}); + LLVM::CallOp::create( + rewriter, op.getLoc(), setFreqFunc, + ValueRange{chVal, adaptor.getFrequency(), adaptor.getPhase(), amplitude}); Value newTime = artiq.nowMu(); rewriter.replaceOp(op, newTime); @@ -95,8 +96,7 @@ struct PulseOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter, ARTIQRuntimeBuilder &artiq) const { - Value channelAddr = - computeChannelDeviceAddrForId(rewriter, op, pulseGroupId(op)); + Value channelAddr = computeChannelDeviceAddrForId(rewriter, op, pulseGroupId(op)); Value durationMu = artiq.secToMu(adaptor.getDuration()); // Enforce minimum pulse duration to avoid 0 duratoin events