@@ -161,7 +161,8 @@ class Vectorizer {
161161
162162 // 2. If it wasn't vectorized (or if it has multiple sources), try
163163 // Structural.
164- if (!transformed && canVectorizeStructurally (oldOutputVal)) {
164+ if (!transformed && !hasCrossBitDependencies (oldOutputVal) &&
165+ canVectorizeStructurally (oldOutputVal)) {
165166 rewriter.setInsertionPointAfterValue (oldOutputVal);
166167
167168 unsigned width = cast<IntegerType>(oldOutputVal.getType ()).getWidth ();
@@ -189,126 +190,223 @@ class Vectorizer {
189190 llvm::DenseMap<Value, BitArray> bitArrays;
190191 hw::HWModuleOp module ;
191192
192- // Returns true if `output` can be replaced by a single wide operation.
193+ // / Analyzes the logic cones of each bit in a vector to ensure they are
194+ // / independent or share only safe global values (like constants or block
195+ // / args). This prevents illegal vectorization of interdependent bit lanes.
196+ bool hasCrossBitDependencies (mlir::Value outputVal) {
197+ unsigned bitWidth = cast<IntegerType>(outputVal.getType ()).getWidth ();
198+ if (bitWidth <= 1 )
199+ return false ;
200+
201+ // Collect all SSA values involved in the computation of each bit lane.
202+ std::vector<llvm::DenseSet<mlir::Value>> bitCones (bitWidth);
203+ for (unsigned i = 0 ; i < bitWidth; ++i) {
204+ mlir::Value bitSource = findBitSource (outputVal, i);
205+ if (bitSource) {
206+ collectLogicCone (bitSource, bitCones[i]);
207+ }
208+ }
209+
210+ // Check for intersections between bit cones. Overlap is only allowed if
211+ // the shared values are considered "safe" (e.g., global control signals).
212+ for (unsigned i = 0 ; i < bitWidth; ++i) {
213+ for (unsigned j = i + 1 ; j < bitWidth; ++j) {
214+ for (mlir::Value val : bitCones[i]) {
215+ if (bitCones[j].count (val)) {
216+ if (!isSafeSharedValue (val)) {
217+ return true ;
218+ }
219+ }
220+ }
221+ }
222+ }
223+ return false ;
224+ }
225+
226+ // / Recursively traverses the defining operations of a value to build a
227+ // / set of all values in its transitive logic cone.
228+ void collectLogicCone (mlir::Value val, llvm::DenseSet<mlir::Value> &cone) {
229+ if (cone.count (val)) {
230+ return ;
231+ }
232+ cone.insert (val);
233+
234+ Operation *definingOp = val.getDefiningOp ();
235+ // Stop traversal at block arguments, constants, or when the source is gone.
236+ if (!definingOp || isa<BlockArgument>(val) ||
237+ isa<hw::ConstantOp>(definingOp)) {
238+ return ;
239+ }
240+
241+ for (Value operand : definingOp->getOperands ()) {
242+ collectLogicCone (operand, cone);
243+ }
244+ }
245+
246+ // / Determines if a shared value is safe for vectorization. Only constants
247+ // / and block arguments are safe to share between bit lanes. Any intermediate
248+ // / operation is considered unsafe as it may introduce cross-lane
249+ // / dependencies.
250+ bool isSafeSharedValue (mlir::Value val) {
251+ return val &&
252+ (isa<BlockArgument>(val) || val.getDefiningOp <hw::ConstantOp>());
253+ }
254+
255+ // / Checks if a logic cone is composed of structurally equivalent slices
256+ // / that can be merged into a vector operation.
193257 // /
194- // / The check succeeds when every bit slice i of `output` is produced by a
195- // / subgraph that is isomorphic to the bit-0 subgraph up to a uniform
196- // / bit-index offset of i (see areSubgraphsEquivalent).
258+ // / The check succeeds when every bit slice i of the output is produced by a
259+ // / subgraph that is isomorphic to the bit-0 subgraph (slice0).
197260 bool canVectorizeStructurally (Value output) {
198- unsigned width = cast<IntegerType>(output.getType ()).getWidth ();
261+ unsigned bitWidth = cast<IntegerType>(output.getType ()).getWidth ();
262+ if (bitWidth <= 1 )
263+ return false ;
199264
200- // A 1-bit value is already scalar; nothing to vectorize.
201- if (width <= 1 )
265+ Value slice0Val = findBitSource (output, 0 );
266+ if (!slice0Val )
202267 return false ;
203268
204- Value slice0 = findBitSource (output, 0 );
205- if (!slice0 )
269+ Value slice1Val = findBitSource (output, 1 );
270+ if (!slice1Val )
206271 return false ;
207272
208- // Compare each bit slice N against the base slice 0 to ensure isomorphism.
209- for (unsigned i = 1 ; i < width; ++i) {
210- Value sliceN = findBitSource (output, i);
211- if (!sliceN)
273+ auto extract0 = slice0Val.getDefiningOp <comb::ExtractOp>();
274+ auto extract1 = slice1Val.getDefiningOp <comb::ExtractOp>();
275+
276+ if (!extract0 || !extract1 || extract0.getInput () != extract1.getInput ()) {
277+ for (unsigned i = 1 ; i < bitWidth; ++i) {
278+ Value sliceNVal = findBitSource (output, i);
279+ if (!sliceNVal || !sliceNVal.getDefiningOp ())
280+ return false ;
281+ llvm::DenseMap<mlir::Value, mlir::Value> map;
282+ if (!areSubgraphsEquivalent (slice0Val, sliceNVal, i, 1 , map)) {
283+ return false ;
284+ }
285+ }
286+ return true ;
287+ }
288+
289+ int stride = (int )extract1.getLowBit () - (int )extract0.getLowBit ();
290+
291+ for (unsigned i = 1 ; i < bitWidth; ++i) {
292+ Value sliceNVal = findBitSource (output, i);
293+ if (!sliceNVal || !sliceNVal.getDefiningOp ())
212294 return false ;
213295
214- DenseMap<Value, Value> map;
215- if (!areSubgraphsEquivalent (slice0, sliceN , i, map))
296+ llvm:: DenseMap<mlir:: Value, mlir:: Value> map;
297+ if (!areSubgraphsEquivalent (slice0Val, sliceNVal , i, stride, map)) {
216298 return false ;
299+ }
217300 }
218301 return true ;
219302 }
220303
221- // / Recursively checks whether subgraph `b` is isomorphic to subgraph `a`
222- // / under the assumption that all ExtractOp low-bit indices in `b` are
223- // / exactly `index` greater than those in `a`.
304+ // / Recursively compares two subgraphs to determine if they are isomorphic
305+ // / with respect to a constant bit-stride.
224306 // /
225- // / Shared control values (block arguments, constants) are treated as
226- // / identical across all bit lanes.
227- // /
228- // / `map` caches already-verified pairs (a → b) to avoid redundant traversal
229- // / and to handle DAGs (values with multiple uses) correctly.
230- bool areSubgraphsEquivalent (Value a, Value b, unsigned index,
231- DenseMap<Value, Value> &map) {
232-
233- // If `a` was already mapped, verify it points to the same `b`.
234- if (map.count (a))
235- return map[a] == b;
236-
237- Operation *opA = a.getDefiningOp ();
238- Operation *opB = b.getDefiningOp ();
239-
240- // Leaf case: ExtractOp – same source, low-bit shifted by `index`.
241- if (auto exA = dyn_cast_or_null<comb::ExtractOp>(opA)) {
242- auto exB = dyn_cast_or_null<comb::ExtractOp>(opB);
243- if (!exB || exA.getInput () != exB.getInput ())
244- return false ;
245-
246- // The bit index in slice N must be exactly `index` ahead of slice 0.
247- if (exB.getLowBit () != exA.getLowBit () + (int )index)
248- return false ;
249-
250- map[a] = b;
251- return true ;
307+ // / It assumes that all ExtractOp low-bit indices in the second subgraph
308+ // / are exactly (sliceIndex * stride) greater than those in the first.
309+ // / Caches results in slice0ToNMap to handle DAGs efficiently.
310+ bool areSubgraphsEquivalent (Value slice0Val, Value sliceNVal,
311+ unsigned sliceIndex, int stride,
312+ DenseMap<Value, Value> &slice0ToNMap) {
313+
314+ if (slice0ToNMap.count (slice0Val))
315+ return slice0ToNMap[slice0Val] == sliceNVal;
316+
317+ Operation *op0 = slice0Val.getDefiningOp ();
318+ Operation *opN = sliceNVal.getDefiningOp ();
319+
320+ if (auto extract0 = dyn_cast_or_null<comb::ExtractOp>(op0)) {
321+ auto extractN = dyn_cast_or_null<comb::ExtractOp>(opN);
322+
323+ if (extractN && extract0.getInput () == extractN.getInput () &&
324+ extractN.getLowBit () == (unsigned )((int )extract0.getLowBit () +
325+ (int )sliceIndex * stride)) {
326+ slice0ToNMap[slice0Val] = sliceNVal;
327+ return true ;
328+ }
329+ return false ;
252330 }
253331
254- // Leaf case: block arguments and constants are considered equivalent when
255- // they are the *exact same* SSA value (shared across all bit lanes, e.g.,
256- // a mux select signal).
257- if (!opA && !opB) {
258- map[a] = b;
332+ if (slice0Val == sliceNVal && (mlir::isa<BlockArgument>(slice0Val) ||
333+ mlir::isa<hw::ConstantOp>(op0))) {
334+ slice0ToNMap[slice0Val] = sliceNVal;
259335 return true ;
260336 }
261337
262- // Interior node: both must be the same operation kind with the same arity.
263- if (!opA || !opB || opA->getName () != opB->getName () ||
264- opA->getNumOperands () != opB->getNumOperands ())
338+ if (!op0 || !opN || op0->getName () != opN->getName () ||
339+ op0->getNumOperands () != opN->getNumOperands ())
265340 return false ;
266341
267- // Recurse into all operand pairs.
268- for (unsigned i = 0 ; i < opA->getNumOperands (); ++i)
269- if (!areSubgraphsEquivalent (opA->getOperand (i), opB->getOperand (i), index,
270- map))
342+ for (unsigned i = 0 ; i < op0->getNumOperands (); ++i) {
343+ if (!areSubgraphsEquivalent (op0->getOperand (i), opN->getOperand (i),
344+ sliceIndex, stride, slice0ToNMap))
271345 return false ;
346+ }
272347
273- map[a ] = b ;
348+ slice0ToNMap[slice0Val ] = sliceNVal ;
274349 return true ;
275350 }
276351
277- // / Traverses ConcatOps to locate the defining 1-bit Value for `bitIndex`
278- // / within `vectorVal` .
352+ // / Traverses through ConcatOps and basic logic gates to locate the
353+ // / original 1-bit source for a specific bit index .
279354 // /
280- // / Returns nullptr if the bit cannot be traced to a concrete 1-bit source
281- // / (e.g., the concat is not fully decomposed into 1-bit pieces).
355+ // / Returns the 1-bit Value or nullptr if the bit cannot be traced back
356+ // / to a concrete scalar source
282357 Value findBitSource (Value vectorVal, unsigned bitIndex) {
283358
284- // A 1-bit block argument is its own source at index 0.
285- if (auto arg = dyn_cast<BlockArgument>(vectorVal))
286- return arg.getType ().isInteger (1 ) && bitIndex == 0 ? arg : nullptr ;
359+ if (auto blockArg = dyn_cast<BlockArgument>(vectorVal)) {
360+ if (blockArg.getType ().isInteger (1 )) {
361+ return blockArg;
362+ }
363+ return nullptr ;
364+ }
287365
288366 Operation *op = vectorVal.getDefiningOp ();
289- if (!op)
367+ if (!op) {
290368 return nullptr ;
369+ }
291370
292- // Decompose ConcatOp: comb.concat lists operands MSB→LSB, so we walk
293- // them and track the cumulative bit offset from LSB upward.
294- if (auto concat = dyn_cast<comb::ConcatOp>(op)) {
295- // `cur` starts at the total width and decreases as we peel off operands.
296- unsigned cur = cast<IntegerType>(vectorVal.getType ()).getWidth ();
297- for (Value operand : concat.getInputs ()) {
298- unsigned w = cast<IntegerType>(operand.getType ()).getWidth ();
299- cur -= w;
300- // `bitIndex` falls inside this operand's range [cur, cur+w].
301- if (bitIndex >= cur && bitIndex < cur + w)
302- return findBitSource (operand, bitIndex - cur);
371+ if (op->getNumResults () == 1 && op->getResult (0 ).getType ().isInteger (1 )) {
372+ return op->getResult (0 );
373+ }
374+
375+ if (auto constOp = dyn_cast<hw::ConstantOp>(op)) {
376+ if (constOp.getType ().isInteger (1 )) {
377+ return constOp.getResult ();
303378 }
304379 return nullptr ;
305380 }
306381
307- // Any other 1-bit result (AND, OR, XOR, MUX, …) is its own source at
308- // bit position 0.
309- if (op->getNumResults () == 1 && op->getResult (0 ).getType ().isInteger (1 ) &&
310- bitIndex == 0 )
311- return op->getResult (0 );
382+ if (auto concat = dyn_cast<comb::ConcatOp>(op)) {
383+ unsigned currentBit = cast<IntegerType>(vectorVal.getType ()).getWidth ();
384+ for (Value operand : concat.getInputs ()) {
385+ unsigned operandWidth = cast<IntegerType>(operand.getType ()).getWidth ();
386+ currentBit -= operandWidth;
387+ if (bitIndex >= currentBit && bitIndex < currentBit + operandWidth) {
388+ return findBitSource (operand, bitIndex - currentBit);
389+ }
390+ }
391+ } else if (auto orOp = dyn_cast<comb::OrOp>(op)) {
392+ if (auto source = findBitSource (orOp.getInputs ()[1 ], bitIndex)) {
393+ if (auto sourceConst =
394+ dyn_cast_or_null<hw::ConstantOp>(source.getDefiningOp ())) {
395+ if (!sourceConst.getValue ().isZero ())
396+ return source;
397+ } else {
398+ return source;
399+ }
400+ }
401+ return findBitSource (orOp.getInputs ()[0 ], bitIndex);
402+ } else if (auto andOp = dyn_cast<comb::AndOp>(op)) {
403+ Value lhs = andOp.getInputs ()[0 ];
404+ Value rhs = andOp.getInputs ()[1 ];
405+ if (isa_and_nonnull<hw::ConstantOp>(rhs.getDefiningOp ()))
406+ return findBitSource (lhs, bitIndex);
407+ if (isa_and_nonnull<hw::ConstantOp>(lhs.getDefiningOp ()))
408+ return findBitSource (rhs, bitIndex);
409+ }
312410
313411 return nullptr ;
314412 }
0 commit comments