diff --git a/compiler/README.md b/compiler/README.md index e10beda3e0..a5fb5dc983 100644 --- a/compiler/README.md +++ b/compiler/README.md @@ -78,6 +78,8 @@ Code generation options: **-cir** **--check-integer-range** check float to integer range conversion. + **-nsa** **--numeric-stability-analysis** run static analysis for numerical instability and enable automatic precision compensation. + **-exp10** **--generate-exp10** pow(10,x) replaced by possibly faster exp10(x). **-os** **--one-sample** generate one sample computation. diff --git a/compiler/global.cpp b/compiler/global.cpp index 6b484b364c..a752365bbc 100644 --- a/compiler/global.cpp +++ b/compiler/global.cpp @@ -461,6 +461,7 @@ void global::reset() gFTZMode = 0; gRangeUI = false; gFreezeUI = false; + gNumericStabilityAnalysis = false; gFloatSize = 1; // -single by default gFixedPointSize = AP_INT_MAX_W; // Special -1 value will be used to generate fixpoint_t type @@ -514,6 +515,7 @@ void global::reset() gNamespace = ""; gFullParentheses = false; gCheckIntRange = false; + gNumericStabilityAnalysis = false; gReprC = true; gNarrowingLimit = 0; @@ -836,6 +838,9 @@ void global::printCompilationOptions(stringstream& dst, bool backend) if (gCheckIntRange) { dst << "-cir "; } + if (gNumericStabilityAnalysis) { + dst << "-nsa "; + } if (gExtControl) { dst << "-ec "; } @@ -1581,6 +1586,10 @@ bool global::processCmdline(int argc, const char* argv[]) } else if (isCmd(argv[i], "-cir", "--check-integer-range")) { gCheckIntRange = true; i += 1; + } else if (isCmd(argv[i], "-nsa", "--numeric-stability-analysis")) { + gNumericStabilityAnalysis = true; + gAllWarning = true; + i += 1; } else if (isCmd(argv[i], "-noreprc", "--no-reprc")) { gReprC = false; i += 1; @@ -2251,6 +2260,10 @@ string global::printHelp() sstr << tab << "-cir --check-integer-range check float to integer range conversion." << endl; + sstr << tab + << "-nsa --numeric-stability-analysis run static analysis for numerical " + "instability and enable automatic precision compensation." + << endl; sstr << tab << "-exp10 --generate-exp10 pow(10,x) replaced by possibly faster exp10(x)." diff --git a/compiler/global.hh b/compiler/global.hh index 3c14a12c93..a31b46ecea 100644 --- a/compiler/global.hh +++ b/compiler/global.hh @@ -300,6 +300,7 @@ struct global { bool gFullParentheses; // -fp option, generate less parenthesis in some textual backends: // C/C++, Cmajor, Dlang, Rust bool gCheckIntRange; // -cir option, check float to integer range conversion + bool gNumericStabilityAnalysis; // -nsa option, run static analysis and automatic precision compensation bool gReprC; // (Rust) Force dsp struct layout to follow C ABI std::string gClassName; // -cn option, name of the generated dsp class, by default 'mydsp' diff --git a/compiler/libcode.cpp b/compiler/libcode.cpp index c44b61b499..da73250a6c 100644 --- a/compiler/libcode.cpp +++ b/compiler/libcode.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -1351,6 +1352,82 @@ LIBFAUST_API Tree DSPToBoxes(const string& name_app, const string& dsp_content, static void* createFactoryAux1(void* arg) { + struct NumericStabilityAnalyzer : public SignalVisitor { + int fDelayCount = 0; + int fRecCount = 0; + int fDivCount = 0; + int fNearUnityCount = 0; + bool fRisky = false; + + static bool isNearUnity(Tree x) + { + double coeff; + if (!isSigReal(x, &coeff)) { + return false; + } + double magnitude = std::fabs(coeff); + return (magnitude >= 0.95 && magnitude < 1.0); + } + + explicit NumericStabilityAnalyzer(Tree root) + { + visitRoot(root); + // Heuristic risk gate: recursion plus either division, a deep delay cascade, + // or feedback gains close to the unit circle. + fRisky = (fRecCount > 0) && ((fDivCount > 0) || (fDelayCount >= 4) || (fNearUnityCount > 0)); + } + + void visit(Tree sig) + { + int op; + Tree x; + Tree y; + Tree t0; + Tree t1; + Tree var; + Tree le; + + if (isSigDelay1(sig, t0)) { + fDelayCount++; + } else if (isSigDelay(sig, t0, t1)) { + fDelayCount++; + } else if (isRec(sig, var, le)) { + fRecCount++; + } else if (isSigBinOp(sig, &op, x, y)) { + if (op == kDiv) { + fDivCount++; + } else if ((op == kMul) && (isNearUnity(x) || isNearUnity(y))) { + fNearUnityCount++; + } + } + SignalVisitor::visit(sig); + } + }; + + auto runNumericStabilityAnalysis = [](Tree signals) { + NumericStabilityAnalyzer analyzer(signals); + if (!analyzer.fRisky) { + return; + } + + { + stringstream warning; + warning << "WARNING : numeric stability analysis detected a recursive structure with " + << analyzer.fDelayCount << " delays, " << analyzer.fDivCount + << " divisions and " << analyzer.fNearUnityCount + << " near-unity gains. Consider a more stable filter form.\n"; + gWarningMessages.push_back(warning.str()); + } + + // Automatic compensation: promote internal precision when currently in single precision. + if (gGlobal->gFloatSize == 1) { + gGlobal->gFloatSize = 2; + gWarningMessages.push_back( + "WARNING : numeric stability analysis promoted internal precision from single to " + "double.\n"); + } + }; + try { CallContext* context = static_cast(arg); string name_app = context->fNameApp; @@ -1424,6 +1501,10 @@ static void* createFactoryAux1(void* arg) cout << "\n\n"; } + if (gGlobal->gNumericStabilityAnalysis) { + runNumericStabilityAnalysis(lsignals); + } + endTiming("propagation"); /************************************************************************* @@ -1446,6 +1527,79 @@ static void* createFactoryAux1(void* arg) static void* createFactoryAux2(void* arg) { + struct NumericStabilityAnalyzer : public SignalVisitor { + int fDelayCount = 0; + int fRecCount = 0; + int fDivCount = 0; + int fNearUnityCount = 0; + bool fRisky = false; + + static bool isNearUnity(Tree x) + { + double coeff; + if (!isSigReal(x, &coeff)) { + return false; + } + double magnitude = std::fabs(coeff); + return (magnitude >= 0.95 && magnitude < 1.0); + } + + explicit NumericStabilityAnalyzer(Tree root) + { + visitRoot(root); + fRisky = (fRecCount > 0) && ((fDivCount > 0) || (fDelayCount >= 4) || (fNearUnityCount > 0)); + } + + void visit(Tree sig) + { + int op; + Tree x; + Tree y; + Tree t0; + Tree t1; + Tree var; + Tree le; + + if (isSigDelay1(sig, t0)) { + fDelayCount++; + } else if (isSigDelay(sig, t0, t1)) { + fDelayCount++; + } else if (isRec(sig, var, le)) { + fRecCount++; + } else if (isSigBinOp(sig, &op, x, y)) { + if (op == kDiv) { + fDivCount++; + } else if ((op == kMul) && (isNearUnity(x) || isNearUnity(y))) { + fNearUnityCount++; + } + } + SignalVisitor::visit(sig); + } + }; + + auto runNumericStabilityAnalysis = [](Tree signals) { + NumericStabilityAnalyzer analyzer(signals); + if (!analyzer.fRisky) { + return; + } + + { + stringstream warning; + warning << "WARNING : numeric stability analysis detected a recursive structure with " + << analyzer.fDelayCount << " delays, " << analyzer.fDivCount + << " divisions and " << analyzer.fNearUnityCount + << " near-unity gains. Consider a more stable filter form.\n"; + gWarningMessages.push_back(warning.str()); + } + + if (gGlobal->gFloatSize == 1) { + gGlobal->gFloatSize = 2; + gWarningMessages.push_back( + "WARNING : numeric stability analysis promoted internal precision from single to " + "double.\n"); + } + }; + // Keep the maximum index of inputs signals struct MaxInputsCounter : public SignalVisitor { int fMaxInputs = 0; @@ -1489,6 +1643,10 @@ static void* createFactoryAux2(void* arg) gGlobal->initDocumentNames(); + if (gGlobal->gNumericStabilityAnalysis) { + runNumericStabilityAnalysis(signals2); + } + // Open output file openOutfile(); diff --git a/documentation/man/README.md b/documentation/man/README.md index e10beda3e0..a5fb5dc983 100644 --- a/documentation/man/README.md +++ b/documentation/man/README.md @@ -78,6 +78,8 @@ Code generation options: **-cir** **--check-integer-range** check float to integer range conversion. + **-nsa** **--numeric-stability-analysis** run static analysis for numerical instability and enable automatic precision compensation. + **-exp10** **--generate-exp10** pow(10,x) replaced by possibly faster exp10(x). **-os** **--one-sample** generate one sample computation. diff --git a/tests/warning-tests/CMakeLists.txt b/tests/warning-tests/CMakeLists.txt index 4c942cd51f..eea752f1a5 100644 --- a/tests/warning-tests/CMakeLists.txt +++ b/tests/warning-tests/CMakeLists.txt @@ -11,3 +11,10 @@ foreach(test ${tests}) ) endforeach() + +add_test(NAME numeric_stability_analysis_nsa + COMMAND faust -nsa ${CMAKE_CURRENT_SOURCE_DIR}/numeric-stability-analysis.dsp) + +set_property(TEST numeric_stability_analysis_nsa + PROPERTY PASS_REGULAR_EXPRESSION "numeric stability analysis" +) diff --git a/tests/warning-tests/numeric-stability-analysis.dsp b/tests/warning-tests/numeric-stability-analysis.dsp new file mode 100644 index 0000000000..0214e3c3af --- /dev/null +++ b/tests/warning-tests/numeric-stability-analysis.dsp @@ -0,0 +1,6 @@ +// WARNING : numeric stability analysis should detect this recursive near-unity gain loop when -nsa is enabled. + +process = loop +with { + loop = +(1.0) ~ (*(0.99995)); +};