-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relax][Backend] Fix TVM crashes with default relax pipeline when opt_level=1: InternalError: Check failed: (slot->value_computed) is false #18491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
1058392 to
95466f1
Compare
| if (it != slot_map_.end() && !it->second->value_computed) { | ||
| // If it's a variable, mark it as ready for computation | ||
| if (expr.as<tir::VarNode>()) { | ||
| it->second->value_computed = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking arbitrary tir::VarNode instances as value_computed = true in VisitExpr_ is incorrect. In VMShapeLower, only Relax ShapeVars (from function parameters/match_cast) or IntImm constants can be safely marked as computed, because only these have runtime values accessible to the VM.
I think correct way is to fix this earlier in the pipeline (shape inference/canonicalization) by symbolizing composite PrimExpr (introduce Relax ShapeVars) or simplifying them before VMShapeLower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tlopex
Thanks for the suggestion.
Should we extend ComputePrimValue() to handle ShapeExpr nodes with compound PrimExpr?
https://github.com/apache/tvm/blob/main/src/relax/transform/compute_prim_value.cc#L65
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, ComputePrimValue is intended only for evaluating statically evaluable PrimExpr into IntImm (constant folding), so I think extending ComputePrimValue would not address the root issue.
The real problem is that VMShapeLower cannot consume composite PrimExpr directly. The correct solution here should be to canonicalize ShapeExpr earlier by introducing a Relax ShapeVar binding for any non-trivial PrimExpr.
Just like:
# 1. Compute the symbolic value first (Canonicalization)
s1 = R.prim_value(n + 1)
# 2. Pass the computed var to the shape (VMShapeLower is happy now)
lv = R.call_tir(cls.func, (x,), R.shape([s1]), dtype="float32")
533de8a to
b98b5a7
Compare
b492829 to
cc57139
Compare
…vel=1: InternalError: Check failed: (slot->value_computed) is false
cc57139 to
a59ed45
Compare
a59ed45 to
b3f7800
Compare
Hi Commiters,
This PR is trying to fix issues #17876. Any suggestions would be appreciated if you are available.
Root Cause
VMShapeLowercrashed when processingShapeExprcontaining compositePrimExprthat weren't computed yet.Solution
ModifiedVisitExpr_(const ShapeExprNode* op)invm_shape_lower.ccto:1. Mark uncomputed variables as ready for computation2. Trigger EmitOutstandingPrimExprCompute() to resolve the dependency chain3. Ensure all expressions are computed before callingMakeSymbolicShapeArgAdded test case:test_composite_shape_expression_fix()to prevent future occurrences.symbolizing the composite PrimExpr in the pipeline: canonicalization