Skip to content

Commit ce9c2bb

Browse files
committed
add StatefulInferer::replace_inferer
1 parent dbb88c7 commit ce9c2bb

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
<!-- next-header -->
1111
## [Unreleased] - ReleaseDate
12+
13+
- Add `StatefulInferer::replace_inferer` which works with a `&mut
14+
StatefulInferer`, at the cost of requiring the inferer to be of the
15+
same type.
16+
1217
## [0.9.0] - 2025-09-04
1318

1419
- Breaking: `Inferer.begin_agent` and `Inferer.end_agent` now take

crates/cervo-core/src/wrapper.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ pub struct StatefulInferer<WrapStack: InfererWrapper, Inf: Inferer> {
110110
}
111111

112112
impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
113+
/// Construct a new [`StatefulInferer`] by wrapping the given
114+
/// inferer with the given wrapper stack.
113115
pub fn new(wrapper_stack: WrapStack, inferer: Inf) -> Self {
114116
Self {
115117
wrapper_stack,
@@ -136,6 +138,21 @@ impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
136138
})
137139
}
138140

141+
/// Replace the inner inferer with a new inferer while maintaining
142+
/// any state in wrappers.
143+
///
144+
/// Requires that the shapes of the policies are compatible, but
145+
/// they may be different concrete inferer implementations. If
146+
/// this check fails, will not change self.
147+
pub fn replace_inferer(&mut self, new_inferer: Inf) -> anyhow::Result<()> {
148+
if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
149+
Err(e)
150+
} else {
151+
self.inferer = new_inferer;
152+
Ok(())
153+
}
154+
}
155+
139156
/// Validate that [`Old`] and [`New`] are compatible with each
140157
/// other.
141158
pub fn check_compatible_shapes<Old: Inferer, New: Inferer>(

0 commit comments

Comments
 (0)