@@ -8,7 +8,7 @@ Utilities for filling noise inputs for an inference model.
88
99use std:: cell:: RefCell ;
1010
11- use crate :: { batcher:: ScratchPadView , inferer:: Inferer } ;
11+ use crate :: { batcher:: ScratchPadView , inferer:: Inferer , prelude :: ModelWrapper } ;
1212use anyhow:: { bail, Result } ;
1313use perchance:: PerchanceContext ;
1414use rand:: thread_rng;
@@ -112,6 +112,13 @@ impl NoiseGenerator for HighQualityNoiseGenerator {
112112 }
113113}
114114
115+ struct EpsilonInjectorState < NG : NoiseGenerator > {
116+ count : usize ,
117+ index : usize ,
118+ generator : NG ,
119+
120+ inputs : Vec < ( String , Vec < usize > ) > ,
121+ }
115122/// The [`EpsilonInjector`] wraps an inferer to add noise values as one of the input data points. This is useful for
116123/// continuous action policies where you might have trained your agent to follow a stochastic policy trained with the
117124/// reparametrization trick.
@@ -120,11 +127,8 @@ impl NoiseGenerator for HighQualityNoiseGenerator {
120127/// wrapper.
121128pub struct EpsilonInjector < T : Inferer , NG : NoiseGenerator = HighQualityNoiseGenerator > {
122129 inner : T ,
123- count : usize ,
124- index : usize ,
125- generator : NG ,
126130
127- inputs : Vec < ( String , Vec < usize > ) > ,
131+ state : EpsilonInjectorState < NG > ,
128132}
129133
130134impl < T > EpsilonInjector < T , HighQualityNoiseGenerator >
@@ -169,11 +173,12 @@ where
169173
170174 Ok ( Self {
171175 inner : inferer,
172- index,
173- count,
174- generator,
175-
176- inputs,
176+ state : EpsilonInjectorState {
177+ index,
178+ count,
179+ generator,
180+ inputs,
181+ } ,
177182 } )
178183 }
179184}
@@ -188,15 +193,15 @@ where
188193 }
189194
190195 fn infer_raw ( & self , batch : & mut ScratchPadView < ' _ > ) -> Result < ( ) , anyhow:: Error > {
191- let total_count = self . count * batch. len ( ) ;
192- let output = batch. input_slot_mut ( self . index ) ;
193- self . generator . generate ( total_count, output) ;
196+ let total_count = self . state . count * batch. len ( ) ;
197+ let output = batch. input_slot_mut ( self . state . index ) ;
198+ self . state . generator . generate ( total_count, output) ;
194199
195200 self . inner . infer_raw ( batch)
196201 }
197202
198203 fn input_shapes ( & self ) -> & [ ( String , Vec < usize > ) ] {
199- & self . inputs
204+ & self . state . inputs
200205 }
201206
202207 fn raw_input_shapes ( & self ) -> & [ ( String , Vec < usize > ) ] {
@@ -215,3 +220,97 @@ where
215220 self . inner . end_agent ( id) ;
216221 }
217222}
223+
224+ pub struct EpsilonInjectorWrapper < Inner : ModelWrapper , NG : NoiseGenerator > {
225+ inner : Inner ,
226+ state : EpsilonInjectorState < NG > ,
227+ }
228+
229+ impl < Inner : ModelWrapper > EpsilonInjectorWrapper < Inner , HighQualityNoiseGenerator > {
230+ /// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
231+ ///
232+ /// This function will use [`HighQualityNoiseGenerator`] as the noise source.
233+ ///
234+ /// # Errors
235+ ///
236+ /// Will return an error if the provided key doesn't match an input on the model.
237+ pub fn wrap (
238+ inner : Inner ,
239+ inferer : & dyn Inferer ,
240+ key : & str ,
241+ ) -> Result < EpsilonInjectorWrapper < Inner , HighQualityNoiseGenerator > > {
242+ Self :: with_generator ( inner, inferer, HighQualityNoiseGenerator :: default ( ) , key)
243+ }
244+ }
245+
246+ impl < Inner , NG > EpsilonInjectorWrapper < Inner , NG >
247+ where
248+ Inner : ModelWrapper ,
249+ NG : NoiseGenerator ,
250+ {
251+ /// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
252+ ///
253+ /// # Errors
254+ ///
255+ /// Will return an error if the provided key doesn't match an input on the model.
256+ pub fn with_generator (
257+ inner : Inner ,
258+ inferer : & dyn Inferer ,
259+ generator : NG ,
260+ key : & str ,
261+ ) -> Result < Self > {
262+ let inputs = inferer. input_shapes ( ) ;
263+
264+ let ( index, count) = match inputs. iter ( ) . enumerate ( ) . find ( |( _, ( k, _) ) | k == key) {
265+ Some ( ( index, ( _, shape) ) ) => ( index, shape. iter ( ) . product ( ) ) ,
266+ None => bail ! ( "model has no input key {:?}" , key) ,
267+ } ;
268+
269+ let inputs = inputs
270+ . iter ( )
271+ . filter ( |( k, _) | * k != key)
272+ . map ( |( k, v) | ( k. to_owned ( ) , v. to_owned ( ) ) )
273+ . collect :: < Vec < _ > > ( ) ;
274+
275+ Ok ( Self {
276+ inner,
277+ state : EpsilonInjectorState {
278+ index,
279+ count,
280+ generator,
281+ inputs,
282+ } ,
283+ } )
284+ }
285+ }
286+
287+ impl < Inner , NG > ModelWrapper for EpsilonInjectorWrapper < Inner , NG >
288+ where
289+ Inner : ModelWrapper ,
290+ NG : NoiseGenerator ,
291+ {
292+ fn invoke ( & self , inferer : & impl Inferer , batch : & mut ScratchPadView < ' _ > ) -> anyhow:: Result < ( ) > {
293+ self . inner . invoke ( inferer, batch) ?;
294+ let total_count = self . state . count * batch. len ( ) ;
295+ let output = batch. input_slot_mut ( self . state . index ) ;
296+ self . state . generator . generate ( total_count, output) ;
297+
298+ self . inner . invoke ( inferer, batch)
299+ }
300+
301+ fn input_shapes < ' a > ( & ' a self , _inferer : & ' a dyn Inferer ) -> & ' a [ ( String , Vec < usize > ) ] {
302+ self . state . inputs . as_ref ( )
303+ }
304+
305+ fn output_shapes < ' a > ( & ' a self , inferer : & ' a dyn Inferer ) -> & ' a [ ( String , Vec < usize > ) ] {
306+ self . inner . output_shapes ( inferer)
307+ }
308+
309+ fn begin_agent ( & self , id : u64 ) {
310+ self . inner . begin_agent ( id)
311+ }
312+
313+ fn end_agent ( & self , id : u64 ) {
314+ self . inner . end_agent ( id)
315+ }
316+ }
0 commit comments