Skip to content

Commit 4df4d86

Browse files
committed
wip on wrappers
1 parent 298c685 commit 4df4d86

File tree

6 files changed

+483
-90
lines changed

6 files changed

+483
-90
lines changed

crates/cervo-cli/src/commands/benchmark.rs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use anyhow::{bail, Result};
22
use cervo::asset::AssetData;
3+
use cervo::core::epsilon::EpsilonInjectorWrapper;
4+
use cervo::core::model::{BaseCase, Model, ModelWrapper};
35
use cervo::core::prelude::{Batcher, Inferer, InfererExt, State};
4-
use cervo::core::recurrent::{RecurrentInfo, RecurrentTracker};
6+
use cervo::core::recurrent::{RecurrentInfo, RecurrentTracker, RecurrentTrackerWrapper};
57
use clap::Parser;
68
use clap::ValueEnum;
79
use serde::Serialize;
@@ -222,55 +224,68 @@ pub fn build_inputs_from_desc(
222224
.collect()
223225
}
224226

225-
fn do_run(mut inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> {
226-
let shapes = inferer.input_shapes().to_vec();
227+
fn do_run(
228+
wrapper: impl ModelWrapper,
229+
mut inferer: impl Inferer,
230+
batch_size: usize,
231+
config: &Args,
232+
) -> Result<Record> {
233+
let mut model = Model::new(wrapper, inferer);
234+
235+
let shapes = model.input_shapes().to_vec();
227236
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
228237
for id in 0..batch_size {
229-
inferer.begin_agent(id as u64);
238+
model.begin_agent(id as u64);
230239
}
231-
let res = execute_load_metrics(batch_size, observations, config.count, &mut inferer)?;
240+
let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?;
232241
for id in 0..batch_size {
233-
inferer.end_agent(id as u64);
242+
model.end_agent(id as u64);
234243
}
235244

236245
Ok(res)
237246
}
238247

239248
fn run_apply_epsilon_config(
249+
wrapper: impl ModelWrapper,
240250
inferer: impl Inferer,
241251
batch_size: usize,
242252
config: &Args,
243253
) -> Result<Record> {
244254
if let Some(epsilon) = config.with_epsilon.as_ref() {
245-
let inferer = inferer.with_default_epsilon(epsilon)?;
246-
do_run(inferer, batch_size, config)
255+
let wrapper = EpsilonInjectorWrapper::wrap(wrapper, &inferer, epsilon)?;
256+
do_run(wrapper, inferer, batch_size, config)
247257
} else {
248-
do_run(inferer, batch_size, config)
258+
do_run(wrapper, inferer, batch_size, config)
249259
}
250260
}
251261

252-
fn run_apply_recurrent(inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> {
262+
fn run_apply_recurrent(
263+
wrapper: impl ModelWrapper,
264+
inferer: impl Inferer,
265+
batch_size: usize,
266+
config: &Args,
267+
) -> Result<Record> {
253268
if let Some(recurrent) = config.recurrent.as_ref() {
254269
if matches!(recurrent, RecurrentConfig::None) {
255-
run_apply_epsilon_config(inferer, batch_size, config)
270+
run_apply_epsilon_config(wrapper, inferer, batch_size, config)
256271
} else {
257-
let inferer = match recurrent {
272+
let wrapper = match recurrent {
258273
RecurrentConfig::None => unreachable!(),
259-
RecurrentConfig::Auto => RecurrentTracker::wrap(inferer),
274+
RecurrentConfig::Auto => RecurrentTrackerWrapper::wrap(wrapper, &inferer),
260275
RecurrentConfig::Mapped(map) => {
261276
let infos = map
262277
.iter()
263278
.cloned()
264279
.map(|(inkey, outkey)| RecurrentInfo { inkey, outkey })
265280
.collect::<Vec<_>>();
266-
RecurrentTracker::new(inferer, infos)
281+
RecurrentTrackerWrapper::new(wrapper, &inferer, infos)
267282
}
268283
}?;
269284

270-
run_apply_epsilon_config(inferer, batch_size, config)
285+
run_apply_epsilon_config(wrapper, inferer, batch_size, config)
271286
}
272287
} else {
273-
run_apply_epsilon_config(inferer, batch_size, config)
288+
run_apply_epsilon_config(wrapper, inferer, batch_size, config)
274289
}
275290
}
276291

@@ -289,7 +304,7 @@ pub(super) fn run(config: Args) -> Result<()> {
289304
}
290305
};
291306

292-
let record = run_apply_recurrent(inferer, batch_size, &config)?;
307+
let record = run_apply_recurrent(BaseCase, inferer, batch_size, &config)?;
293308

294309
// Print Text
295310
if matches!(config.output, OutputFormat::Text) {

crates/cervo-cli/src/commands/run.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,7 @@ pub(super) fn run(config: Args) -> Result<()> {
8484

8585
let elapsed = if let Some(epsilon) = config.with_epsilon.as_ref() {
8686
let inferer = inferer.with_default_epsilon(epsilon)?;
87-
// TODO[TSolberg]: Issue #31.
88-
let shapes = inferer
89-
.raw_input_shapes()
90-
.iter()
91-
.filter(|(k, _)| k.as_str() != epsilon)
92-
.cloned()
93-
.collect::<Vec<_>>();
94-
95-
let observations = build_inputs_from_desc(config.batch_size as u64, &shapes);
87+
let observations = build_inputs_from_desc(config.batch_size as u64, inferer.input_shapes());
9688

9789
if config.print_input {
9890
print_input(&observations);

crates/cervo-core/src/epsilon.rs

Lines changed: 113 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Utilities for filling noise inputs for an inference model.
88

99
use std::cell::RefCell;
1010

11-
use crate::{batcher::ScratchPadView, inferer::Inferer};
11+
use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::ModelWrapper};
1212
use anyhow::{bail, Result};
1313
use perchance::PerchanceContext;
1414
use rand::thread_rng;
@@ -112,6 +112,13 @@ impl NoiseGenerator for HighQualityNoiseGenerator {
112112
}
113113
}
114114

115+
struct EpsilonInjectorState<NG: NoiseGenerator> {
116+
count: usize,
117+
index: usize,
118+
generator: NG,
119+
120+
inputs: Vec<(String, Vec<usize>)>,
121+
}
115122
/// The [`EpsilonInjector`] wraps an inferer to add noise values as one of the input data points. This is useful for
116123
/// continuous action policies where you might have trained your agent to follow a stochastic policy trained with the
117124
/// reparametrization trick.
@@ -120,11 +127,8 @@ impl NoiseGenerator for HighQualityNoiseGenerator {
120127
/// wrapper.
121128
pub struct EpsilonInjector<T: Inferer, NG: NoiseGenerator = HighQualityNoiseGenerator> {
122129
inner: T,
123-
count: usize,
124-
index: usize,
125-
generator: NG,
126130

127-
inputs: Vec<(String, Vec<usize>)>,
131+
state: EpsilonInjectorState<NG>,
128132
}
129133

130134
impl<T> EpsilonInjector<T, HighQualityNoiseGenerator>
@@ -169,11 +173,12 @@ where
169173

170174
Ok(Self {
171175
inner: inferer,
172-
index,
173-
count,
174-
generator,
175-
176-
inputs,
176+
state: EpsilonInjectorState {
177+
index,
178+
count,
179+
generator,
180+
inputs,
181+
},
177182
})
178183
}
179184
}
@@ -188,15 +193,15 @@ where
188193
}
189194

190195
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
191-
let total_count = self.count * batch.len();
192-
let output = batch.input_slot_mut(self.index);
193-
self.generator.generate(total_count, output);
196+
let total_count = self.state.count * batch.len();
197+
let output = batch.input_slot_mut(self.state.index);
198+
self.state.generator.generate(total_count, output);
194199

195200
self.inner.infer_raw(batch)
196201
}
197202

198203
fn input_shapes(&self) -> &[(String, Vec<usize>)] {
199-
&self.inputs
204+
&self.state.inputs
200205
}
201206

202207
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
@@ -215,3 +220,97 @@ where
215220
self.inner.end_agent(id);
216221
}
217222
}
223+
224+
pub struct EpsilonInjectorWrapper<Inner: ModelWrapper, NG: NoiseGenerator> {
225+
inner: Inner,
226+
state: EpsilonInjectorState<NG>,
227+
}
228+
229+
impl<Inner: ModelWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
230+
/// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
231+
///
232+
/// This function will use [`HighQualityNoiseGenerator`] as the noise source.
233+
///
234+
/// # Errors
235+
///
236+
/// Will return an error if the provided key doesn't match an input on the model.
237+
pub fn wrap(
238+
inner: Inner,
239+
inferer: &dyn Inferer,
240+
key: &str,
241+
) -> Result<EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator>> {
242+
Self::with_generator(inner, inferer, HighQualityNoiseGenerator::default(), key)
243+
}
244+
}
245+
246+
impl<Inner, NG> EpsilonInjectorWrapper<Inner, NG>
247+
where
248+
Inner: ModelWrapper,
249+
NG: NoiseGenerator,
250+
{
251+
/// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
252+
///
253+
/// # Errors
254+
///
255+
/// Will return an error if the provided key doesn't match an input on the model.
256+
pub fn with_generator(
257+
inner: Inner,
258+
inferer: &dyn Inferer,
259+
generator: NG,
260+
key: &str,
261+
) -> Result<Self> {
262+
let inputs = inferer.input_shapes();
263+
264+
let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
265+
Some((index, (_, shape))) => (index, shape.iter().product()),
266+
None => bail!("model has no input key {:?}", key),
267+
};
268+
269+
let inputs = inputs
270+
.iter()
271+
.filter(|(k, _)| *k != key)
272+
.map(|(k, v)| (k.to_owned(), v.to_owned()))
273+
.collect::<Vec<_>>();
274+
275+
Ok(Self {
276+
inner,
277+
state: EpsilonInjectorState {
278+
index,
279+
count,
280+
generator,
281+
inputs,
282+
},
283+
})
284+
}
285+
}
286+
287+
impl<Inner, NG> ModelWrapper for EpsilonInjectorWrapper<Inner, NG>
288+
where
289+
Inner: ModelWrapper,
290+
NG: NoiseGenerator,
291+
{
292+
fn invoke(&self, inferer: &impl Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
293+
self.inner.invoke(inferer, batch)?;
294+
let total_count = self.state.count * batch.len();
295+
let output = batch.input_slot_mut(self.state.index);
296+
self.state.generator.generate(total_count, output);
297+
298+
self.inner.invoke(inferer, batch)
299+
}
300+
301+
fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
302+
self.state.inputs.as_ref()
303+
}
304+
305+
fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
306+
self.inner.output_shapes(inferer)
307+
}
308+
309+
fn begin_agent(&self, id: u64) {
310+
self.inner.begin_agent(id)
311+
}
312+
313+
fn end_agent(&self, id: u64) {
314+
self.inner.end_agent(id)
315+
}
316+
}

crates/cervo-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub use tract_hir;
1515
pub mod batcher;
1616
pub mod epsilon;
1717
pub mod inferer;
18+
pub mod model;
1819
mod model_api;
1920
pub mod recurrent;
2021

@@ -29,6 +30,7 @@ pub mod prelude {
2930
InfererProvider, MemoizingDynamicInferer, Response, State,
3031
};
3132

33+
pub use super::model::ModelWrapper;
3234
pub use super::model_api::ModelApi;
3335
pub use super::recurrent::{RecurrentInfo, RecurrentTracker};
3436
}

0 commit comments

Comments
 (0)