@@ -128,13 +128,13 @@ TEST_P(InferenceTest, Simple){
128128 ring_buffer.push_sample (0 , data_reference.at ((repeat*buffer_size)+i));
129129 }
130130
131- size_t prev_samples = inference_handler.get_num_received_samples (0 );
131+ size_t prev_samples = inference_handler.get_available_samples (0 );
132132
133133 inference_handler.process (test_buffer.get_array_of_write_pointers (), buffer_size);
134134
135135 // wait until the block was properly processed
136136 auto start = std::chrono::system_clock::now ();
137- while (inference_handler.get_num_received_samples (0 ) != prev_samples){
137+ while (inference_handler.get_available_samples (0 ) != prev_samples){
138138 if (std::chrono::system_clock::now () > start + std::chrono::duration<long int >(INFERENCE_TIMEOUT_S )){
139139 FAIL () << " Timeout while waiting for block to be processed" ;
140140 }
@@ -215,13 +215,13 @@ TEST_P(InferenceTest, WithCustomLatency){
215215 ring_buffer.push_sample (0 , data_reference.at ((repeat*buffer_size)+i));
216216 }
217217
218- size_t prev_samples = inference_handler.get_num_received_samples (0 );
218+ size_t prev_samples = inference_handler.get_available_samples (0 );
219219
220220 inference_handler.push_data (test_buffer.get_array_of_read_pointers (), buffer_size);
221221
222222 // wait until the block was properly processed
223223 auto start = std::chrono::system_clock::now ();
224- while (inference_handler.get_num_received_samples (0 ) != prev_samples + buffer_size){
224+ while (inference_handler.get_available_samples (0 ) != prev_samples + buffer_size){
225225 if (std::chrono::system_clock::now () > start + std::chrono::duration<long int >(INFERENCE_TIMEOUT_S )){
226226 FAIL () << " Timeout while waiting for block to be processed" ;
227227 }
@@ -245,6 +245,137 @@ TEST_P(InferenceTest, WithCustomLatency){
245245 }
246246}
247247
248+ TEST_P (InferenceTest, Reset){
249+
250+ auto const & test_params = GetParam ();
251+ auto const & buffer_size = test_params.host_config .m_buffer_size ;
252+ auto const & reference_offset = test_params.reference_data_offset ;
253+
254+ // read reference data
255+ std::vector<float > data_input;
256+ std::vector<float > data_reference;
257+
258+ read_wav (test_params.input_data_path , data_input);
259+ read_wav (test_params.reference_data_path , data_reference);
260+
261+ ASSERT_TRUE (data_input.size () > 0 );
262+ ASSERT_TRUE (data_reference.size () > 0 );
263+
264+ // setup inference
265+ ContextConfig anira_context_config;
266+ InferenceConfig inference_config = hybridnn_config;
267+ HybridNNPrePostProcessor pp_processor (inference_config);
268+ HybridNNBypassProcessor bypass_processor (inference_config);
269+
270+ // This test requires the buffer size to be a multiple of the preprocess input size
271+ if (static_cast <size_t >(buffer_size) % inference_config.get_preprocess_input_size ()[0 ] != 0 ){
272+ GTEST_SKIP () << " Test requires the preprocess_input_size to be a multiple of the buffer size." ;
273+ return ;
274+ }
275+
276+ // Create an InferenceHandler instance
277+ InferenceHandler inference_handler (pp_processor, inference_config, bypass_processor, anira_context_config);
278+
279+ // Allocate memory for audio processing
280+ inference_handler.prepare (test_params.host_config );
281+ // Select the inference backend
282+ inference_handler.set_inference_backend (test_params.backend );
283+
284+ int latency_offset = inference_handler.get_latency (); // The 0th tensor is the audio data tensor, so we only need the first element of the latency vector
285+
286+ BufferF test_buffer (1 , buffer_size);
287+ RingBuffer ring_buffer;
288+ ring_buffer.initialize_with_positions (1 , latency_offset + buffer_size + reference_offset);
289+
290+ // fill the buffer with zeroes to compensate for the latency
291+ for (size_t i = 0 ; i < latency_offset + reference_offset; i++){
292+ ring_buffer.push_sample (0 , 0 );
293+ }
294+
295+ // First, process some data to "contaminate" the internal state
296+ for (size_t repeat = 0 ; repeat < 50 ; repeat++){
297+ for (size_t i = 0 ; i < buffer_size; i++){
298+ test_buffer.set_sample (0 , i, data_input.at ((repeat*buffer_size)+i));
299+ ring_buffer.push_sample (0 , data_reference.at ((repeat*buffer_size)+i));
300+ }
301+
302+ size_t prev_samples = inference_handler.get_available_samples (0 );
303+ inference_handler.process (test_buffer.get_array_of_write_pointers (), buffer_size);
304+
305+ // wait until the block was properly processed
306+ auto start = std::chrono::system_clock::now ();
307+ while (inference_handler.get_available_samples (0 ) != prev_samples){
308+ if (std::chrono::system_clock::now () > start + std::chrono::duration<long int >(INFERENCE_TIMEOUT_S )){
309+ FAIL () << " Timeout while waiting for block to be processed" ;
310+ }
311+ std::this_thread::sleep_for (std::chrono::nanoseconds (10 ));
312+ }
313+
314+ for (size_t i = 0 ; i < buffer_size; i++){
315+ float reference = ring_buffer.pop_sample (0 );
316+ float processed = test_buffer.get_sample (0 , i);
317+
318+ if (repeat*buffer_size + i < latency_offset + reference_offset){
319+ ASSERT_FLOAT_EQ (reference, 0 );
320+ } else {
321+ // calculate epsilon on the fly
322+ float epsilon = max (abs (reference), abs (processed)) * test_params.epsilon_rel + test_params.epsilon_abs ;
323+ ASSERT_NEAR (reference, processed, epsilon) << " repeat=" << repeat << " , i=" << i << " , total sample nr: " << repeat*buffer_size + i << std::endl;
324+ }
325+ }
326+ }
327+
328+ // Now reset the inference handler
329+ inference_handler.reset ();
330+
331+ // Verify that the available samples count is reset
332+ EXPECT_EQ (inference_handler.get_available_samples (0 ), latency_offset) << " Available samples should be " << latency_offset << " after reset" ;
333+
334+ // Reset the ring buffer to restart from the beginning of reference data
335+ ring_buffer.clear_with_positions ();
336+ ring_buffer.initialize_with_positions (1 , latency_offset + buffer_size + reference_offset);
337+
338+ // Fill the buffer with zeroes to compensate for the latency
339+ for (size_t i = 0 ; i < latency_offset + reference_offset; i++){
340+ ring_buffer.push_sample (0 , 0 );
341+ }
342+
343+ // Process data again and verify that output matches reference from the beginning
344+ for (size_t repeat = 0 ; repeat < 150 ; repeat++){
345+
346+ for (size_t i = 0 ; i < buffer_size; i++){
347+ test_buffer.set_sample (0 , i, data_input.at ((repeat*buffer_size)+i));
348+ ring_buffer.push_sample (0 , data_reference.at ((repeat*buffer_size)+i));
349+ }
350+
351+ size_t prev_samples = inference_handler.get_available_samples (0 );
352+
353+ inference_handler.process (test_buffer.get_array_of_write_pointers (), buffer_size);
354+
355+ // wait until the block was properly processed
356+ auto start = std::chrono::system_clock::now ();
357+ while (inference_handler.get_available_samples (0 ) != prev_samples){
358+ if (std::chrono::system_clock::now () > start + std::chrono::duration<long int >(INFERENCE_TIMEOUT_S )){
359+ FAIL () << " Timeout while waiting for block to be processed" ;
360+ }
361+ std::this_thread::sleep_for (std::chrono::nanoseconds (10 ));
362+ }
363+
364+ for (size_t i = 0 ; i < buffer_size; i++){
365+ float reference = ring_buffer.pop_sample (0 );
366+ float processed = test_buffer.get_sample (0 , i);
367+
368+ if (repeat*buffer_size + i < latency_offset + reference_offset){
369+ ASSERT_FLOAT_EQ (reference, 0 );
370+ } else {
371+ // calculate epsilon on the fly
372+ float epsilon = max (abs (reference), abs (processed)) * test_params.epsilon_rel + test_params.epsilon_abs ;
373+ ASSERT_NEAR (reference, processed, epsilon) << " After reset: repeat=" << repeat << " , i=" << i << " , total sample nr: " << repeat*buffer_size + i << std::endl;
374+ }
375+ }
376+ }
377+ }
378+
248379std::string build_test_name (const testing::TestParamInfo<InferenceTest::ParamType>& info){
249380 std::stringstream ss_sample_rate, ss_buffer_size;
250381
0 commit comments