Skip to content

Commit c290d56

Browse files
committed
mse/nmse optimization
1 parent 38a53ff commit c290d56

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

include/dsplib/math.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,20 +253,12 @@ real_t norm(span_real x, int p = 2);
253253
real_t norm(span_cmplx x, int p = 2);
254254

255255
//Mean squared error
256-
inline real_t mse(span_real x, span_real y) {
257-
return mean(abs2(arr_real(x) - arr_real(y)));
258-
}
259-
inline real_t mse(span_cmplx x, span_cmplx y) {
260-
return mean(abs2(arr_cmplx(x) - arr_cmplx(y)));
261-
}
256+
real_t mse(span_real x, span_real y);
257+
real_t mse(span_cmplx x, span_cmplx y);
262258

263259
//Normalized mean squared error
264-
inline real_t nmse(span_real x, span_real y) {
265-
return mse(x, y) / sum(abs2(x));
266-
}
267-
inline real_t nmse(span_cmplx x, span_cmplx y) {
268-
return mse(x, y) / sum(abs2(x));
269-
}
260+
real_t nmse(span_real x, span_real y);
261+
real_t nmse(span_cmplx x, span_cmplx y);
270262

271263
//signum function
272264
constexpr int sign(const real_t& x) noexcept {

lib/math.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,51 @@ real_t norm(span_cmplx x, int p) {
678678
return power(sum(power(abs(x), p)), real_t(1) / p);
679679
}
680680

681+
//-------------------------------------------------------------------------------------------------
682+
namespace {
683+
684+
template<typename T>
685+
real_t _mse(span_t<T> x, span_t<T> y) {
686+
DSPLIB_ASSERT(x.size() == y.size(), "arrays sizes must be equal");
687+
const int n = x.size();
688+
real_t s = 0;
689+
for (int i = 0; i < n; ++i) {
690+
s += abs2(x[i] - y[i]);
691+
}
692+
return s / n;
693+
}
694+
695+
template<typename T>
696+
real_t _nmse(span_t<T> x, span_t<T> y) {
697+
DSPLIB_ASSERT(x.size() == y.size(), "arrays sizes must be equal");
698+
const int n = x.size();
699+
real_t s = 0;
700+
real_t d = 0;
701+
for (int i = 0; i < n; ++i) {
702+
s += abs2(x[i] - y[i]);
703+
d += abs2(x[i]);
704+
}
705+
return (s / n) / d;
706+
}
707+
708+
} // namespace
709+
710+
real_t mse(span_real x, span_real y) {
711+
return _mse(x, y);
712+
}
713+
714+
real_t mse(span_cmplx x, span_cmplx y) {
715+
return _mse(x, y);
716+
}
717+
718+
real_t nmse(span_real x, span_real y) {
719+
return _nmse(x, y);
720+
}
721+
722+
real_t nmse(span_cmplx x, span_cmplx y) {
723+
return _nmse(x, y);
724+
}
725+
681726
//-------------------------------------------------------------------------------------------------
682727
arr_real deg2rad(span_real x) {
683728
auto y = arr_real(x);

0 commit comments

Comments
 (0)