@@ -49,7 +49,10 @@ public class Game : MidiExtractor<GameOptions> {
4949 InferenceSession ? segmenterSession ;
5050 InferenceSession ? estimatorSession ;
5151 InferenceSession ? bd2durSession ;
52+ RunOptions ? runOptions ;
5253 bool sessionsLoaded = false ;
54+ bool disposed = false ;
55+ volatile bool stopping = false ;
5356 GameConfig config ;
5457 string Location ;
5558
@@ -97,11 +100,23 @@ public Game(string? location) {
97100 /// </summary>
98101 private void EnsureSessionsLoaded ( ) {
99102 if ( sessionsLoaded ) return ;
103+ if ( stopping ) {
104+ throw new OperationCanceledException ( ) ;
105+ }
106+ runOptions = new RunOptions ( ) ;
107+ if ( stopping ) {
108+ runOptions . Terminate = true ;
109+ throw new OperationCanceledException ( ) ;
110+ }
100111 encoderSession = CreateSession ( "encoder.onnx" , OnnxRunnerChoice . CPUForCoreML ) ;
101112 segmenterSession = CreateSession ( "segmenter.onnx" , OnnxRunnerChoice . Default ) ;
102113 estimatorSession = CreateSession ( "estimator.onnx" , OnnxRunnerChoice . Default ) ;
103114 bd2durSession = CreateSession ( "bd2dur.onnx" , OnnxRunnerChoice . Default ) ;
104115 sessionsLoaded = true ;
116+ if ( stopping ) {
117+ runOptions . Terminate = true ;
118+ throw new OperationCanceledException ( ) ;
119+ }
105120 }
106121
107122 protected override bool SupportsBatch => true ;
@@ -131,59 +146,76 @@ private List<List<TranscribedNote>> RunPipeline(List<float[]> batch, GameOptions
131146 var waveform = new DenseTensor < float > ( waveformData , new [ ] { B , maxLen } ) ;
132147 var duration = new DenseTensor < float > ( durationData , new [ ] { B } ) ;
133148
134- // 1. Encoder
135- var ( xSeg , xEst , maskT ) = RunEncoder ( waveform , duration ) ;
149+ try {
150+ // 1. Encoder
151+ var ( xSeg , xEst , maskT ) = RunEncoder ( waveform , duration ) ;
136152
137- // 2. Segmentation (D3PM loop)
138- int T = xSeg . Dimensions [ 1 ] ;
139- Tensor < bool > knownBoundaries = new DenseTensor < bool > ( new [ ] { B , T } ) ;
140- Tensor < bool > boundaries = new DenseTensor < bool > ( new [ ] { B , T } ) ;
153+ // 2. Segmentation (D3PM loop)
154+ int T = xSeg . Dimensions [ 1 ] ;
155+ Tensor < bool > knownBoundaries = new DenseTensor < bool > ( new [ ] { B , T } ) ;
156+ Tensor < bool > boundaries = new DenseTensor < bool > ( new [ ] { B , T } ) ;
141157
142- Tensor < long > ? language = null ;
143- if ( config . Languages != null ) {
144- int languageId = ResolveLanguageId ( options . LanguageCode ) ;
145- language = new DenseTensor < long > (
146- Enumerable . Repeat ( ( long ) languageId , B ) . ToArray ( ) , new [ ] { B } ) ;
147- }
148-
149- var segThreshold = new DenseTensor < float > ( new [ ] { options . BoundaryThreshold } , Array . Empty < int > ( ) ) ;
150- var radius = new DenseTensor < long > ( new long [ ] { options . BoundaryRadius } , Array . Empty < int > ( ) ) ;
158+ Tensor < long > ? language = null ;
159+ if ( config . Languages != null ) {
160+ int languageId = ResolveLanguageId ( options . LanguageCode ) ;
161+ language = new DenseTensor < long > (
162+ Enumerable . Repeat ( ( long ) languageId , B ) . ToArray ( ) , new [ ] { B } ) ;
163+ }
151164
152- if ( config . Loop ) {
153- float step = 1.0f / options . SamplingSteps ;
154- for ( int i = 0 ; i < options . SamplingSteps ; i ++ ) {
155- var t = new DenseTensor < float > (
156- Enumerable . Repeat ( i * step , B ) . ToArray ( ) , new [ ] { B } ) ;
157- boundaries = RunSegmenter ( xSeg , knownBoundaries , boundaries , t , maskT , language , segThreshold , radius ) ;
165+ var segThreshold = new DenseTensor < float > ( new [ ] { options . BoundaryThreshold } , Array . Empty < int > ( ) ) ;
166+ var radius = new DenseTensor < long > ( new long [ ] { options . BoundaryRadius } , Array . Empty < int > ( ) ) ;
167+
168+ if ( config . Loop ) {
169+ float step = 1.0f / options . SamplingSteps ;
170+ for ( int i = 0 ; i < options . SamplingSteps ; i ++ ) {
171+ var t = new DenseTensor < float > (
172+ Enumerable . Repeat ( i * step , B ) . ToArray ( ) , new [ ] { B } ) ;
173+ boundaries = RunSegmenter ( xSeg , knownBoundaries , boundaries , t , maskT , language , segThreshold , radius ) ;
174+ }
175+ } else {
176+ boundaries = RunSegmenter ( xSeg , knownBoundaries , null , null , maskT , language , segThreshold , radius ) ;
158177 }
159- } else {
160- boundaries = RunSegmenter ( xSeg , knownBoundaries , null , null , maskT , language , segThreshold , radius ) ;
161- }
162178
163- // 3. Boundaries to durations
164- var ( durations , maskN ) = RunBd2Dur ( boundaries , maskT ) ;
165- int N = maskN . Dimensions [ 1 ] ;
179+ // 3. Boundaries to durations
180+ var ( durations , maskN ) = RunBd2Dur ( boundaries , maskT ) ;
181+ int N = maskN . Dimensions [ 1 ] ;
166182
167- // 4. Estimation
168- var scoreThreshold = new DenseTensor < float > ( new [ ] { options . ScoreThreshold } , Array . Empty < int > ( ) ) ;
169- var ( presence , scores ) = RunEstimator ( xEst , boundaries , maskT , maskN , scoreThreshold ) ;
183+ // 4. Estimation
184+ var scoreThreshold = new DenseTensor < float > ( new [ ] { options . ScoreThreshold } , Array . Empty < int > ( ) ) ;
185+ var ( presence , scores ) = RunEstimator ( xEst , boundaries , maskT , maskN , scoreThreshold ) ;
170186
171- // 5. Split results per batch item
172- var results = new List < List < TranscribedNote > > ( B ) ;
173- for ( int b = 0 ; b < B ; b ++ ) {
174- var notes = new List < TranscribedNote > ( N ) ;
175- for ( int i = 0 ; i < N ; i ++ ) {
176- if ( ! maskN [ b , i ] ) break ;
177- notes . Add ( new TranscribedNote ( durations [ b , i ] , scores [ b , i ] , presence [ b , i ] ) ) ;
187+ // 5. Split results per batch item
188+ var results = new List < List < TranscribedNote > > ( B ) ;
189+ for ( int b = 0 ; b < B ; b ++ ) {
190+ var notes = new List < TranscribedNote > ( N ) ;
191+ for ( int i = 0 ; i < N ; i ++ ) {
192+ if ( ! maskN [ b , i ] ) break ;
193+ notes . Add ( new TranscribedNote ( durations [ b , i ] , scores [ b , i ] , presence [ b , i ] ) ) ;
194+ }
195+
196+ results . Add ( notes ) ;
178197 }
179198
180- results . Add ( notes ) ;
199+ return results ;
200+ } catch ( OnnxRuntimeException ) {
201+ if ( runOptions != null && runOptions . Terminate ) {
202+ throw new OperationCanceledException ( ) ;
203+ }
204+ throw ;
181205 }
206+ }
182207
183- return results ;
208+ public override void Interrupt ( ) {
209+ stopping = true ;
210+ if ( ! disposed && runOptions != null ) {
211+ runOptions . Terminate = true ;
212+ }
184213 }
185214
186215 protected override void DisposeManaged ( ) {
216+ if ( disposed ) return ;
217+ disposed = true ;
218+ runOptions ? . Dispose ( ) ;
187219 encoderSession ? . Dispose ( ) ;
188220 segmenterSession ? . Dispose ( ) ;
189221 estimatorSession ? . Dispose ( ) ;
@@ -228,7 +260,7 @@ private int ResolveLanguageId(string? languageCode) {
228260 NamedOnnxValue . CreateFromTensor ( "duration" , duration ) ,
229261 } ;
230262
231- using var outputs = encoderSession ! . Run ( inputs ) ;
263+ using var outputs = encoderSession ! . Run ( inputs , encoderSession . OutputNames , runOptions ) ;
232264
233265 var xSeg = outputs . First ( o => o . Name == "x_seg" ) . AsTensor < float > ( ) . ToDenseTensor ( ) ;
234266 var xEst = outputs . First ( o => o . Name == "x_est" ) . AsTensor < float > ( ) . ToDenseTensor ( ) ;
@@ -267,7 +299,7 @@ private Tensor<bool> RunSegmenter(
267299 inputs . Add ( NamedOnnxValue . CreateFromTensor ( "threshold" , threshold ) ) ;
268300 inputs . Add ( NamedOnnxValue . CreateFromTensor ( "radius" , radius ) ) ;
269301
270- using var outputs = segmenterSession ! . Run ( inputs ) ;
302+ using var outputs = segmenterSession ! . Run ( inputs , segmenterSession . OutputNames , runOptions ) ;
271303 var boundaries = outputs . First ( o => o . Name == "boundaries" ) . AsTensor < bool > ( ) . ToDenseTensor ( ) ;
272304 return boundaries ;
273305 }
@@ -282,7 +314,7 @@ private Tensor<bool> RunSegmenter(
282314 NamedOnnxValue . CreateFromTensor ( "maskT" , maskT ) ,
283315 } ;
284316
285- using var outputs = bd2durSession ! . Run ( inputs ) ;
317+ using var outputs = bd2durSession ! . Run ( inputs , bd2durSession . OutputNames , runOptions ) ;
286318 var durations = outputs . First ( o => o . Name == "durations" ) . AsTensor < float > ( ) . ToDenseTensor ( ) ;
287319 var maskN = outputs . First ( o => o . Name == "maskN" ) . AsTensor < bool > ( ) . ToDenseTensor ( ) ;
288320
@@ -303,7 +335,7 @@ private Tensor<bool> RunSegmenter(
303335 NamedOnnxValue . CreateFromTensor ( "threshold" , threshold ) ,
304336 } ;
305337
306- using var outputs = estimatorSession ! . Run ( inputs ) ;
338+ using var outputs = estimatorSession ! . Run ( inputs , estimatorSession . OutputNames , runOptions ) ;
307339 var presence = outputs . First ( o => o . Name == "presence" ) . AsTensor < bool > ( ) . ToDenseTensor ( ) ;
308340 var scores = outputs . First ( o => o . Name == "scores" ) . AsTensor < float > ( ) . ToDenseTensor ( ) ;
309341 return ( presence , scores ) ;
0 commit comments