Skip to content

Commit 2247a8b

Browse files
committed
fixed bugs and added test for inference_handler.reset method
1 parent f861119 commit 2247a8b

8 files changed

Lines changed: 154 additions & 19 deletions

File tree

docs/sphinx/architecture.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Component Responsibilities
6767
--------------------------
6868

6969
:cpp:class:`anira::InferenceHandler`
70-
~~~~~~~~~~~~~~~~
70+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7171

7272
The primary interface for users, handling the overall integration of neural network inference into audio processing workflows.
7373

@@ -77,7 +77,7 @@ The primary interface for users, handling the overall integration of neural netw
7777
* Reports latency information
7878

7979
:cpp:class:`anira::InferenceConfig`
80-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
80+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8181

8282
Stores configuration data for models and processing parameters.
8383

@@ -87,7 +87,7 @@ Stores configuration data for models and processing parameters.
8787
* Memory management settings
8888

8989
:cpp:class:`anira::PrePostProcessor`
90-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
90+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9191

9292
Handles data formatting between audio buffers and neural network tensors.
9393

@@ -96,7 +96,7 @@ Handles data formatting between audio buffers and neural network tensors.
9696
* Manages intermediate buffers
9797

9898
:cpp:class:`anira::InferenceManager`
99-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
99+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
100100

101101
Coordinates the thread pool and inference scheduling.
102102

include/anira/InferenceHandler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class ANIRA_API InferenceHandler {
228228
* @param channel Channel index to query (default: 0)
229229
* @return Number of samples received for the specified tensor and channel
230230
*/
231-
size_t get_num_received_samples(size_t tensor_index, size_t channel = 0) const;
231+
size_t get_available_samples(size_t tensor_index, size_t channel = 0) const;
232232

233233
/**
234234
* @brief Configures the handler for non-real-time operation

include/anira/scheduler/InferenceManager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class ANIRA_API InferenceManager {
162162
* @param channel Channel index to query
163163
* @return Number of samples received for the specified tensor and channel
164164
*/
165-
size_t get_num_received_samples(size_t tensor_index, size_t channel) const;
165+
size_t get_available_samples(size_t tensor_index, size_t channel) const;
166166

167167
/**
168168
* @brief Gets a const reference to the inference context (for unit testing)

src/InferenceHandler.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ std::vector<unsigned int> InferenceHandler::get_latency_vector() const {
129129
return m_inference_manager.get_latency();
130130
}
131131

132-
size_t InferenceHandler::get_num_received_samples(size_t tensor_index, size_t channel) const {
133-
return m_inference_manager.get_num_received_samples(tensor_index, channel);
132+
size_t InferenceHandler::get_available_samples(size_t tensor_index, size_t channel) const {
133+
return m_inference_manager.get_available_samples(tensor_index, channel);
134134
}
135135

136136
void InferenceHandler::set_non_realtime(bool is_non_realtime) {

src/benchmark/ProcessBlockFixture.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ ProcessBlockFixture::~ProcessBlockFixture() {
1313
}
1414

1515
void ProcessBlockFixture::initialize_iteration() {
16-
m_prev_num_received_samples = m_inference_handler->get_num_received_samples(0);
16+
m_prev_num_received_samples = m_inference_handler->get_available_samples(0);
1717
}
1818

1919
void ProcessBlockFixture::initialize_repetition(const InferenceConfig& inference_config, const HostConfig& host_config, const InferenceBackend& inference_backend, bool sleep_after_repetition) {
@@ -72,7 +72,7 @@ void ProcessBlockFixture::initialize_repetition(const InferenceConfig& inference
7272
}
7373

7474
bool ProcessBlockFixture::buffer_processed() {
75-
return m_inference_handler->get_num_received_samples(0) >= m_prev_num_received_samples;
75+
return m_inference_handler->get_available_samples(0) >= m_prev_num_received_samples;
7676
}
7777

7878
void ProcessBlockFixture::push_random_samples_in_buffer(anira::HostConfig host_config) {

src/scheduler/Context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void Context::release_thread_pool() {
9393
}
9494

9595
void Context::release_session(std::shared_ptr<SessionElement> session) {
96-
session->m_initialized.store(false, std::memory_order::acquire);
96+
session->m_initialized.store(false, std::memory_order::release);
9797

9898
drain_inference_queue(session);
9999

@@ -134,7 +134,7 @@ void Context::release_session(std::shared_ptr<SessionElement> session) {
134134
}
135135

136136
void Context::prepare_session(std::shared_ptr<SessionElement> session, HostConfig new_config, std::vector<long> custom_latency) {
137-
session->m_initialized.store(false, std::memory_order::acquire);
137+
session->m_initialized.store(false, std::memory_order::release);
138138

139139
drain_inference_queue(session);
140140

@@ -321,7 +321,7 @@ template <typename T> void Context::release_processor(InferenceConfig& inference
321321
}
322322

323323
void Context::reset_session(std::shared_ptr<SessionElement> session) {
324-
session->m_initialized.store(false, std::memory_order::acquire);
324+
session->m_initialized.store(false, std::memory_order::release);
325325

326326
drain_inference_queue(session);
327327

src/scheduler/InferenceManager.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,13 @@ const Context& InferenceManager::get_context() const {
147147
return *m_context;
148148
}
149149

150-
size_t InferenceManager::get_num_received_samples(size_t tensor_index, size_t channel) const {
150+
size_t InferenceManager::get_available_samples(size_t tensor_index, size_t channel) const {
151151
m_context->new_data_request(m_session, 0.);
152-
return m_session->m_receive_buffer[tensor_index].get_available_samples(channel);
152+
if (m_inference_config.get_postprocess_output_size()[tensor_index] > 0) {
153+
return m_session->m_receive_buffer[tensor_index].get_available_samples(channel);
154+
} else {
155+
return 0;
156+
}
153157
}
154158

155159
int InferenceManager::get_session_id() const {

test/test_InferenceHandler.cpp

Lines changed: 135 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ TEST_P(InferenceTest, Simple){
128128
ring_buffer.push_sample(0, data_reference.at((repeat*buffer_size)+i));
129129
}
130130

131-
size_t prev_samples = inference_handler.get_num_received_samples(0);
131+
size_t prev_samples = inference_handler.get_available_samples(0);
132132

133133
inference_handler.process(test_buffer.get_array_of_write_pointers(), buffer_size);
134134

135135
// wait until the block was properly processed
136136
auto start = std::chrono::system_clock::now();
137-
while (inference_handler.get_num_received_samples(0) != prev_samples){
137+
while (inference_handler.get_available_samples(0) != prev_samples){
138138
if (std::chrono::system_clock::now() > start + std::chrono::duration<long int>(INFERENCE_TIMEOUT_S)){
139139
FAIL() << "Timeout while waiting for block to be processed";
140140
}
@@ -215,13 +215,13 @@ TEST_P(InferenceTest, WithCustomLatency){
215215
ring_buffer.push_sample(0, data_reference.at((repeat*buffer_size)+i));
216216
}
217217

218-
size_t prev_samples = inference_handler.get_num_received_samples(0);
218+
size_t prev_samples = inference_handler.get_available_samples(0);
219219

220220
inference_handler.push_data(test_buffer.get_array_of_read_pointers(), buffer_size);
221221

222222
// wait until the block was properly processed
223223
auto start = std::chrono::system_clock::now();
224-
while (inference_handler.get_num_received_samples(0) != prev_samples + buffer_size){
224+
while (inference_handler.get_available_samples(0) != prev_samples + buffer_size){
225225
if (std::chrono::system_clock::now() > start + std::chrono::duration<long int>(INFERENCE_TIMEOUT_S)){
226226
FAIL() << "Timeout while waiting for block to be processed";
227227
}
@@ -245,6 +245,137 @@ TEST_P(InferenceTest, WithCustomLatency){
245245
}
246246
}
247247

248+
TEST_P(InferenceTest, Reset){
249+
250+
auto const& test_params = GetParam();
251+
auto const& buffer_size = test_params.host_config.m_buffer_size;
252+
auto const& reference_offset = test_params.reference_data_offset;
253+
254+
// read reference data
255+
std::vector<float> data_input;
256+
std::vector<float> data_reference;
257+
258+
read_wav(test_params.input_data_path, data_input);
259+
read_wav(test_params.reference_data_path, data_reference);
260+
261+
ASSERT_TRUE(data_input.size() > 0);
262+
ASSERT_TRUE(data_reference.size() > 0);
263+
264+
// setup inference
265+
ContextConfig anira_context_config;
266+
InferenceConfig inference_config = hybridnn_config;
267+
HybridNNPrePostProcessor pp_processor(inference_config);
268+
HybridNNBypassProcessor bypass_processor(inference_config);
269+
270+
// This test requires the buffer size to be a multiple of the preprocess input size
271+
if (static_cast<size_t>(buffer_size) % inference_config.get_preprocess_input_size()[0] != 0){
272+
GTEST_SKIP() << "Test requires the preprocess_input_size to be a multiple of the buffer size.";
273+
return;
274+
}
275+
276+
// Create an InferenceHandler instance
277+
InferenceHandler inference_handler(pp_processor, inference_config, bypass_processor, anira_context_config);
278+
279+
// Allocate memory for audio processing
280+
inference_handler.prepare(test_params.host_config);
281+
// Select the inference backend
282+
inference_handler.set_inference_backend(test_params.backend);
283+
284+
int latency_offset = inference_handler.get_latency(); // The 0th tensor is the audio data tensor, so we only need the first element of the latency vector
285+
286+
BufferF test_buffer(1, buffer_size);
287+
RingBuffer ring_buffer;
288+
ring_buffer.initialize_with_positions(1, latency_offset + buffer_size + reference_offset);
289+
290+
//fill the buffer with zeroes to compensate for the latency
291+
for (size_t i = 0; i < latency_offset + reference_offset; i++){
292+
ring_buffer.push_sample(0, 0);
293+
}
294+
295+
// First, process some data to "contaminate" the internal state
296+
for (size_t repeat = 0; repeat < 50; repeat++){
297+
for (size_t i = 0; i < buffer_size; i++){
298+
test_buffer.set_sample(0, i, data_input.at((repeat*buffer_size)+i));
299+
ring_buffer.push_sample(0, data_reference.at((repeat*buffer_size)+i));
300+
}
301+
302+
size_t prev_samples = inference_handler.get_available_samples(0);
303+
inference_handler.process(test_buffer.get_array_of_write_pointers(), buffer_size);
304+
305+
// wait until the block was properly processed
306+
auto start = std::chrono::system_clock::now();
307+
while (inference_handler.get_available_samples(0) != prev_samples){
308+
if (std::chrono::system_clock::now() > start + std::chrono::duration<long int>(INFERENCE_TIMEOUT_S)){
309+
FAIL() << "Timeout while waiting for block to be processed";
310+
}
311+
std::this_thread::sleep_for(std::chrono::nanoseconds (10));
312+
}
313+
314+
for (size_t i = 0; i < buffer_size; i++){
315+
float reference = ring_buffer.pop_sample(0);
316+
float processed = test_buffer.get_sample(0, i);
317+
318+
if (repeat*buffer_size + i < latency_offset + reference_offset){
319+
ASSERT_FLOAT_EQ(reference, 0);
320+
} else {
321+
// calculate epsilon on the fly
322+
float epsilon = max(abs(reference), abs(processed)) * test_params.epsilon_rel + test_params.epsilon_abs;
323+
ASSERT_NEAR(reference, processed, epsilon) << "repeat=" << repeat << ", i=" << i << ", total sample nr: " << repeat*buffer_size + i << std::endl;
324+
}
325+
}
326+
}
327+
328+
// Now reset the inference handler
329+
inference_handler.reset();
330+
331+
// Verify that the available samples count is reset
332+
EXPECT_EQ(inference_handler.get_available_samples(0), latency_offset) << "Available samples should be " << latency_offset << " after reset";
333+
334+
// Reset the ring buffer to restart from the beginning of reference data
335+
ring_buffer.clear_with_positions();
336+
ring_buffer.initialize_with_positions(1, latency_offset + buffer_size + reference_offset);
337+
338+
// Fill the buffer with zeroes to compensate for the latency
339+
for (size_t i = 0; i < latency_offset + reference_offset; i++){
340+
ring_buffer.push_sample(0, 0);
341+
}
342+
343+
// Process data again and verify that output matches reference from the beginning
344+
for (size_t repeat = 0; repeat < 150; repeat++){
345+
346+
for (size_t i = 0; i < buffer_size; i++){
347+
test_buffer.set_sample(0, i, data_input.at((repeat*buffer_size)+i));
348+
ring_buffer.push_sample(0, data_reference.at((repeat*buffer_size)+i));
349+
}
350+
351+
size_t prev_samples = inference_handler.get_available_samples(0);
352+
353+
inference_handler.process(test_buffer.get_array_of_write_pointers(), buffer_size);
354+
355+
// wait until the block was properly processed
356+
auto start = std::chrono::system_clock::now();
357+
while (inference_handler.get_available_samples(0) != prev_samples){
358+
if (std::chrono::system_clock::now() > start + std::chrono::duration<long int>(INFERENCE_TIMEOUT_S)){
359+
FAIL() << "Timeout while waiting for block to be processed";
360+
}
361+
std::this_thread::sleep_for(std::chrono::nanoseconds (10));
362+
}
363+
364+
for (size_t i = 0; i < buffer_size; i++){
365+
float reference = ring_buffer.pop_sample(0);
366+
float processed = test_buffer.get_sample(0, i);
367+
368+
if (repeat*buffer_size + i < latency_offset + reference_offset){
369+
ASSERT_FLOAT_EQ(reference, 0);
370+
} else {
371+
// calculate epsilon on the fly
372+
float epsilon = max(abs(reference), abs(processed)) * test_params.epsilon_rel + test_params.epsilon_abs;
373+
ASSERT_NEAR(reference, processed, epsilon) << "After reset: repeat=" << repeat << ", i=" << i << ", total sample nr: " << repeat*buffer_size + i << std::endl;
374+
}
375+
}
376+
}
377+
}
378+
248379
std::string build_test_name(const testing::TestParamInfo<InferenceTest::ParamType>& info){
249380
std::stringstream ss_sample_rate, ss_buffer_size;
250381

0 commit comments

Comments
 (0)