@@ -232,19 +232,32 @@ namespace Game
232232
233233 InferenceInput input;
234234 input.waveform = waveform_data;
235- // Match Python: waveform_duration = known_durations.sum(dim=1)
236- input.duration = std::accumulate (known_durations.begin (), known_durations.end (), 0 .0f );
237- if (input.duration <= 0 .0f ) {
238- msg = " forwardWithKnownDurations: total duration is zero or negative: " +
239- std::to_string (input.duration );
235+ input.known_durations = known_durations;
236+ input.language = m_language;
237+
238+ const float waveformDuration = static_cast <float >(waveform_data.size ()) / static_cast <float >(sampleRate);
239+ if (waveformDuration <= 0 .0f ) {
240+ msg = " forwardWithKnownDurations: waveform duration is zero or negative: " +
241+ std::to_string (waveformDuration);
240242 {
241243 std::lock_guard lock (m_runMutex);
242244 m_activeRunOptions = nullptr ;
243245 }
244246 return false ;
245247 }
246- input.known_durations = known_durations;
247- input.language = m_language;
248+
249+ const float knownDurSum = std::accumulate (known_durations.begin (), known_durations.end (), 0 .0f );
250+ if (knownDurSum <= 0 .0f ) {
251+ bool allZero = true ;
252+ for (const auto d : known_durations) {
253+ if (d > 0 .0f ) { allZero = false ; break ; }
254+ }
255+ if (allZero) {
256+ input.known_durations = {};
257+ }
258+ }
259+
260+ input.duration = waveformDuration;
248261
249262 int T = static_cast <int >(std::ceil (input.duration / m_timestep));
250263 if (T <= 0 )
@@ -797,9 +810,11 @@ namespace Game
797810 }
798811
799812 std::vector<uint8_t > maskTBool (T);
813+ size_t maskTTrueCount = 0 ;
800814 for (int64_t i = 0 ; i < T; ++i) {
801815 if (maskTData) {
802816 maskTBool[i] = maskTData[i] ? 1 : 0 ;
817+ if (maskTBool[i]) ++maskTTrueCount;
803818 } else {
804819 maskTBool[i] = false ;
805820 }
@@ -811,7 +826,7 @@ namespace Game
811826 }
812827
813828 std::vector<uint8_t > knownBoundaries;
814- if (input. known_durations .size () > 1 ) {
829+ if (knownDurations .size () > 1 ) {
815830 knownBoundaries = formatBoundaries (knownDurations, T, maskTBool);
816831 } else {
817832 knownBoundaries.resize (T, 0 );
@@ -823,9 +838,54 @@ namespace Game
823838 boundaries.resize (T, 0 );
824839 }
825840
841+ const float wavDurFallback = static_cast <float >(input.waveform .size ()) / static_cast <float >(sampleRate);
842+ if (maskTTrueCount == 0 && T > 0 && wavDurFallback > 0 .0f &&
843+ std::fabs (input.duration - wavDurFallback) > 1e-3f ) {
844+ auto [xSegVal2, xEstVal2, maskTVal2] = runEncoder (input.waveform , wavDurFallback, input.language );
845+ if (xSegVal2 && maskTVal2) {
846+ auto maskTShape2 = maskTVal2.GetTensorTypeAndShapeInfo ().GetShape ();
847+ int64_t T2 = maskTShape2[1 ];
848+ const bool *d2 = maskTVal2.GetTensorData <bool >();
849+ size_t count2 = 0 ;
850+ maskTBool.resize (T2);
851+ for (int64_t i = 0 ; i < T2; ++i) {
852+ maskTBool[i] = d2[i] ? 1 : 0 ;
853+ if (maskTBool[i]) ++count2;
854+ }
855+ if (count2 > 0 ) {
856+ auto xSegShape2 = xSegVal2.GetTensorTypeAndShapeInfo ().GetShape ();
857+ auto xSegD2 = xSegVal2.GetTensorData <float >();
858+ size_t sc2 = xSegVal2.GetTensorTypeAndShapeInfo ().GetElementCount ();
859+ xSegClean.assign (xSegD2, xSegD2 + sc2);
860+ xSegCleanVal = Ort::Value::CreateTensor<float >(m_memoryInfo, xSegClean.data (), xSegClean.size (),
861+ xSegShape2.data (), xSegShape2.size ());
862+ auto xEstD2 = xEstVal2.GetTensorData <float >();
863+ size_t ec2 = xEstVal2.GetTensorTypeAndShapeInfo ().GetElementCount ();
864+ xEstClean.assign (xEstD2, xEstD2 + ec2);
865+ xEstCleanVal = Ort::Value::CreateTensor<float >(m_memoryInfo, xEstClean.data (), xEstClean.size (),
866+ xSegShape2.data (), xSegShape2.size ());
867+ T = T2;
868+ if (knownDurations.size () > 1 ) {
869+ knownBoundaries = formatBoundaries (knownDurations, T, maskTBool);
870+ } else {
871+ knownBoundaries.resize (T, 0 );
872+ }
873+ boundaries = runSegmenterWithConfig (xSegCleanVal, knownBoundaries, knownBoundaries,
874+ input.language , maskTVal2, segThreshold, segRadius, d3pmTs);
875+ if (boundaries.empty ())
876+ boundaries.resize (T, 0 );
877+ maskTVal = std::move (maskTVal2);
878+ }
879+ }
880+ }
881+
826882 auto [durations, maskN] = boundariesToDurations (boundaries, maskTBool);
827883 if (durations.empty () || maskN.empty ()) {
828- throw std::runtime_error (" inferSlice: boundariesToDurations returned empty result" );
884+ throw std::runtime_error (" inferSlice: boundariesToDurations returned empty result"
885+ " (T=" + std::to_string (T) + " , bd_sz=" + std::to_string (boundaries.size ()) +
886+ " , mT_sz=" + std::to_string (maskTBool.size ()) +
887+ " , mT_true=" + std::to_string (maskTTrueCount) +
888+ " , dur=" + std::to_string (input.duration ) + " )" );
829889 }
830890
831891 std::vector<float > presence, scores;
0 commit comments