Skip to content

Commit 49a6fcc

Browse files
authored
Fix Rust RSI moving-average type and max-value regression (#4382)
The Rust-core RSI ignored the `ma_type` argument (inner gain/loss averages were always built as Exponential) and never advanced `last_value` on zero-loss bars, pinning RSI at `rsi_max` (1.0) even after real down-moves. Both defects diverged from the Cython source of truth; the second is the same bug fixed for Cython in #2703. - Resolve `ma_type` once and pass it into `MovingAverageFactory::create` for both averages. - Advance `last_value` in the zero-loss branch before the early return. - Add regression tests: ma_type distinctness, recovery below max after losses, and a Wilder golden series vs published reference values.
1 parent 89becc7 commit 49a6fcc

1 file changed

Lines changed: 76 additions & 4 deletions

File tree

  • crates/indicators/src/momentum

crates/indicators/src/momentum/rsi.rs

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,16 @@ impl RelativeStrengthIndex {
9595
/// Creates a new [`RelativeStrengthIndex`] instance.
9696
#[must_use]
9797
pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
98+
let ma_type = ma_type.unwrap_or(MovingAverageType::Exponential);
9899
Self {
99100
period,
100-
ma_type: ma_type.unwrap_or(MovingAverageType::Exponential),
101+
ma_type,
101102
value: 0.0,
102103
last_value: 0.0,
103104
count: 0,
104105
has_inputs: false,
105-
average_gain: MovingAverageFactory::create(MovingAverageType::Exponential, period),
106-
average_loss: MovingAverageFactory::create(MovingAverageType::Exponential, period),
106+
average_gain: MovingAverageFactory::create(ma_type, period),
107+
average_loss: MovingAverageFactory::create(ma_type, period),
107108
rsi_max: 1.0,
108109
initialized: false,
109110
}
@@ -132,6 +133,7 @@ impl RelativeStrengthIndex {
132133

133134
if self.average_loss.value() == 0.0 {
134135
self.value = self.rsi_max;
136+
self.last_value = value;
135137
return;
136138
}
137139

@@ -150,7 +152,10 @@ mod tests {
150152
use nautilus_model::data::{Bar, QuoteTick, TradeTick};
151153
use rstest::rstest;
152154

153-
use crate::{indicator::Indicator, momentum::rsi::RelativeStrengthIndex, stubs::*};
155+
use crate::{
156+
average::MovingAverageType, indicator::Indicator, momentum::rsi::RelativeStrengthIndex,
157+
stubs::*,
158+
};
154159

155160
#[rstest]
156161
fn test_rsi_initialized(rsi_10: RelativeStrengthIndex) {
@@ -271,4 +276,71 @@ mod tests {
271276
assert!(!rsi_10.has_inputs());
272277
assert_eq!(rsi_10.value, 0.0);
273278
}
279+
280+
// Feeds `values` through a fresh RSI of the given `ma_type` and returns the final value.
281+
fn run_rsi(values: &[f64], period: usize, ma_type: MovingAverageType) -> f64 {
282+
let mut rsi = RelativeStrengthIndex::new(period, Some(ma_type));
283+
for &v in values {
284+
rsi.update_raw(v);
285+
}
286+
rsi.value
287+
}
288+
289+
#[rstest]
290+
fn test_ma_type_is_plumbed_into_inner_averages() {
291+
// The `ma_type` argument must reach the inner gain/loss averages, so distinct
292+
// moving-average types must produce distinct output on the same series.
293+
// Previously all types collapsed onto Exponential (see issue: v2 RSI ignores ma_type).
294+
let series = [
295+
44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
296+
45.61, 46.28, 46.28,
297+
];
298+
299+
let wilder = run_rsi(&series, 14, MovingAverageType::Wilder);
300+
let simple = run_rsi(&series, 14, MovingAverageType::Simple);
301+
let exponential = run_rsi(&series, 14, MovingAverageType::Exponential);
302+
303+
assert_ne!(wilder, simple);
304+
assert_ne!(wilder, exponential);
305+
assert_ne!(simple, exponential);
306+
}
307+
308+
#[rstest]
309+
fn test_recovers_below_max_after_losses() {
310+
// Regression for the flat-1.0 defect (mirrors Cython fix #2703): once real down-moves
311+
// arrive, RSI must fall below `rsi_max` rather than staying pinned at 1.0 because
312+
// `last_value` was never advanced on zero-loss bars.
313+
let mut values: Vec<f64> = (1..=15).map(f64::from).collect();
314+
values.extend([14.0, 12.0, 9.0, 5.0, 2.0]);
315+
316+
let value = run_rsi(&values, 14, MovingAverageType::Wilder);
317+
assert!(
318+
value < 1.0,
319+
"RSI should drop below rsi_max after losses, was {value}"
320+
);
321+
}
322+
323+
#[rstest]
324+
fn test_wilder_golden_series() {
325+
// Golden reference: up 1..15 then down 14, 12, 9, 5, 2 with period 14, Wilder MA.
326+
// Expected Wilder RSI values (×100) after each down-move, per the published reference.
327+
let base: Vec<f64> = (1..=15).map(f64::from).collect();
328+
let downs = [14.0, 12.0, 9.0, 5.0, 2.0];
329+
let expected = [0.8935, 0.7269, 0.5586, 0.4192, 0.3489];
330+
331+
let mut rsi = RelativeStrengthIndex::new(14, Some(MovingAverageType::Wilder));
332+
for &v in &base {
333+
rsi.update_raw(v);
334+
}
335+
336+
for (i, &v) in downs.iter().enumerate() {
337+
rsi.update_raw(v);
338+
assert!(
339+
(rsi.value - expected[i]).abs() < 1e-4,
340+
"step {i}: expected {}, was {}",
341+
expected[i],
342+
rsi.value
343+
);
344+
}
345+
}
274346
}

0 commit comments

Comments
 (0)