Skip to content

Commit eb58d29

Browse files
committed
[math] Forward declare sgemm_ in custom derivatives header
Forward declare `sgemm_` in custom derivatives header, instead of using the forward declaration wrapped inside the SOFIE namespace. Otherwise, the `CladDerivator.h` header is broken in case ROOT was built without TMVA. This change is covered by the SOFIE unit test `TestGemmDerivative`. Follows up on #18364.
1 parent d5b14f6 commit eb58d29

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

math/mathcore/inc/Math/CladDerivator.h

+16-11
Original file line numberDiff line numberDiff line change
@@ -1137,9 +1137,17 @@ inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a,
11371137
} // namespace Math
11381138
} // namespace ROOT
11391139

1140-
namespace TMVA {
1141-
namespace Experimental {
1142-
namespace SOFIE {
1140+
} // namespace custom_derivatives
1141+
} // namespace clad
1142+
1143+
// Forward declare BLAS functions.
1144+
extern "C" void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
1145+
const float *alpha, const float *A, const int *lda, const float *B, const int *ldb,
1146+
const float *beta, float *C, const int *ldc);
1147+
1148+
namespace clad::custom_derivatives {
1149+
1150+
namespace TMVA::Experimental::SOFIE {
11431151

11441152
inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha,
11451153
const float *A, const float *B, float beta, const float *C, float *_d_output, bool *,
@@ -1148,7 +1156,7 @@ inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, i
11481156
{
11491157
// TODO:
11501158
// - handle transa and transb cases correctly
1151-
if ( transa || transb ) {
1159+
if (transa || transb) {
11521160
return;
11531161
}
11541162

@@ -1161,8 +1169,8 @@ inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, i
11611169

11621170
// _d_A, _d_B
11631171
// note: beta needs to be one because we want to add to _d_A and _d_B instead of overwriting it.
1164-
::TMVA::Experimental::SOFIE::BLAS::sgemm_(&cn, &ct, &m, &k, &n, &alpha, _d_output, &m, B, &k, &one, _d_A, &m);
1165-
::TMVA::Experimental::SOFIE::BLAS::sgemm_(&ct, &cn, &k, &n, &m, &alpha, A, &m, _d_output, &m, &one, _d_B, &k);
1172+
::sgemm_(&cn, &ct, &m, &k, &n, &alpha, _d_output, &m, B, &k, &one, _d_A, &m);
1173+
::sgemm_(&ct, &cn, &k, &n, &m, &alpha, A, &m, _d_output, &m, &one, _d_B, &k);
11661174

11671175
// _d_alpha, _d_beta, _d_C
11681176
int sizeC = n * m;
@@ -1173,11 +1181,8 @@ inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, i
11731181
}
11741182
}
11751183

1176-
} // namespace SOFIE
1177-
} // namespace Experimental
1178-
} // namespace TMVA
1184+
} // namespace TMVA::Experimental::SOFIE
11791185

1180-
} // namespace custom_derivatives
1181-
} // namespace clad
1186+
} // namespace clad::custom_derivatives
11821187

11831188
#endif // CLAD_DERIVATOR

0 commit comments

Comments
 (0)