Skip to content

Commit 5a03229

Browse files
authored
feat: add NAM neural amp modeler stage (#243)
1 parent fa01ce1 commit 5a03229

18 files changed

Lines changed: 625 additions & 2 deletions

File tree

Cargo.lock

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

nam/.gitkeep

Whitespace-only changes.

rustortion-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ rustfft = "6.4"
1717
realfft = "3.5"
1818
arc-swap = "1.8"
1919
assert_no_alloc = { version = "1.1", features = ["warn_debug"] }
20+
nam-rs = "0.1.0"
2021

2122
[dev-dependencies]
2223
criterion = { version = "0.8", features = ["html_reports"] }

rustortion-core/src/amp/stages/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod eq;
66
pub mod filter;
77
pub mod level;
88
pub mod multiband_saturator;
9+
pub mod nam;
910
pub mod noise_gate;
1011
pub mod poweramp;
1112
pub mod preamp;
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
use log::warn;
2+
use nam_rs::WaveNet;
3+
use serde::{Deserialize, Serialize};
4+
5+
use crate::amp::stages::Stage;
6+
use crate::amp::stages::common::db_to_lin;
7+
use crate::nam::registry;
8+
9+
/// Valid range for the input/output gain knobs, matching the UI and plugin params.
10+
const GAIN_DB_MIN: f32 = -24.0;
11+
const GAIN_DB_MAX: f32 = 24.0;
12+
13+
/// A Neural Amp Modeler stage running a WaveNet `.nam` model.
14+
///
15+
/// With no model loaded the stage is a passthrough. Input/output gain are applied
16+
/// around the model and the wet output is blended with the dry signal via `mix`.
17+
pub struct NamStage {
18+
wavenet: Option<WaveNet>,
19+
input_gain: f32,
20+
output_gain: f32,
21+
mix: f32,
22+
/// Native sample rate of the loaded model (0.0 if none), for UI display.
23+
native_sample_rate: f32,
24+
/// True if the model's native rate differs from the engine rate.
25+
sample_rate_mismatch: bool,
26+
}
27+
28+
impl NamStage {
29+
const fn passthrough(input_gain: f32, output_gain: f32, mix: f32) -> Self {
30+
Self {
31+
wavenet: None,
32+
input_gain,
33+
output_gain,
34+
mix,
35+
native_sample_rate: 0.0,
36+
sample_rate_mismatch: false,
37+
}
38+
}
39+
}
40+
41+
impl Stage for NamStage {
42+
fn process(&mut self, input: f32) -> f32 {
43+
let Some(wavenet) = self.wavenet.as_mut() else {
44+
return input;
45+
};
46+
let wet = wavenet.process_sample(input * self.input_gain) * self.output_gain;
47+
self.mix.mul_add(wet - input, input)
48+
}
49+
50+
fn set_parameter(&mut self, name: &str, value: f32) -> Result<(), &'static str> {
51+
match name {
52+
"input_gain_db" => {
53+
if (GAIN_DB_MIN..=GAIN_DB_MAX).contains(&value) {
54+
self.input_gain = db_to_lin(value);
55+
Ok(())
56+
} else {
57+
Err("Input gain must be between -24 and 24 dB")
58+
}
59+
}
60+
"output_gain_db" => {
61+
if (GAIN_DB_MIN..=GAIN_DB_MAX).contains(&value) {
62+
self.output_gain = db_to_lin(value);
63+
Ok(())
64+
} else {
65+
Err("Output gain must be between -24 and 24 dB")
66+
}
67+
}
68+
"mix" => {
69+
if (0.0..=1.0).contains(&value) {
70+
self.mix = value;
71+
Ok(())
72+
} else {
73+
Err("Mix must be between 0.0 and 1.0")
74+
}
75+
}
76+
"native_sample_rate" | "sample_rate_mismatch" => Err("Parameter is read-only"),
77+
_ => Err("Unknown parameter"),
78+
}
79+
}
80+
81+
fn get_parameter(&self, name: &str) -> Result<f32, &'static str> {
82+
match name {
83+
"input_gain_db" => Ok(20.0 * self.input_gain.log10()),
84+
"output_gain_db" => Ok(20.0 * self.output_gain.log10()),
85+
"mix" => Ok(self.mix),
86+
"native_sample_rate" => Ok(self.native_sample_rate),
87+
"sample_rate_mismatch" => Ok(f32::from(u8::from(self.sample_rate_mismatch))),
88+
_ => Err("Unknown parameter name"),
89+
}
90+
}
91+
}
92+
93+
// --- Config ---
94+
95+
#[derive(Debug, Clone, Serialize, Deserialize)]
96+
pub struct NamConfig {
97+
/// Display name of the selected model, or `None` for passthrough.
98+
#[serde(default)]
99+
pub model_name: Option<String>,
100+
pub input_gain_db: f32,
101+
pub output_gain_db: f32,
102+
pub mix: f32,
103+
#[serde(default)]
104+
pub bypassed: bool,
105+
}
106+
107+
impl Default for NamConfig {
108+
fn default() -> Self {
109+
Self {
110+
model_name: None,
111+
input_gain_db: 0.0,
112+
output_gain_db: 0.0,
113+
mix: 1.0,
114+
bypassed: false,
115+
}
116+
}
117+
}
118+
119+
impl NamConfig {
120+
/// Build a runnable stage. Resolves the model from the global registry and
121+
/// allocates the `WaveNet` here (off the real-time thread). On any failure the
122+
/// stage falls back to passthrough with a warning.
123+
pub fn to_stage(&self, sample_rate: f32) -> NamStage {
124+
let input_gain = db_to_lin(self.input_gain_db.clamp(GAIN_DB_MIN, GAIN_DB_MAX));
125+
let output_gain = db_to_lin(self.output_gain_db.clamp(GAIN_DB_MIN, GAIN_DB_MAX));
126+
let mix = self.mix.clamp(0.0, 1.0);
127+
128+
let Some(name) = self.model_name.as_deref() else {
129+
return NamStage::passthrough(input_gain, output_gain, mix);
130+
};
131+
132+
let Some(model) = registry::get(name) else {
133+
warn!("NAM model '{name}' not found in registry; using passthrough");
134+
return NamStage::passthrough(input_gain, output_gain, mix);
135+
};
136+
137+
let native_sample_rate = model.sample_rate() as f32;
138+
let sample_rate_mismatch = (native_sample_rate - sample_rate).abs() > 1.0;
139+
if sample_rate_mismatch {
140+
warn!(
141+
"NAM model '{name}' native rate {native_sample_rate} Hz differs from engine \
142+
rate {sample_rate} Hz; tone may be affected"
143+
);
144+
}
145+
146+
match WaveNet::new(&model) {
147+
Ok(wavenet) => NamStage {
148+
wavenet: Some(wavenet),
149+
input_gain,
150+
output_gain,
151+
mix,
152+
native_sample_rate,
153+
sample_rate_mismatch,
154+
},
155+
Err(e) => {
156+
warn!("Failed to build NAM model '{name}': {e}; using passthrough");
157+
NamStage::passthrough(input_gain, output_gain, mix)
158+
}
159+
}
160+
}
161+
}
162+
163+
#[cfg(test)]
164+
mod tests {
165+
use super::*;
166+
167+
#[test]
168+
fn passthrough_when_no_model() {
169+
let stage = NamConfig::default().to_stage(48_000.0);
170+
let mut stage = stage;
171+
for x in [-1.0, 0.0, 0.25, 0.9] {
172+
assert_eq!(stage.process(x), x);
173+
}
174+
}
175+
176+
#[test]
177+
fn gain_and_mix_round_trip() {
178+
let mut stage = NamConfig::default().to_stage(48_000.0);
179+
stage.set_parameter("mix", 0.5).unwrap();
180+
assert!((stage.get_parameter("mix").unwrap() - 0.5).abs() < 1e-6);
181+
182+
stage.set_parameter("input_gain_db", 6.0).unwrap();
183+
assert!((stage.get_parameter("input_gain_db").unwrap() - 6.0).abs() < 1e-3);
184+
185+
assert!(stage.set_parameter("mix", 2.0).is_err());
186+
assert!(stage.set_parameter("native_sample_rate", 1.0).is_err());
187+
188+
// Gains outside ±24 dB (and NaN) are rejected.
189+
assert!(stage.set_parameter("input_gain_db", 30.0).is_err());
190+
assert!(stage.set_parameter("output_gain_db", -30.0).is_err());
191+
assert!(stage.set_parameter("input_gain_db", f32::NAN).is_err());
192+
}
193+
}

rustortion-core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@ pub mod amp;
3333
pub mod audio;
3434
pub mod ir;
3535
pub mod metronome;
36+
pub mod nam;
3637
pub mod preset;
3738
pub mod tuner;

rustortion-core/src/nam/loader.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use std::collections::BTreeMap;
2+
use std::path::Path;
3+
use std::sync::Arc;
4+
5+
use anyhow::{Context, Result};
6+
use log::{info, warn};
7+
use nam_rs::NamModel;
8+
9+
/// Scans a directory for `*.nam` files and parses each into memory.
10+
///
11+
/// Parsing happens once, at construction (off the real-time thread). Models are
12+
/// keyed by display name (the file stem). Unparseable files are skipped with a
13+
/// warning rather than failing the whole scan, matching the IR loader's tolerant
14+
/// behaviour.
15+
pub struct NamLoader {
16+
models: BTreeMap<String, Arc<NamModel>>,
17+
}
18+
19+
impl NamLoader {
20+
/// Scan `directory` and parse every `*.nam` file found.
21+
///
22+
/// A missing directory yields an empty loader (warn, not error) so the app can
23+
/// run without a nam folder present.
24+
pub fn new(directory: &Path) -> Result<Self> {
25+
let mut models = BTreeMap::new();
26+
27+
if !directory.is_dir() {
28+
warn!(
29+
"NAM directory '{}' does not exist; no models loaded",
30+
directory.display()
31+
);
32+
return Ok(Self { models });
33+
}
34+
35+
let entries = std::fs::read_dir(directory)
36+
.with_context(|| format!("Failed to read NAM directory '{}'", directory.display()))?;
37+
38+
for entry in entries {
39+
let entry = match entry {
40+
Ok(entry) => entry,
41+
Err(e) => {
42+
warn!("Skipping unreadable entry in NAM directory: {e}");
43+
continue;
44+
}
45+
};
46+
let path = entry.path();
47+
if path.extension().and_then(|e| e.to_str()) != Some("nam") {
48+
continue;
49+
}
50+
let Some(name) = path.file_stem().and_then(|s| s.to_str()).map(str::to_owned) else {
51+
continue;
52+
};
53+
54+
match std::fs::read_to_string(&path)
55+
.map_err(anyhow::Error::from)
56+
.and_then(|json| NamModel::from_json_str(&json).map_err(anyhow::Error::from))
57+
{
58+
Ok(model) => {
59+
info!(
60+
"Loaded NAM model '{name}' ({} Hz)",
61+
model.sample_rate() as u32
62+
);
63+
models.insert(name, Arc::new(model));
64+
}
65+
Err(e) => warn!("Skipping NAM file '{}': {e}", path.display()),
66+
}
67+
}
68+
69+
Ok(Self { models })
70+
}
71+
72+
/// Sorted list of available model display names.
73+
#[must_use]
74+
pub fn available_names(&self) -> Vec<String> {
75+
self.models.keys().cloned().collect()
76+
}
77+
78+
/// Look up a parsed model by display name.
79+
#[must_use]
80+
pub fn get(&self, name: &str) -> Option<Arc<NamModel>> {
81+
self.models.get(name).cloned()
82+
}
83+
84+
/// All parsed models, for populating the global registry.
85+
pub fn models(&self) -> impl Iterator<Item = (&String, &Arc<NamModel>)> {
86+
self.models.iter()
87+
}
88+
}

rustortion-core/src/nam/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//! NAM (Neural Amp Modeler) model loading and a process-global parsed-model
2+
//! registry.
3+
//!
4+
//! `.nam` models are parsed (and the `WaveNet` allocated) off the real-time thread.
5+
//! The [`loader`] scans a directory and parses every `*.nam` file into memory at
6+
//! startup; the [`registry`] makes those parsed models reachable from
7+
//! `StageConfig::to_runtime`, which has no other handle to the loader.
8+
9+
pub mod loader;
10+
pub mod registry;
11+
12+
pub use loader::NamLoader;
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//! Process-global registry of parsed NAM models.
2+
//!
3+
//! `StageConfig::to_runtime(sample_rate)` builds stages off the real-time thread but
4+
//! has no handle to the [`NamLoader`](super::loader::NamLoader). Since the nam folder
5+
//! is a singleton resource, a process-global registry lets `NamConfig::to_runtime`
6+
//! resolve a model by name without threading a loader through every `to_runtime` call
7+
//! site.
8+
9+
use std::collections::HashMap;
10+
use std::sync::{Arc, OnceLock, RwLock};
11+
12+
use nam_rs::NamModel;
13+
14+
use super::loader::NamLoader;
15+
16+
type Store = RwLock<HashMap<String, Arc<NamModel>>>;
17+
18+
static NAM_REGISTRY: OnceLock<Store> = OnceLock::new();
19+
20+
fn store() -> &'static Store {
21+
NAM_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
22+
}
23+
24+
/// Populate (or replace) the global registry from a loader's parsed models.
25+
pub fn init_from_loader(loader: &NamLoader) {
26+
let mut map = store()
27+
.write()
28+
.unwrap_or_else(std::sync::PoisonError::into_inner);
29+
map.clear();
30+
for (name, model) in loader.models() {
31+
map.insert(name.clone(), Arc::clone(model));
32+
}
33+
}
34+
35+
/// Look up a parsed model by display name.
36+
#[must_use]
37+
pub fn get(name: &str) -> Option<Arc<NamModel>> {
38+
let map = store()
39+
.read()
40+
.unwrap_or_else(std::sync::PoisonError::into_inner);
41+
map.get(name).cloned()
42+
}
43+
44+
/// Sorted list of available model display names.
45+
#[must_use]
46+
pub fn available_names() -> Vec<String> {
47+
let map = store()
48+
.read()
49+
.unwrap_or_else(std::sync::PoisonError::into_inner);
50+
let mut names: Vec<String> = map.keys().cloned().collect();
51+
names.sort();
52+
names
53+
}

0 commit comments

Comments
 (0)