Skip to content

Commit bcbcdb4

Browse files
authored
Merge pull request #451 from Xilinx/bump_to_7b11dfc0
[AutoBump] Merge with fixes of 7b11dfc (Oct 11) (80)
2 parents dcba58b + 28940f2 commit bcbcdb4

File tree

5 files changed

+381
-4
lines changed

5 files changed

+381
-4
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7368,10 +7368,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
73687368
loc, Torch::ListType::get(Torch::IntType::get(context)),
73697369
ValueRange{constantOne});
73707370

7371-
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
7372-
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
7373-
paddingSizeList, dialationList,
7374-
/*ceil_mode=*/constantFalse);
7371+
if (op.getResult(1).use_empty()) {
7372+
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
7373+
loc, op.getType(0), input, kernelSizeList, strideList,
7374+
paddingSizeList, dialationList,
7375+
/*ceil_mode=*/constantFalse);
7376+
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
7377+
} else {
7378+
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
7379+
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
7380+
paddingSizeList, dialationList,
7381+
/*ceil_mode=*/constantFalse);
7382+
rewriter.replaceOp(op, maxPool.getResults());
7383+
}
73757384
return success();
73767385
}
73777386
};

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
6363
return success();
6464
}
6565

66+
LogicalResult constructListFromLiteral(PatternRewriter &rewriter,
67+
ValueTensorLiteralOp literalOp,
68+
SmallVector<Value> &vals) {
69+
// only supports splat ValueTensorLiterals for now. TODO: add support for
70+
// small non-splat valuetensorliterals.
71+
auto ty = dyn_cast<ValueTensorType>(literalOp.getType());
72+
if (!ty || !ty.hasSizes())
73+
return failure();
74+
auto attr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
75+
if (!attr)
76+
return failure();
77+
auto attrInt = dyn_cast<IntegerAttr>(attr.getSplatValue<Attribute>());
78+
if (!attrInt)
79+
return failure();
80+
IntegerType intty = cast<IntegerType>(attrInt.getType());
81+
if (!intty.isSignedInteger())
82+
return failure();
83+
Value materializedVal = rewriter.create<Torch::ConstantIntOp>(
84+
literalOp.getLoc(), attrInt.getSInt());
85+
vals.resize(vals.size() + ty.getSizes()[0], materializedVal);
86+
return success();
87+
}
88+
6689
LogicalResult getListFromTensor(Value value, SmallVector<Value> &vals) {
6790
constexpr int64_t kMaxFold = 16;
6891
if (auto tensor = value.getDefiningOp<Torch::AtenTensorOp>())
@@ -351,6 +374,172 @@ class PropagateAtenSliceTensorPattern
351374
};
352375
} // namespace
353376

377+
namespace {
378+
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
379+
public:
380+
using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
381+
LogicalResult matchAndRewrite(AtenWhereSelfOp op,
382+
PatternRewriter &rewriter) const override {
383+
Value condition = op.getCondition();
384+
Value self = op.getSelf();
385+
Value other = op.getOther();
386+
auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType());
387+
if (!conditionTy || !conditionTy.hasSizes() ||
388+
conditionTy.getSizes().size() != 1)
389+
return rewriter.notifyMatchFailure(op, "bad condition type");
390+
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
391+
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
392+
return rewriter.notifyMatchFailure(op, "bad self type");
393+
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
394+
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
395+
return rewriter.notifyMatchFailure(op, "bad other type");
396+
int64_t conditionSize = selfTy.getSizes()[0];
397+
int64_t selfSize = selfTy.getSizes()[0];
398+
int64_t otherSize = otherTy.getSizes()[0];
399+
400+
if (selfSize != otherSize || selfSize != conditionSize)
401+
return rewriter.notifyMatchFailure(
402+
op,
403+
"unimplemented: support for propogating with implicit broadcasting.");
404+
405+
constexpr int64_t kMaxFold = 16;
406+
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold)
407+
return rewriter.notifyMatchFailure(op,
408+
"arguments are dynamic or too big");
409+
410+
SmallVector<Value> conditionList, selfList, otherList;
411+
if (failed(getListFromTensor(condition, conditionList)) ||
412+
(int64_t)conditionList.size() != conditionSize)
413+
return failure();
414+
415+
// If one of these tensors is a value tensor literal op, we will need to
416+
// create constant ints in the IR to form a list. Before calling
417+
// constructListFromLiteral, we must be certain that the conversion can no
418+
// longer fail, otherwise we will cause an infinite loop of creating a
419+
// constant and removing it.
420+
LogicalResult selfFromList = getListFromTensor(self, selfList);
421+
LogicalResult otherFromList = getListFromTensor(other, otherList);
422+
423+
if (failed(selfFromList) && failed(otherFromList))
424+
return rewriter.notifyMatchFailure(
425+
op, "At least one operand must succeed at constructing a list");
426+
427+
auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
428+
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
429+
if (succeeded(selfFromList) && otherLiteral &&
430+
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
431+
return failure();
432+
if (succeeded(otherFromList) && selfLiteral &&
433+
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
434+
return failure();
435+
if ((int64_t)selfList.size() != selfSize ||
436+
(int64_t)otherList.size() != otherSize)
437+
// this should only occur if we did not generate IR with
438+
// constructListFromLiteral
439+
return failure();
440+
441+
Location loc = op.getLoc();
442+
SmallVector<Value> whereVals;
443+
auto rank0IntTy = rewriter.getType<Torch::ValueTensorType>(
444+
ArrayRef<int64_t>({}), selfTy.getDtype());
445+
auto rank0BoolTy = rewriter.getType<Torch::ValueTensorType>(
446+
ArrayRef<int64_t>({}), conditionTy.getDtype());
447+
for (uint64_t i = 0; i < selfList.size(); i++) {
448+
Value rank0Cond = rewriter.create<Torch::PrimNumToTensorScalarOp>(
449+
loc, rank0BoolTy, conditionList[i]);
450+
Value rank0Self = rewriter.create<Torch::PrimNumToTensorScalarOp>(
451+
loc, rank0IntTy, selfList[i]);
452+
Value rank0Other = rewriter.create<Torch::PrimNumToTensorScalarOp>(
453+
loc, rank0IntTy, otherList[i]);
454+
Value rank0Where = rewriter.create<AtenWhereSelfOp>(
455+
loc, rank0IntTy, rank0Cond, rank0Self, rank0Other);
456+
whereVals.push_back(rewriter.create<AtenItemOp>(
457+
loc, rewriter.getType<Torch::IntType>(), rank0Where));
458+
}
459+
Value list = rewriter.create<Torch::PrimListConstructOp>(
460+
op.getLoc(), Torch::ListType::get(whereVals[0].getType()), whereVals);
461+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
462+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
463+
op.getLoc(), rewriter.getBoolAttr(false));
464+
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
465+
op, op.getType(), list, cstNone, cstNone, cstFalse);
466+
return success();
467+
}
468+
};
469+
} // namespace
470+
471+
namespace {
472+
class PropagateAtenEqTensorPattern : public OpRewritePattern<AtenEqTensorOp> {
473+
public:
474+
using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
475+
LogicalResult matchAndRewrite(AtenEqTensorOp op,
476+
PatternRewriter &rewriter) const override {
477+
Value self = op.getSelf();
478+
Value other = op.getOther();
479+
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
480+
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
481+
return rewriter.notifyMatchFailure(op, "bad self type");
482+
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
483+
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
484+
return rewriter.notifyMatchFailure(op, "bad other type");
485+
int64_t selfSize = selfTy.getSizes()[0];
486+
int64_t otherSize = otherTy.getSizes()[0];
487+
488+
if (selfSize != otherSize)
489+
return rewriter.notifyMatchFailure(
490+
op,
491+
"unimplemented: support for propogating with implicit broadcasting.");
492+
493+
constexpr int64_t kMaxFold = 16;
494+
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
495+
otherSize == Torch::kUnknownSize || otherSize > kMaxFold)
496+
return rewriter.notifyMatchFailure(op,
497+
"self or other is dynamic or too big");
498+
499+
SmallVector<Value> selfList, otherList;
500+
// If one of these tensors is a value tensor literal op, we will need to
501+
// create constant ints in the IR to form a list. Before calling
502+
// constructListFromLiteral, we must be certain that the conversion can no
503+
// longer fail, otherwise we will cause an infinite loop of creating a
504+
// constant and removing it.
505+
LogicalResult selfFromList = getListFromTensor(self, selfList);
506+
LogicalResult otherFromList = getListFromTensor(other, otherList);
507+
508+
if (failed(selfFromList) && failed(otherFromList))
509+
return rewriter.notifyMatchFailure(
510+
op, "At least one operand must succeed at constructing a list");
511+
512+
auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
513+
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
514+
if (succeeded(selfFromList) && otherLiteral &&
515+
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
516+
return failure();
517+
if (succeeded(otherFromList) && selfLiteral &&
518+
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
519+
return failure();
520+
if ((int64_t)selfList.size() != selfSize ||
521+
(int64_t)otherList.size() != otherSize)
522+
// this should only occur if we did not generate IR with
523+
// constructListFromLiteral
524+
return failure();
525+
526+
SmallVector<Value> eqVals;
527+
for (uint64_t i = 0; i < selfList.size(); i++) {
528+
eqVals.push_back(
529+
rewriter.create<AtenEqIntOp>(op.getLoc(), selfList[i], otherList[i]));
530+
}
531+
Value list = rewriter.create<Torch::PrimListConstructOp>(
532+
op.getLoc(), Torch::ListType::get(eqVals[0].getType()), eqVals);
533+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
534+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
535+
op.getLoc(), rewriter.getBoolAttr(false));
536+
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
537+
op, op.getType(), list, cstNone, cstNone, cstFalse);
538+
return success();
539+
}
540+
};
541+
} // namespace
542+
354543
namespace {
355544
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
356545
public:
@@ -454,6 +643,26 @@ class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
454643
};
455644
} // namespace
456645

646+
namespace {
647+
class FoldAtenSqueezeDimPattern : public OpRewritePattern<AtenSqueezeDimOp> {
648+
public:
649+
using OpRewritePattern<AtenSqueezeDimOp>::OpRewritePattern;
650+
LogicalResult matchAndRewrite(AtenSqueezeDimOp op,
651+
PatternRewriter &rewriter) const override {
652+
auto resultTy = cast<ValueTensorType>(op.getType());
653+
if (!resultTy.hasSizes() || resultTy.getSizes().size() != 0)
654+
return rewriter.notifyMatchFailure(op, "Unknown result shape");
655+
656+
if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
657+
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
658+
op, resultTy, atenFull.getFillValue());
659+
return success();
660+
}
661+
return failure();
662+
}
663+
};
664+
} // namespace
665+
457666
namespace {
458667
class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
459668
public:
@@ -694,6 +903,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
694903
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
695904
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
696905
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
906+
PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
907+
FoldAtenSqueezeDimPattern,
697908
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
698909
RemoveUnusedPattern<Torch::AtenEqIntOp>,
699910
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,

lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
4646
};
4747
} // namespace
4848

49+
namespace {
50+
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
51+
public:
52+
using OpRewritePattern::OpRewritePattern;
53+
LogicalResult matchAndRewrite(AtenTensorOp op,
54+
PatternRewriter &rewriter) const override {
55+
auto context = op.getContext();
56+
auto loc = op.getLoc();
57+
auto result = op.getResult();
58+
auto resultType = cast<BaseTensorType>(result.getType());
59+
if (resultType.hasSizes() && resultType.hasDtype()) {
60+
return rewriter.notifyMatchFailure(
61+
op, "The result of aten.tensor is already a BaseTensorType.");
62+
}
63+
64+
auto inputList = op.getOperand(0);
65+
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
66+
if (!listConstruct) {
67+
return rewriter.notifyMatchFailure(
68+
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
69+
}
70+
71+
// Currently only support the 1d input list.
72+
SmallVector<int64_t> sizes;
73+
sizes.push_back(listConstruct->getOperands().size());
74+
FailureOr<Type> torchType;
75+
auto eleType = listConstruct->getOperands()[0].getType();
76+
if (isa<Torch::IntType>(eleType)) {
77+
torchType = getTypeForScalarType(op->getContext(),
78+
torch_upstream::ScalarType::Long);
79+
} else if (isa<Torch::FloatType>(eleType)) {
80+
torchType = getTypeForScalarType(op->getContext(),
81+
torch_upstream::ScalarType::Float);
82+
} else {
83+
return rewriter.notifyMatchFailure(
84+
op, "Currently only support Int and Float Type.");
85+
}
86+
auto newResultType = ValueTensorType::get(context, sizes, *torchType);
87+
88+
Value originalTypedValue;
89+
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
90+
if (!originalTypedValue) {
91+
rewriter.setInsertionPointAfter(op);
92+
originalTypedValue =
93+
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
94+
}
95+
use.set(originalTypedValue);
96+
}
97+
98+
result.setType(newResultType);
99+
100+
return success();
101+
}
102+
};
103+
} // namespace
104+
49105
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
50106
int resultNum,
51107
PatternRewriter &rewriter) {
@@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
135191
populateFoldPrimUncheckedCastOpPattern(patterns, context);
136192
patterns.insert<DecomposeAtenSizeOp>(context);
137193
patterns.insert<RefineShapeCalculateOp>(context);
194+
patterns.insert<InferTensorOp>(context);
138195

139196
PrimIfOp::getCanonicalizationPatterns(patterns, context);
140197
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5941,6 +5941,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
59415941
# ==============================================================================
59425942

59435943

5944+
class TensorAlloc1dStaticModule(torch.nn.Module):
5945+
def __init__(self):
5946+
super().__init__()
5947+
5948+
@export
5949+
@annotate_args(
5950+
[
5951+
None,
5952+
([2, 4, 6], torch.int, True),
5953+
]
5954+
)
5955+
def forward(self, x):
5956+
res = torch.tensor([x.shape[0]])
5957+
return res
5958+
5959+
5960+
@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
5961+
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
5962+
module.forward(tu.rand(2, 4, 6))
5963+
5964+
5965+
# ==============================================================================
5966+
5967+
59445968
class ScalarTensorFloat32Module(torch.nn.Module):
59455969
def __init__(self):
59465970
super().__init__()

0 commit comments

Comments
 (0)