Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def Rows : MagicInst; // given a transpose, normal rows, normal cols get the tru
def Concat : MagicInst;

def ShadowNoInc : MagicInst;
def Dep : MagicInst;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what the heck is dep?


class Binop<string _s, list<string> _tys> {
string s = _s;
Expand Down Expand Up @@ -672,6 +673,45 @@ def trtrs : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $nrhs, $A, $l
]
>;

def trsv : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $A, $lda, $x, $incx),
["x"],
[cblas_layout, uplo, trans, diag, len, mld<["uplo", "n", "n"]>, vinc<["n"]>],
[
/* A */
(Seq<["tri", "triangular", "n"], [], 1>
(BlasCall<"lacpy"> $layout, $uplo, $n, $n, (Shadow $A), use<"tri">, $n),
(BlasCall<"ger">
$layout,
$n,
$n,
Constant<"-1">,
(Rows $trans,
(Concat (Shadow $x)),
(Concat $x)),
(Rows $trans,
(Concat $x),
(Concat (Shadow $x))),
use<"tri">, $n),

(BlasCall<"copy">
(ISelect (is_nonunit $diag), ConstantInt<0>, $n),
(First (Shadow $A)), (Add $lda, ConstantInt<1>),
use<"tri">, (Add $n, ConstantInt<1>)),

(BlasCall<"lacpy"> $layout, $uplo, $n, $n, use<"tri">, $n, (Shadow $A))
),
/* x */ (BlasCall<"trsv"> $layout, $uplo, transpose<"trans">, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x)),
]
,
(Seq<["tmp", "vector", "n"], [], 0>
(BlasCall<"copy"> $n, (Dep (Shadow $A), $x), use<"tmp">, ConstantInt<1>),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think dep should be necessary here at all [and you should delete the concept?]

(BlasCall<"trmv"> $layout, $uplo, $trans, $diag, $n, (Shadow $A), use<"tmp">, ConstantInt<1>),
(BlasCall<"axpy"> $n, Constant<"-1">, (Dep (Shadow $A), use<"tmp">), ConstantInt<1>, (Shadow $x)),
(BlasCall<"axpy"> (ISelect (is_nonunit $diag), ConstantInt<0>, $n), Constant<"1">, (Dep (Shadow $A), $x), (Shadow $x)),
(BlasCall<"trsv"> $layout, $uplo, $trans, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x))
)
>;

def spr2 : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy, $ap),
["ap"],
[cblas_layout, uplo, len, fp, vinc<["n"]>, vinc<["n"]>, ap<["n"]>],
Expand Down
42 changes: 42 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/blas/trsv_f.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S -enzyme-detect-readthrow=0 | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -enzyme-detect-readthrow=0 -S | FileCheck %s

target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

declare dso_local void @__enzyme_fwddiff(...)

declare void @dtrsv_64_(i8*, i8*, i8*, i64*, double*, i64*, double*, i64*, i64, i64, i64)

define void @f(double* %A, double* %x) {
entry:
%uplo = alloca i8, align 1
%trans = alloca i8, align 1
%diag = alloca i8, align 1
%n = alloca i64, align 8
%lda = alloca i64, align 8
%incx = alloca i64, align 8
store i8 85, i8* %uplo, align 1
store i8 78, i8* %trans, align 1
store i8 85, i8* %diag, align 1
store i64 4, i64* %n, align 8
store i64 4, i64* %lda, align 8
store i64 1, i64* %incx, align 8
call void @dtrsv_64_(i8* %uplo, i8* %trans, i8* %diag, i64* %n, double* %A, i64* %lda, double* %x, i64* %incx, i64 1, i64 1, i64 1)
ret void
}

define void @active(double* %A, double* %dA, double* %x, double* %dx) {
entry:
call void (...) @__enzyme_fwddiff(void (double*, double*)* @f, metadata !"enzyme_dup", double* %A, double* %dA, metadata !"enzyme_dup", double* %x, double* %dx)
ret void
}

; CHECK-LABEL: define internal void @fwddiffef(
; CHECK: call void @dtrsv_64_
; CHECK: call void @dcopy_64_
; CHECK: call void @dtrmv_64_
; CHECK: call void @daxpy_64_
; CHECK: call void @daxpy_64_
; CHECK: call void @dtrsv_64_
; CHECK: ret void
41 changes: 41 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas/trsv_f.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S -enzyme-detect-readthrow=0 | FileCheck %s

target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

declare void @dtrsv_64_(i8*, i8*, i8*, i64*, double*, i64*, double*, i64*, i64, i64, i64)

define void @f(double* %A, double* %x) {
entry:
%uplo = alloca i8, align 1
%trans = alloca i8, align 1
%diag = alloca i8, align 1
%n = alloca i64, align 8
%lda = alloca i64, align 8
%incx = alloca i64, align 8
store i8 85, i8* %uplo, align 1
store i8 78, i8* %trans, align 1
store i8 78, i8* %diag, align 1
store i64 4, i64* %n, align 8
store i64 4, i64* %lda, align 8
store i64 1, i64* %incx, align 8
call void @dtrsv_64_(i8* %uplo, i8* %trans, i8* %diag, i64* %n, double* %A, i64* %lda, double* %x, i64* %incx, i64 1, i64 1, i64 1)
ret void
}

declare void @__enzyme_autodiff(...)

define void @active(double* %A, double* %dA, double* %x, double* %dx) {
entry:
call void (...) @__enzyme_autodiff(void (double*, double*)* @f, metadata !"enzyme_dup", double* %A, double* %dA, metadata !"enzyme_dup", double* %x, double* %dx)
ret void
}

; CHECK: define internal void @diffef(double* %A, double* %"A'", double* %x, double* %"x'")
; CHECK: call void @dtrsv_64_
; CHECK: invertentry:
; CHECK: call void @dtrsv_64_
; CHECK: call void @dlacpy_64_
; CHECK: call void @dger_64_
; CHECK: call void @dcopy_64_
; CHECK: call void @dlacpy_64_
15 changes: 11 additions & 4 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,12 @@ void rev_call_arg(bool forward, const DagInit *ruleDag,
os << "); })";
return;
}
if (Def->getName() == "Dep") {
if (Dag->getNumArgs() != 2)
PrintFatalError(pattern.getLoc(), "only 2-arg Dep operands supported");
rev_call_arg(forward, Dag, pattern, 1, os, vars);
return;
}
if (Def->getName() == "ld") {
if (Dag->getNumArgs() != 5)
PrintFatalError(pattern.getLoc(), "only 5-arg ld operands supported");
Expand Down Expand Up @@ -1493,7 +1499,7 @@ void rev_call_args(bool forward, Twine argName, const TGPattern &pattern,
n = 1;
if (func == "gemm" || func == "syrk" || func == "syr2k" || func == "symm")
n = 2;
if (func == "trmv" || func == "trtrs")
if (func == "trmv" || func == "trsv" || func == "trtrs")
n = 3;
if (func == "trmm" || func == "trsm")
n = 4;
Expand Down Expand Up @@ -2354,9 +2360,10 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,

os << " auto bb_name = Builder2.GetInsertBlock()->getName();\n";
for (size_t iteri = 0; iteri < activeArgs.size(); iteri++) {
// trtrs do in reversed arg order.
size_t i = (pattern.getName() != "trtrs") ? iteri
: (activeArgs.size() - 1 - iteri);
// trtrs and trsv solve x cotangent before forming dA.
size_t i = (pattern.getName() != "trtrs" && pattern.getName() != "trsv")
? iteri
: (activeArgs.size() - 1 - iteri);
StringRef extraCond;
auto rule = rules[i];
const size_t actArg = activeArgs[i];
Expand Down