Skip to content

Commit e14a086

Browse files
authored
Merge branch 'stakira:master' into master
2 parents 9bf09de + f3d80c1 commit e14a086

File tree

15 files changed

+584
-252
lines changed

15 files changed

+584
-252
lines changed

OpenUtau.Core/Analysis/Game.cs

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ public class Game : MidiExtractor<GameOptions> {
4949
InferenceSession? segmenterSession;
5050
InferenceSession? estimatorSession;
5151
InferenceSession? bd2durSession;
52+
RunOptions? runOptions;
5253
bool sessionsLoaded = false;
54+
bool disposed = false;
55+
volatile bool stopping = false;
5356
GameConfig config;
5457
string Location;
5558

@@ -97,11 +100,23 @@ public Game(string? location) {
97100
/// </summary>
98101
private void EnsureSessionsLoaded() {
99102
if (sessionsLoaded) return;
103+
if (stopping) {
104+
throw new OperationCanceledException();
105+
}
106+
runOptions = new RunOptions();
107+
if (stopping) {
108+
runOptions.Terminate = true;
109+
throw new OperationCanceledException();
110+
}
100111
encoderSession = CreateSession("encoder.onnx", OnnxRunnerChoice.CPUForCoreML);
101112
segmenterSession = CreateSession("segmenter.onnx", OnnxRunnerChoice.Default);
102113
estimatorSession = CreateSession("estimator.onnx", OnnxRunnerChoice.Default);
103114
bd2durSession = CreateSession("bd2dur.onnx", OnnxRunnerChoice.Default);
104115
sessionsLoaded = true;
116+
if (stopping) {
117+
runOptions.Terminate = true;
118+
throw new OperationCanceledException();
119+
}
105120
}
106121

107122
protected override bool SupportsBatch => true;
@@ -131,59 +146,76 @@ private List<List<TranscribedNote>> RunPipeline(List<float[]> batch, GameOptions
131146
var waveform = new DenseTensor<float>(waveformData, new[] { B, maxLen });
132147
var duration = new DenseTensor<float>(durationData, new[] { B });
133148

134-
// 1. Encoder
135-
var (xSeg, xEst, maskT) = RunEncoder(waveform, duration);
149+
try {
150+
// 1. Encoder
151+
var (xSeg, xEst, maskT) = RunEncoder(waveform, duration);
136152

137-
// 2. Segmentation (D3PM loop)
138-
int T = xSeg.Dimensions[1];
139-
Tensor<bool> knownBoundaries = new DenseTensor<bool>(new[] { B, T });
140-
Tensor<bool> boundaries = new DenseTensor<bool>(new[] { B, T });
153+
// 2. Segmentation (D3PM loop)
154+
int T = xSeg.Dimensions[1];
155+
Tensor<bool> knownBoundaries = new DenseTensor<bool>(new[] { B, T });
156+
Tensor<bool> boundaries = new DenseTensor<bool>(new[] { B, T });
141157

142-
Tensor<long>? language = null;
143-
if (config.Languages != null) {
144-
int languageId = ResolveLanguageId(options.LanguageCode);
145-
language = new DenseTensor<long>(
146-
Enumerable.Repeat((long)languageId, B).ToArray(), new[] { B });
147-
}
148-
149-
var segThreshold = new DenseTensor<float>(new[] { options.BoundaryThreshold }, Array.Empty<int>());
150-
var radius = new DenseTensor<long>(new long[] { options.BoundaryRadius }, Array.Empty<int>());
158+
Tensor<long>? language = null;
159+
if (config.Languages != null) {
160+
int languageId = ResolveLanguageId(options.LanguageCode);
161+
language = new DenseTensor<long>(
162+
Enumerable.Repeat((long)languageId, B).ToArray(), new[] { B });
163+
}
151164

152-
if (config.Loop) {
153-
float step = 1.0f / options.SamplingSteps;
154-
for (int i = 0; i < options.SamplingSteps; i++) {
155-
var t = new DenseTensor<float>(
156-
Enumerable.Repeat(i * step, B).ToArray(), new[] { B });
157-
boundaries = RunSegmenter(xSeg, knownBoundaries, boundaries, t, maskT, language, segThreshold, radius);
165+
var segThreshold = new DenseTensor<float>(new[] { options.BoundaryThreshold }, Array.Empty<int>());
166+
var radius = new DenseTensor<long>(new long[] { options.BoundaryRadius }, Array.Empty<int>());
167+
168+
if (config.Loop) {
169+
float step = 1.0f / options.SamplingSteps;
170+
for (int i = 0; i < options.SamplingSteps; i++) {
171+
var t = new DenseTensor<float>(
172+
Enumerable.Repeat(i * step, B).ToArray(), new[] { B });
173+
boundaries = RunSegmenter(xSeg, knownBoundaries, boundaries, t, maskT, language, segThreshold, radius);
174+
}
175+
} else {
176+
boundaries = RunSegmenter(xSeg, knownBoundaries, null, null, maskT, language, segThreshold, radius);
158177
}
159-
} else {
160-
boundaries = RunSegmenter(xSeg, knownBoundaries, null, null, maskT, language, segThreshold, radius);
161-
}
162178

163-
// 3. Boundaries to durations
164-
var (durations, maskN) = RunBd2Dur(boundaries, maskT);
165-
int N = maskN.Dimensions[1];
179+
// 3. Boundaries to durations
180+
var (durations, maskN) = RunBd2Dur(boundaries, maskT);
181+
int N = maskN.Dimensions[1];
166182

167-
// 4. Estimation
168-
var scoreThreshold = new DenseTensor<float>(new[] { options.ScoreThreshold }, Array.Empty<int>());
169-
var (presence, scores) = RunEstimator(xEst, boundaries, maskT, maskN, scoreThreshold);
183+
// 4. Estimation
184+
var scoreThreshold = new DenseTensor<float>(new[] { options.ScoreThreshold }, Array.Empty<int>());
185+
var (presence, scores) = RunEstimator(xEst, boundaries, maskT, maskN, scoreThreshold);
170186

171-
// 5. Split results per batch item
172-
var results = new List<List<TranscribedNote>>(B);
173-
for (int b = 0; b < B; b++) {
174-
var notes = new List<TranscribedNote>(N);
175-
for (int i = 0; i < N; i++) {
176-
if (!maskN[b, i]) break;
177-
notes.Add(new TranscribedNote(durations[b, i], scores[b, i], presence[b, i]));
187+
// 5. Split results per batch item
188+
var results = new List<List<TranscribedNote>>(B);
189+
for (int b = 0; b < B; b++) {
190+
var notes = new List<TranscribedNote>(N);
191+
for (int i = 0; i < N; i++) {
192+
if (!maskN[b, i]) break;
193+
notes.Add(new TranscribedNote(durations[b, i], scores[b, i], presence[b, i]));
194+
}
195+
196+
results.Add(notes);
178197
}
179198

180-
results.Add(notes);
199+
return results;
200+
} catch (OnnxRuntimeException) {
201+
if (runOptions != null && runOptions.Terminate) {
202+
throw new OperationCanceledException();
203+
}
204+
throw;
181205
}
206+
}
182207

183-
return results;
208+
public override void Interrupt() {
209+
stopping = true;
210+
if (!disposed && runOptions != null) {
211+
runOptions.Terminate = true;
212+
}
184213
}
185214

186215
protected override void DisposeManaged() {
216+
if (disposed) return;
217+
disposed = true;
218+
runOptions?.Dispose();
187219
encoderSession?.Dispose();
188220
segmenterSession?.Dispose();
189221
estimatorSession?.Dispose();
@@ -228,7 +260,7 @@ private int ResolveLanguageId(string? languageCode) {
228260
NamedOnnxValue.CreateFromTensor("duration", duration),
229261
};
230262

231-
using var outputs = encoderSession!.Run(inputs);
263+
using var outputs = encoderSession!.Run(inputs, encoderSession.OutputNames, runOptions);
232264

233265
var xSeg = outputs.First(o => o.Name == "x_seg").AsTensor<float>().ToDenseTensor();
234266
var xEst = outputs.First(o => o.Name == "x_est").AsTensor<float>().ToDenseTensor();
@@ -267,7 +299,7 @@ private Tensor<bool> RunSegmenter(
267299
inputs.Add(NamedOnnxValue.CreateFromTensor("threshold", threshold));
268300
inputs.Add(NamedOnnxValue.CreateFromTensor("radius", radius));
269301

270-
using var outputs = segmenterSession!.Run(inputs);
302+
using var outputs = segmenterSession!.Run(inputs, segmenterSession.OutputNames, runOptions);
271303
var boundaries = outputs.First(o => o.Name == "boundaries").AsTensor<bool>().ToDenseTensor();
272304
return boundaries;
273305
}
@@ -282,7 +314,7 @@ private Tensor<bool> RunSegmenter(
282314
NamedOnnxValue.CreateFromTensor("maskT", maskT),
283315
};
284316

285-
using var outputs = bd2durSession!.Run(inputs);
317+
using var outputs = bd2durSession!.Run(inputs, bd2durSession.OutputNames, runOptions);
286318
var durations = outputs.First(o => o.Name == "durations").AsTensor<float>().ToDenseTensor();
287319
var maskN = outputs.First(o => o.Name == "maskN").AsTensor<bool>().ToDenseTensor();
288320

@@ -303,7 +335,7 @@ private Tensor<bool> RunSegmenter(
303335
NamedOnnxValue.CreateFromTensor("threshold", threshold),
304336
};
305337

306-
using var outputs = estimatorSession!.Run(inputs);
338+
using var outputs = estimatorSession!.Run(inputs, estimatorSession.OutputNames, runOptions);
307339
var presence = outputs.First(o => o.Name == "presence").AsTensor<bool>().ToDenseTensor();
308340
var scores = outputs.First(o => o.Name == "scores").AsTensor<float>().ToDenseTensor();
309341
return (presence, scores);

0 commit comments

Comments
 (0)