|
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
9 | 9 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 10 | +#include "mlir/IR/Attributes.h" |
10 | 11 | #include "mlir/IR/Matchers.h" |
11 | 12 | #include "mlir/Support/LLVM.h" |
12 | 13 | #include "llvm/ADT/APSInt.h" |
@@ -280,27 +281,28 @@ std::optional<APInt> constantTripCount( |
280 | 281 | computeUbMinusLb) { |
281 | 282 | // This is the bitwidth used to return 0 when loop does not execute. |
282 | 283 | // We infer it from the type of the bound if it isn't an index type. |
283 | | - bool isIndex = true; |
284 | | - auto getBitwidth = [&](OpFoldResult ofr) -> int { |
285 | | - if (auto attr = dyn_cast<Attribute>(ofr)) { |
286 | | - if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
287 | | - if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) { |
288 | | - isIndex = intType.isIndex(); |
289 | | - return intType.getWidth(); |
290 | | - } |
291 | | - } |
| 284 | + auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> { |
| 285 | + if (auto intAttr = |
| 286 | + dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) { |
| 287 | + if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) |
| 288 | + return std::make_tuple(intType.getWidth(), intType.isIndex()); |
292 | 289 | } else { |
293 | 290 | auto val = cast<Value>(ofr); |
294 | | - if (auto intType = dyn_cast<IntegerType>(val.getType())) { |
295 | | - isIndex = intType.isIndex(); |
296 | | - return intType.getWidth(); |
297 | | - } |
| 291 | + if (auto intType = dyn_cast<IntegerType>(val.getType())) |
| 292 | + return std::make_tuple(intType.getWidth(), intType.isIndex()); |
298 | 293 | } |
299 | | - return IndexType::kInternalStorageBitWidth; |
| 294 | + return std::make_tuple(IndexType::kInternalStorageBitWidth, true); |
300 | 295 | }; |
301 | | - int bitwidth = getBitwidth(lb); |
302 | | - assert(bitwidth == getBitwidth(ub) && |
303 | | - "lb and ub must have the same bitwidth"); |
| 296 | + auto [bitwidth, isIndex] = getBitwidth(lb); |
| 297 | + // This would better be an assert, but unfortunately it breaks scf.for_all |
| 298 | + // which is missing attributes and SSA value optionally for its bounds, and |
| 299 | + // uses Index type for the dynamic bounds but i64 for the static bounds. This |
| 300 | + // is broken... |
| 301 | + if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) { |
| 302 | + LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs " |
| 303 | + << lb; |
| 304 | + return std::nullopt; |
| 305 | + } |
304 | 306 | if (lb == ub) |
305 | 307 | return APInt(bitwidth, 0); |
306 | 308 |
|
|
0 commit comments