Add BLAS trsv derivative rule#2828
Conversation
Adds the `def trsv` rule for the real triangular vector solve, with both forward and reverse mode derivatives. Mirrors the structure of `trtrs`. Outside the new `def`, the only required tablegen additions are: - `def Dep : MagicInst;` with a 6-line passthrough in `rev_call_arg`. The forward rule needs to gate two `axpy` and one `copy` call on `Shadow $A` being active, but `Shadow $A` is not a direct BLAS arg of those calls. `Dep` wraps an arg so activity analysis sees the dependency without it being emitted as a real BLAS argument. - `trsv` joins `trmv`/`trtrs` in the n=3 byRef-suffix arity table (three trailing string-length args for uplo/trans/diag). - `trsv` joins `trtrs` in the reversed activeArg iteration order, since reverse mode needs the solved x cotangent before forming dA. No changes to Utils.cpp/Utils.h or to other BLAS rules.
| def Concat : MagicInst; | ||
|
|
||
| def ShadowNoInc : MagicInst; | ||
| def Dep : MagicInst; |
| ] | ||
| , | ||
| (Seq<["tmp", "vector", "n"], [], 0> | ||
| (BlasCall<"copy"> $n, (Dep (Shadow $A), $x), use<"tmp">, ConstantInt<1>), |
There was a problem hiding this comment.
I don't think dep should be necessary here at all [and you should delete the concept?]
|
Thanks for the look. To answer the direct question: I tried removing it and the codegen breaks. The forward rule's
The forward JVP for trsv is A few options I can take, whichever you prefer:
I'm happy to do (2) immediately if you want me to unblock the simple reverse part — let me know. Or if there's a mechanism I missed that does the activity-gating without a new MagicInst, I'd appreciate the pointer and I'll refactor. |
Adds the
def trsvrule for the real BLAS triangular vector solve, with both forward and reverse mode derivatives. Split out of #2825 per @wsmoses's request to land per-rule PRs with minimal infrastructure churn.What's added
def trsvinBlasDerivatives.td, following the same reverse-mode shape astrtrs(single RHS vector instead of matrix). Forward mode is a 5-callSequsingcopy/trmv/axpy/axpy/trsv.def Dep : MagicInst;(1 line) plus a 6-line passthrough inrev_call_arg. The forward rule needs to gate twoaxpycalls and onecopyonShadow $Aactivity, butShadow $Ais not a direct BLAS argument of those calls.Depwraps an arg soif_rule_condition_innersees the activity dependency without emitting the shadow as a real BLAS argument. Noemit_daghandler is needed —Deponly appears nested insideBlasCallargs, andif_rule_condition_inneralready recurses through arbitrary DagInits transparently.trsvjoinstrmv/trtrsin then=3byRef-suffix arity table (three trailing Fortran string-length args foruplo/trans/diag) and in the reversed activeArg iteration order (reverse mode must solve thexcotangent before formingdA, identical dependency shape totrtrs).What's intentionally not here
enzyme/Enzyme/Utils.cpporUtils.h.trsm,potrf,potrs, or theuplo_to_*/side_to_*matchers.mat_ld/MatAdd/side_squaremachinery — those belong to follow-up PRs fortrsm/potrs.Tests
test/Enzyme/ForwardMode/blas/trsv_f.ll,test/Enzyme/ReverseMode/blas/trsv_f.llninja LLVMEnzyme-16,lit -sv test/Enzyme(1002 passed, 10 expected-fail; no regressions).Follow-ups will add
trsm,potrs, and forward-mode rules forpotrf/etc. as separate PRs.