Skip to content

Commit 4435fd7

Browse files
committed
[fix] fix waveform duration calculation and add mask-all-zero fallback in forwardWithKnownDurations
1 parent 0a2545f commit 4435fd7

1 file changed

Lines changed: 69 additions & 9 deletions

File tree

src/infer/game-infer/src/GameModel.cpp

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,32 @@ namespace Game
232232

233233
InferenceInput input;
234234
input.waveform = waveform_data;
235-
// Match Python: waveform_duration = known_durations.sum(dim=1)
236-
input.duration = std::accumulate(known_durations.begin(), known_durations.end(), 0.0f);
237-
if (input.duration <= 0.0f) {
238-
msg = "forwardWithKnownDurations: total duration is zero or negative: " +
239-
std::to_string(input.duration);
235+
input.known_durations = known_durations;
236+
input.language = m_language;
237+
238+
const float waveformDuration = static_cast<float>(waveform_data.size()) / static_cast<float>(sampleRate);
239+
if (waveformDuration <= 0.0f) {
240+
msg = "forwardWithKnownDurations: waveform duration is zero or negative: " +
241+
std::to_string(waveformDuration);
240242
{
241243
std::lock_guard lock(m_runMutex);
242244
m_activeRunOptions = nullptr;
243245
}
244246
return false;
245247
}
246-
input.known_durations = known_durations;
247-
input.language = m_language;
248+
249+
const float knownDurSum = std::accumulate(known_durations.begin(), known_durations.end(), 0.0f);
250+
if (knownDurSum <= 0.0f) {
251+
bool allZero = true;
252+
for (const auto d : known_durations) {
253+
if (d > 0.0f) { allZero = false; break; }
254+
}
255+
if (allZero) {
256+
input.known_durations = {};
257+
}
258+
}
259+
260+
input.duration = waveformDuration;
248261

249262
int T = static_cast<int>(std::ceil(input.duration / m_timestep));
250263
if (T <= 0)
@@ -797,9 +810,11 @@ namespace Game
797810
}
798811

799812
std::vector<uint8_t> maskTBool(T);
813+
size_t maskTTrueCount = 0;
800814
for (int64_t i = 0; i < T; ++i) {
801815
if (maskTData) {
802816
maskTBool[i] = maskTData[i] ? 1 : 0;
817+
if (maskTBool[i]) ++maskTTrueCount;
803818
} else {
804819
maskTBool[i] = false;
805820
}
@@ -811,7 +826,7 @@ namespace Game
811826
}
812827

813828
std::vector<uint8_t> knownBoundaries;
814-
if (input.known_durations.size() > 1) {
829+
if (knownDurations.size() > 1) {
815830
knownBoundaries = formatBoundaries(knownDurations, T, maskTBool);
816831
} else {
817832
knownBoundaries.resize(T, 0);
@@ -823,9 +838,54 @@ namespace Game
823838
boundaries.resize(T, 0);
824839
}
825840

841+
const float wavDurFallback = static_cast<float>(input.waveform.size()) / static_cast<float>(sampleRate);
842+
if (maskTTrueCount == 0 && T > 0 && wavDurFallback > 0.0f &&
843+
std::fabs(input.duration - wavDurFallback) > 1e-3f) {
844+
auto [xSegVal2, xEstVal2, maskTVal2] = runEncoder(input.waveform, wavDurFallback, input.language);
845+
if (xSegVal2 && maskTVal2) {
846+
auto maskTShape2 = maskTVal2.GetTensorTypeAndShapeInfo().GetShape();
847+
int64_t T2 = maskTShape2[1];
848+
const bool *d2 = maskTVal2.GetTensorData<bool>();
849+
size_t count2 = 0;
850+
maskTBool.resize(T2);
851+
for (int64_t i = 0; i < T2; ++i) {
852+
maskTBool[i] = d2[i] ? 1 : 0;
853+
if (maskTBool[i]) ++count2;
854+
}
855+
if (count2 > 0) {
856+
auto xSegShape2 = xSegVal2.GetTensorTypeAndShapeInfo().GetShape();
857+
auto xSegD2 = xSegVal2.GetTensorData<float>();
858+
size_t sc2 = xSegVal2.GetTensorTypeAndShapeInfo().GetElementCount();
859+
xSegClean.assign(xSegD2, xSegD2 + sc2);
860+
xSegCleanVal = Ort::Value::CreateTensor<float>(m_memoryInfo, xSegClean.data(), xSegClean.size(),
861+
xSegShape2.data(), xSegShape2.size());
862+
auto xEstD2 = xEstVal2.GetTensorData<float>();
863+
size_t ec2 = xEstVal2.GetTensorTypeAndShapeInfo().GetElementCount();
864+
xEstClean.assign(xEstD2, xEstD2 + ec2);
865+
xEstCleanVal = Ort::Value::CreateTensor<float>(m_memoryInfo, xEstClean.data(), xEstClean.size(),
866+
xSegShape2.data(), xSegShape2.size());
867+
T = T2;
868+
if (knownDurations.size() > 1) {
869+
knownBoundaries = formatBoundaries(knownDurations, T, maskTBool);
870+
} else {
871+
knownBoundaries.resize(T, 0);
872+
}
873+
boundaries = runSegmenterWithConfig(xSegCleanVal, knownBoundaries, knownBoundaries,
874+
input.language, maskTVal2, segThreshold, segRadius, d3pmTs);
875+
if (boundaries.empty())
876+
boundaries.resize(T, 0);
877+
maskTVal = std::move(maskTVal2);
878+
}
879+
}
880+
}
881+
826882
auto [durations, maskN] = boundariesToDurations(boundaries, maskTBool);
827883
if (durations.empty() || maskN.empty()) {
828-
throw std::runtime_error("inferSlice: boundariesToDurations returned empty result");
884+
throw std::runtime_error("inferSlice: boundariesToDurations returned empty result"
885+
" (T=" + std::to_string(T) + ", bd_sz=" + std::to_string(boundaries.size()) +
886+
", mT_sz=" + std::to_string(maskTBool.size()) +
887+
", mT_true=" + std::to_string(maskTTrueCount) +
888+
", dur=" + std::to_string(input.duration) + ")");
829889
}
830890

831891
std::vector<float> presence, scores;

0 commit comments

Comments
 (0)