Skip to content

Commit 3faf348

Browse files
committed
[HW] Finalize Part 3: Structural Vectorization and tests
1 parent 1848e56 commit 3faf348

File tree

1 file changed

+181
-83
lines changed

1 file changed

+181
-83
lines changed

lib/Dialect/HW/Transforms/HWVectorization.cpp

Lines changed: 181 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ class Vectorizer {
161161

162162
// 2. If it wasn't vectorized (or if it has multiple sources), try
163163
// Structural.
164-
if (!transformed && canVectorizeStructurally(oldOutputVal)) {
164+
if (!transformed && !hasCrossBitDependencies(oldOutputVal) &&
165+
canVectorizeStructurally(oldOutputVal)) {
165166
rewriter.setInsertionPointAfterValue(oldOutputVal);
166167

167168
unsigned width = cast<IntegerType>(oldOutputVal.getType()).getWidth();
@@ -189,126 +190,223 @@ class Vectorizer {
189190
llvm::DenseMap<Value, BitArray> bitArrays;
190191
hw::HWModuleOp module;
191192

192-
// Returns true if `output` can be replaced by a single wide operation.
193+
/// Analyzes the logic cones of each bit in a vector to ensure they are
194+
/// independent or share only safe global values (like constants or block
195+
/// args). This prevents illegal vectorization of interdependent bit lanes.
196+
bool hasCrossBitDependencies(mlir::Value outputVal) {
197+
unsigned bitWidth = cast<IntegerType>(outputVal.getType()).getWidth();
198+
if (bitWidth <= 1)
199+
return false;
200+
201+
// Collect all SSA values involved in the computation of each bit lane.
202+
std::vector<llvm::DenseSet<mlir::Value>> bitCones(bitWidth);
203+
for (unsigned i = 0; i < bitWidth; ++i) {
204+
mlir::Value bitSource = findBitSource(outputVal, i);
205+
if (bitSource) {
206+
collectLogicCone(bitSource, bitCones[i]);
207+
}
208+
}
209+
210+
// Check for intersections between bit cones. Overlap is only allowed if
211+
// the shared values are considered "safe" (e.g., global control signals).
212+
for (unsigned i = 0; i < bitWidth; ++i) {
213+
for (unsigned j = i + 1; j < bitWidth; ++j) {
214+
for (mlir::Value val : bitCones[i]) {
215+
if (bitCones[j].count(val)) {
216+
if (!isSafeSharedValue(val)) {
217+
return true;
218+
}
219+
}
220+
}
221+
}
222+
}
223+
return false;
224+
}
225+
226+
/// Recursively traverses the defining operations of a value to build a
227+
/// set of all values in its transitive logic cone.
228+
void collectLogicCone(mlir::Value val, llvm::DenseSet<mlir::Value> &cone) {
229+
if (cone.count(val)) {
230+
return;
231+
}
232+
cone.insert(val);
233+
234+
Operation *definingOp = val.getDefiningOp();
235+
// Stop traversal at block arguments, constants, or when the source is gone.
236+
if (!definingOp || isa<BlockArgument>(val) ||
237+
isa<hw::ConstantOp>(definingOp)) {
238+
return;
239+
}
240+
241+
for (Value operand : definingOp->getOperands()) {
242+
collectLogicCone(operand, cone);
243+
}
244+
}
245+
246+
/// Determines if a shared value is safe for vectorization. Only constants
247+
/// and block arguments are safe to share between bit lanes. Any intermediate
248+
/// operation is considered unsafe as it may introduce cross-lane
249+
/// dependencies.
250+
bool isSafeSharedValue(mlir::Value val) {
251+
return val &&
252+
(isa<BlockArgument>(val) || val.getDefiningOp<hw::ConstantOp>());
253+
}
254+
255+
/// Checks if a logic cone is composed of structurally equivalent slices
256+
/// that can be merged into a vector operation.
193257
///
194-
/// The check succeeds when every bit slice i of `output` is produced by a
195-
/// subgraph that is isomorphic to the bit-0 subgraph up to a uniform
196-
/// bit-index offset of i (see areSubgraphsEquivalent).
258+
/// The check succeeds when every bit slice i of the output is produced by a
259+
/// subgraph that is isomorphic to the bit-0 subgraph (slice0).
197260
bool canVectorizeStructurally(Value output) {
198-
unsigned width = cast<IntegerType>(output.getType()).getWidth();
261+
unsigned bitWidth = cast<IntegerType>(output.getType()).getWidth();
262+
if (bitWidth <= 1)
263+
return false;
199264

200-
// A 1-bit value is already scalar; nothing to vectorize.
201-
if (width <= 1)
265+
Value slice0Val = findBitSource(output, 0);
266+
if (!slice0Val)
202267
return false;
203268

204-
Value slice0 = findBitSource(output, 0);
205-
if (!slice0)
269+
Value slice1Val = findBitSource(output, 1);
270+
if (!slice1Val)
206271
return false;
207272

208-
// Compare each bit slice N against the base slice 0 to ensure isomorphism.
209-
for (unsigned i = 1; i < width; ++i) {
210-
Value sliceN = findBitSource(output, i);
211-
if (!sliceN)
273+
auto extract0 = slice0Val.getDefiningOp<comb::ExtractOp>();
274+
auto extract1 = slice1Val.getDefiningOp<comb::ExtractOp>();
275+
276+
if (!extract0 || !extract1 || extract0.getInput() != extract1.getInput()) {
277+
for (unsigned i = 1; i < bitWidth; ++i) {
278+
Value sliceNVal = findBitSource(output, i);
279+
if (!sliceNVal || !sliceNVal.getDefiningOp())
280+
return false;
281+
llvm::DenseMap<mlir::Value, mlir::Value> map;
282+
if (!areSubgraphsEquivalent(slice0Val, sliceNVal, i, 1, map)) {
283+
return false;
284+
}
285+
}
286+
return true;
287+
}
288+
289+
int stride = (int)extract1.getLowBit() - (int)extract0.getLowBit();
290+
291+
for (unsigned i = 1; i < bitWidth; ++i) {
292+
Value sliceNVal = findBitSource(output, i);
293+
if (!sliceNVal || !sliceNVal.getDefiningOp())
212294
return false;
213295

214-
DenseMap<Value, Value> map;
215-
if (!areSubgraphsEquivalent(slice0, sliceN, i, map))
296+
llvm::DenseMap<mlir::Value, mlir::Value> map;
297+
if (!areSubgraphsEquivalent(slice0Val, sliceNVal, i, stride, map)) {
216298
return false;
299+
}
217300
}
218301
return true;
219302
}
220303

221-
/// Recursively checks whether subgraph `b` is isomorphic to subgraph `a`
222-
/// under the assumption that all ExtractOp low-bit indices in `b` are
223-
/// exactly `index` greater than those in `a`.
304+
/// Recursively compares two subgraphs to determine if they are isomorphic
305+
/// with respect to a constant bit-stride.
224306
///
225-
/// Shared control values (block arguments, constants) are treated as
226-
/// identical across all bit lanes.
227-
///
228-
/// `map` caches already-verified pairs (a → b) to avoid redundant traversal
229-
/// and to handle DAGs (values with multiple uses) correctly.
230-
bool areSubgraphsEquivalent(Value a, Value b, unsigned index,
231-
DenseMap<Value, Value> &map) {
232-
233-
// If `a` was already mapped, verify it points to the same `b`.
234-
if (map.count(a))
235-
return map[a] == b;
236-
237-
Operation *opA = a.getDefiningOp();
238-
Operation *opB = b.getDefiningOp();
239-
240-
// Leaf case: ExtractOp – same source, low-bit shifted by `index`.
241-
if (auto exA = dyn_cast_or_null<comb::ExtractOp>(opA)) {
242-
auto exB = dyn_cast_or_null<comb::ExtractOp>(opB);
243-
if (!exB || exA.getInput() != exB.getInput())
244-
return false;
245-
246-
// The bit index in slice N must be exactly `index` ahead of slice 0.
247-
if (exB.getLowBit() != exA.getLowBit() + (int)index)
248-
return false;
249-
250-
map[a] = b;
251-
return true;
307+
/// It assumes that all ExtractOp low-bit indices in the second subgraph
308+
/// are exactly (sliceIndex * stride) greater than those in the first.
309+
/// Caches results in slice0ToNMap to handle DAGs efficiently.
310+
bool areSubgraphsEquivalent(Value slice0Val, Value sliceNVal,
311+
unsigned sliceIndex, int stride,
312+
DenseMap<Value, Value> &slice0ToNMap) {
313+
314+
if (slice0ToNMap.count(slice0Val))
315+
return slice0ToNMap[slice0Val] == sliceNVal;
316+
317+
Operation *op0 = slice0Val.getDefiningOp();
318+
Operation *opN = sliceNVal.getDefiningOp();
319+
320+
if (auto extract0 = dyn_cast_or_null<comb::ExtractOp>(op0)) {
321+
auto extractN = dyn_cast_or_null<comb::ExtractOp>(opN);
322+
323+
if (extractN && extract0.getInput() == extractN.getInput() &&
324+
extractN.getLowBit() == (unsigned)((int)extract0.getLowBit() +
325+
(int)sliceIndex * stride)) {
326+
slice0ToNMap[slice0Val] = sliceNVal;
327+
return true;
328+
}
329+
return false;
252330
}
253331

254-
// Leaf case: block arguments and constants are considered equivalent when
255-
// they are the *exact same* SSA value (shared across all bit lanes, e.g.,
256-
// a mux select signal).
257-
if (!opA && !opB) {
258-
map[a] = b;
332+
if (slice0Val == sliceNVal && (mlir::isa<BlockArgument>(slice0Val) ||
333+
mlir::isa<hw::ConstantOp>(op0))) {
334+
slice0ToNMap[slice0Val] = sliceNVal;
259335
return true;
260336
}
261337

262-
// Interior node: both must be the same operation kind with the same arity.
263-
if (!opA || !opB || opA->getName() != opB->getName() ||
264-
opA->getNumOperands() != opB->getNumOperands())
338+
if (!op0 || !opN || op0->getName() != opN->getName() ||
339+
op0->getNumOperands() != opN->getNumOperands())
265340
return false;
266341

267-
// Recurse into all operand pairs.
268-
for (unsigned i = 0; i < opA->getNumOperands(); ++i)
269-
if (!areSubgraphsEquivalent(opA->getOperand(i), opB->getOperand(i), index,
270-
map))
342+
for (unsigned i = 0; i < op0->getNumOperands(); ++i) {
343+
if (!areSubgraphsEquivalent(op0->getOperand(i), opN->getOperand(i),
344+
sliceIndex, stride, slice0ToNMap))
271345
return false;
346+
}
272347

273-
map[a] = b;
348+
slice0ToNMap[slice0Val] = sliceNVal;
274349
return true;
275350
}
276351

277-
/// Traverses ConcatOps to locate the defining 1-bit Value for `bitIndex`
278-
/// within `vectorVal`.
352+
/// Traverses through ConcatOps and basic logic gates to locate the
353+
/// original 1-bit source for a specific bit index.
279354
///
280-
/// Returns nullptr if the bit cannot be traced to a concrete 1-bit source
281-
/// (e.g., the concat is not fully decomposed into 1-bit pieces).
355+
/// Returns the 1-bit Value or nullptr if the bit cannot be traced back
356+
/// to a concrete scalar source
282357
Value findBitSource(Value vectorVal, unsigned bitIndex) {
283358

284-
// A 1-bit block argument is its own source at index 0.
285-
if (auto arg = dyn_cast<BlockArgument>(vectorVal))
286-
return arg.getType().isInteger(1) && bitIndex == 0 ? arg : nullptr;
359+
if (auto blockArg = dyn_cast<BlockArgument>(vectorVal)) {
360+
if (blockArg.getType().isInteger(1)) {
361+
return blockArg;
362+
}
363+
return nullptr;
364+
}
287365

288366
Operation *op = vectorVal.getDefiningOp();
289-
if (!op)
367+
if (!op) {
290368
return nullptr;
369+
}
291370

292-
// Decompose ConcatOp: comb.concat lists operands MSB→LSB, so we walk
293-
// them and track the cumulative bit offset from LSB upward.
294-
if (auto concat = dyn_cast<comb::ConcatOp>(op)) {
295-
// `cur` starts at the total width and decreases as we peel off operands.
296-
unsigned cur = cast<IntegerType>(vectorVal.getType()).getWidth();
297-
for (Value operand : concat.getInputs()) {
298-
unsigned w = cast<IntegerType>(operand.getType()).getWidth();
299-
cur -= w;
300-
// `bitIndex` falls inside this operand's range [cur, cur+w].
301-
if (bitIndex >= cur && bitIndex < cur + w)
302-
return findBitSource(operand, bitIndex - cur);
371+
if (op->getNumResults() == 1 && op->getResult(0).getType().isInteger(1)) {
372+
return op->getResult(0);
373+
}
374+
375+
if (auto constOp = dyn_cast<hw::ConstantOp>(op)) {
376+
if (constOp.getType().isInteger(1)) {
377+
return constOp.getResult();
303378
}
304379
return nullptr;
305380
}
306381

307-
// Any other 1-bit result (AND, OR, XOR, MUX, …) is its own source at
308-
// bit position 0.
309-
if (op->getNumResults() == 1 && op->getResult(0).getType().isInteger(1) &&
310-
bitIndex == 0)
311-
return op->getResult(0);
382+
if (auto concat = dyn_cast<comb::ConcatOp>(op)) {
383+
unsigned currentBit = cast<IntegerType>(vectorVal.getType()).getWidth();
384+
for (Value operand : concat.getInputs()) {
385+
unsigned operandWidth = cast<IntegerType>(operand.getType()).getWidth();
386+
currentBit -= operandWidth;
387+
if (bitIndex >= currentBit && bitIndex < currentBit + operandWidth) {
388+
return findBitSource(operand, bitIndex - currentBit);
389+
}
390+
}
391+
} else if (auto orOp = dyn_cast<comb::OrOp>(op)) {
392+
if (auto source = findBitSource(orOp.getInputs()[1], bitIndex)) {
393+
if (auto sourceConst =
394+
dyn_cast_or_null<hw::ConstantOp>(source.getDefiningOp())) {
395+
if (!sourceConst.getValue().isZero())
396+
return source;
397+
} else {
398+
return source;
399+
}
400+
}
401+
return findBitSource(orOp.getInputs()[0], bitIndex);
402+
} else if (auto andOp = dyn_cast<comb::AndOp>(op)) {
403+
Value lhs = andOp.getInputs()[0];
404+
Value rhs = andOp.getInputs()[1];
405+
if (isa_and_nonnull<hw::ConstantOp>(rhs.getDefiningOp()))
406+
return findBitSource(lhs, bitIndex);
407+
if (isa_and_nonnull<hw::ConstantOp>(lhs.getDefiningOp()))
408+
return findBitSource(rhs, bitIndex);
409+
}
312410

313411
return nullptr;
314412
}

0 commit comments

Comments
 (0)