@@ -6,7 +6,7 @@ use ort::inputs;
66use ort:: value:: TensorRef ;
77use regex:: Regex ;
88
9- use crate :: asr:: recognizer:: { AsrError , AsrModel , Transcript } ;
9+ use crate :: asr:: recognizer:: { AsrError , AsrModel , InferenceConfig , Transcript } ;
1010
1111type DecoderState = ( Array3 < f32 > , Array3 < f32 > ) ;
1212
@@ -17,102 +17,72 @@ const MAX_TOKENS_PER_STEP: usize = 10;
1717static DECODE_SPACE_RE : LazyLock < Result < Regex , regex:: Error > > =
1818 LazyLock :: new ( || Regex :: new ( r"\A\s|\s\B|(\s)\b" ) ) ;
1919
20- pub ( crate ) struct DecoderWorkspace {
21- encoder_step : Array3 < f32 > ,
22- targets : Array2 < i32 > ,
23- target_length : Array1 < i32 > ,
24- state : DecoderState ,
20+ pub struct DecoderSession < ' m > {
21+ model : & ' m mut AsrModel ,
22+ workspace : DecoderWorkspace ,
23+ last_token : i32 ,
2524}
2625
27- impl DecoderWorkspace {
28- pub ( crate ) fn new ( session : & ort:: session:: Session ) -> Result < Self , AsrError > {
29- let encoder_dim = session
30- . inputs
31- . iter ( )
32- . find ( |input| input. name == "encoder_outputs" )
33- . and_then ( |input| input. input_type . tensor_shape ( ) )
34- . and_then ( |shape| shape. get ( 1 ) . copied ( ) )
35- . and_then ( |d| usize:: try_from ( d) . ok ( ) )
36- . unwrap_or ( 1024 ) ;
37-
38- let state1_shape = session
39- . inputs
40- . iter ( )
41- . find ( |input| input. name == "input_states_1" )
42- . ok_or_else ( || AsrError :: InputNotFound ( "input_states_1" . to_string ( ) ) ) ?
43- . input_type
44- . tensor_shape ( )
45- . ok_or_else ( || AsrError :: TensorShape ( "input_states_1" . to_string ( ) ) ) ?;
46-
47- let state2_shape = session
48- . inputs
49- . iter ( )
50- . find ( |input| input. name == "input_states_2" )
51- . ok_or_else ( || AsrError :: InputNotFound ( "input_states_2" . to_string ( ) ) ) ?
52- . input_type
53- . tensor_shape ( )
54- . ok_or_else ( || AsrError :: TensorShape ( "input_states_2" . to_string ( ) ) ) ?;
26+ impl < ' m > DecoderSession < ' m > {
27+ pub ( crate ) fn new (
28+ model : & ' m mut AsrModel ,
29+ workspace : Option < DecoderWorkspace > ,
30+ last_token : i32 ,
31+ ) -> Result < Self , AsrError > {
32+ let workspace = if let Some ( ws) = workspace {
33+ log:: debug!( "Reusing cached decoder workspace" ) ;
34+ ws
35+ } else {
36+ log:: debug!( "Initializing new decoder workspace" ) ;
37+ DecoderWorkspace :: new ( & model. decoder_joint ) ?
38+ } ;
5539
56- let state1 = Array :: zeros ( ( state1_shape[ 0 ] as usize , 1 , state1_shape[ 2 ] as usize ) ) ;
57- let state2 = Array :: zeros ( ( state2_shape[ 0 ] as usize , 1 , state2_shape[ 2 ] as usize ) ) ;
40+ log:: debug!( "Decoder session initialized with last_token={}" , last_token) ;
5841
5942 Ok ( Self {
60- encoder_step : Array :: zeros ( ( 1 , encoder_dim, 1 ) ) ,
61- targets : Array2 :: zeros ( ( 1 , 1 ) ) ,
62- target_length : Array1 :: from_vec ( vec ! [ 1 ] ) ,
63- state : ( state1, state2) ,
43+ model,
44+ workspace,
45+ last_token,
6446 } )
6547 }
6648
67- #[ inline]
68- pub ( crate ) fn reset_state ( & mut self ) {
69- self . state . 0 . fill ( 0.0 ) ;
70- self . state . 1 . fill ( 0.0 ) ;
71- }
72-
73- pub ( crate ) fn set_encoder_step ( & mut self , frame : & ArrayView1 < f32 > ) {
74- let mut view = self . encoder_step . index_axis_mut ( ndarray:: Axis ( 2 ) , 0 ) ;
75- let mut view = view. index_axis_mut ( ndarray:: Axis ( 0 ) , 0 ) ;
76- view. assign ( frame) ;
77- }
78-
79- pub ( crate ) fn set_target ( & mut self , token : i32 ) {
80- self . targets [ [ 0 , 0 ] ] = token;
49+ pub ( crate ) fn into_parts ( self ) -> ( DecoderWorkspace , i32 ) {
50+ ( self . workspace , self . last_token )
8151 }
82- }
8352
84- impl AsrModel {
8553 pub ( crate ) fn decode_sequence (
8654 & mut self ,
8755 encodings : & ArrayViewD < f32 > ,
8856 encodings_len : usize ,
57+ _config : & InferenceConfig ,
8958 ) -> Result < ( Vec < i32 > , Vec < usize > ) , AsrError > {
9059 let decode_start = Instant :: now ( ) ;
91- let mut tokens = Vec :: with_capacity ( encodings_len / 2 + 4 ) ;
92- let mut timestamps = Vec :: with_capacity ( encodings_len / 2 + 4 ) ;
93-
94- let workspace = & mut self . decoder_workspace ;
95- workspace. reset_state ( ) ;
60+ let mut tokens = Vec :: with_capacity ( std:: cmp:: max ( 1 , encodings_len / 2 ) ) ;
61+ let mut timestamps = Vec :: with_capacity ( std:: cmp:: max ( 1 , encodings_len / 2 ) ) ;
9662
9763 let mut t = 0 ;
9864 let mut emitted_tokens = 0 ;
9965
10066 while t < encodings_len {
10167 let encoder_step = encodings. slice ( ndarray:: s![ t, ..] ) ;
102- workspace. set_encoder_step ( & encoder_step) ;
68+ self . workspace . set_encoder_step ( & encoder_step) ;
10369
104- let target_token = tokens. last ( ) . copied ( ) . unwrap_or ( self . blank_idx ) ;
105- workspace. set_target ( target_token) ;
70+ let target_token = if let Some ( last) = tokens. last ( ) {
71+ * last
72+ } else {
73+ self . last_token
74+ } ;
75+ self . workspace . set_target ( target_token) ;
10676
10777 let inputs = inputs ! [
108- "encoder_outputs" => TensorRef :: from_array_view( workspace. encoder_step. view( ) ) ?,
109- "targets" => TensorRef :: from_array_view( workspace. targets. view( ) ) ?,
110- "target_length" => TensorRef :: from_array_view( workspace. target_length. view( ) ) ?,
111- "input_states_1" => TensorRef :: from_array_view( workspace. state. 0 . view( ) ) ?,
112- "input_states_2" => TensorRef :: from_array_view( workspace. state. 1 . view( ) ) ?,
78+ "encoder_outputs" => TensorRef :: from_array_view( self . workspace. encoder_step. view( ) ) ?,
79+ "targets" => TensorRef :: from_array_view( self . workspace. targets. view( ) ) ?,
80+ "target_length" => TensorRef :: from_array_view( self . workspace. target_length. view( ) ) ?,
81+ "input_states_1" => TensorRef :: from_array_view( self . workspace. state. 0 . view( ) ) ?,
82+ "input_states_2" => TensorRef :: from_array_view( self . workspace. state. 1 . view( ) ) ?,
11383 ] ;
11484
115- let outputs = self . decoder_joint . run ( inputs) ?;
85+ let outputs = self . model . decoder_joint . run ( inputs) ?;
11686
11787 let logits = outputs
11888 . get ( "outputs" )
@@ -127,13 +97,8 @@ impl AsrModel {
12797 ) )
12898 } ) ?;
12999
130- let vocab_logits = if logits. len ( ) > self . vocab_size {
131- log:: trace!(
132- "TDT model detected: splitting {} logits into vocab({}) + duration" ,
133- logits. len( ) ,
134- self . vocab_size
135- ) ;
136- & vocab_logits_slice[ ..self . vocab_size ]
100+ let vocab_logits = if logits. len ( ) > self . model . vocab_size {
101+ & vocab_logits_slice[ ..self . model . vocab_size ]
137102 } else {
138103 vocab_logits_slice
139104 } ;
@@ -142,10 +107,9 @@ impl AsrModel {
142107 . iter ( )
143108 . enumerate ( )
144109 . max_by ( |( _, a) , ( _, b) | a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) )
145- . map ( |( idx, _) | idx as i32 )
146- . unwrap_or ( self . blank_idx ) ;
110+ . map_or ( self . model . blank_idx , |( idx, _) | idx as i32 ) ;
147111
148- if token != self . blank_idx {
112+ if token != self . model . blank_idx {
149113 let state1 = outputs
150114 . get ( "output_states_1" )
151115 . ok_or_else ( || AsrError :: OutputNotFound ( "output_states_1" . to_string ( ) ) ) ?
@@ -156,26 +120,37 @@ impl AsrModel {
156120 . try_extract_array :: < f32 > ( ) ?;
157121
158122 if let Ok ( state1_view) = state1. view ( ) . into_dimensionality :: < ndarray:: Ix3 > ( ) {
159- if workspace. state . 0 . shape ( ) == state1_view. shape ( ) {
160- workspace. state . 0 . assign ( & state1_view) ;
123+ if self . workspace . state . 0 . shape ( ) == state1_view. shape ( ) {
124+ self . workspace . state . 0 . assign ( & state1_view) ;
161125 } else {
162- workspace. state . 0 = state1_view. to_owned ( ) ;
126+ log:: warn!(
127+ "Decoder state_1 shape changed: {:?} -> {:?}" ,
128+ self . workspace. state. 0 . shape( ) ,
129+ state1_view. shape( )
130+ ) ;
131+ self . workspace . state . 0 = state1_view. to_owned ( ) ;
163132 }
164133 }
165134 if let Ok ( state2_view) = state2. view ( ) . into_dimensionality :: < ndarray:: Ix3 > ( ) {
166- if workspace. state . 1 . shape ( ) == state2_view. shape ( ) {
167- workspace. state . 1 . assign ( & state2_view) ;
135+ if self . workspace . state . 1 . shape ( ) == state2_view. shape ( ) {
136+ self . workspace . state . 1 . assign ( & state2_view) ;
168137 } else {
169- workspace. state . 1 = state2_view. to_owned ( ) ;
138+ log:: warn!(
139+ "Decoder state_2 shape changed: {:?} -> {:?}" ,
140+ self . workspace. state. 1 . shape( ) ,
141+ state2_view. shape( )
142+ ) ;
143+ self . workspace . state . 1 = state2_view. to_owned ( ) ;
170144 }
171145 }
172146
173147 tokens. push ( token) ;
174148 timestamps. push ( t) ;
175149 emitted_tokens += 1 ;
150+ self . last_token = token;
176151 }
177152
178- if token == self . blank_idx || emitted_tokens == MAX_TOKENS_PER_STEP {
153+ if token == self . model . blank_idx || emitted_tokens == MAX_TOKENS_PER_STEP {
179154 t += 1 ;
180155 emitted_tokens = 0 ;
181156 }
@@ -190,7 +165,74 @@ impl AsrModel {
190165
191166 Ok ( ( tokens, timestamps) )
192167 }
168+ }
169+
170+ pub ( crate ) struct DecoderWorkspace {
171+ encoder_step : Array3 < f32 > ,
172+ targets : Array2 < i32 > ,
173+ target_length : Array1 < i32 > ,
174+ state : DecoderState ,
175+ }
176+
177+ impl DecoderWorkspace {
178+ pub ( crate ) fn new ( session : & ort:: session:: Session ) -> Result < Self , AsrError > {
179+ let encoder_dim = session
180+ . inputs
181+ . iter ( )
182+ . find ( |input| input. name == "encoder_outputs" )
183+ . and_then ( |input| input. input_type . tensor_shape ( ) )
184+ . and_then ( |shape| shape. get ( 1 ) . copied ( ) )
185+ . and_then ( |d| usize:: try_from ( d) . ok ( ) ) ;
186+
187+ let encoder_dim = match encoder_dim {
188+ Some ( dim) => dim,
189+ None => {
190+ log:: warn!( "Could not determine encoder_dim from model, falling back to 1024" ) ;
191+ 1024
192+ }
193+ } ;
194+
195+ let state1_shape = session
196+ . inputs
197+ . iter ( )
198+ . find ( |input| input. name == "input_states_1" )
199+ . ok_or_else ( || AsrError :: InputNotFound ( "input_states_1" . to_string ( ) ) ) ?
200+ . input_type
201+ . tensor_shape ( )
202+ . ok_or_else ( || AsrError :: TensorShape ( "input_states_1" . to_string ( ) ) ) ?;
203+
204+ let state2_shape = session
205+ . inputs
206+ . iter ( )
207+ . find ( |input| input. name == "input_states_2" )
208+ . ok_or_else ( || AsrError :: InputNotFound ( "input_states_2" . to_string ( ) ) ) ?
209+ . input_type
210+ . tensor_shape ( )
211+ . ok_or_else ( || AsrError :: TensorShape ( "input_states_2" . to_string ( ) ) ) ?;
212+
213+ let state1 = Array :: zeros ( ( state1_shape[ 0 ] as usize , 1 , state1_shape[ 2 ] as usize ) ) ;
214+ let state2 = Array :: zeros ( ( state2_shape[ 0 ] as usize , 1 , state2_shape[ 2 ] as usize ) ) ;
215+
216+ Ok ( Self {
217+ encoder_step : Array :: zeros ( ( 1 , encoder_dim, 1 ) ) ,
218+ targets : Array2 :: zeros ( ( 1 , 1 ) ) ,
219+ target_length : Array1 :: from_vec ( vec ! [ 1 ] ) ,
220+ state : ( state1, state2) ,
221+ } )
222+ }
223+
224+ pub ( crate ) fn set_encoder_step ( & mut self , frame : & ArrayView1 < f32 > ) {
225+ let mut view = self . encoder_step . index_axis_mut ( ndarray:: Axis ( 2 ) , 0 ) ;
226+ let mut view = view. index_axis_mut ( ndarray:: Axis ( 0 ) , 0 ) ;
227+ view. assign ( frame) ;
228+ }
193229
230+ pub ( crate ) fn set_target ( & mut self , token : i32 ) {
231+ self . targets [ [ 0 , 0 ] ] = token;
232+ }
233+ }
234+
235+ impl AsrModel {
194236 pub ( crate ) fn decode_tokens ( & self , ids : Vec < i32 > , timestamps : Vec < usize > ) -> Transcript {
195237 let tokens: Vec < String > = ids
196238 . iter ( )
0 commit comments