Skip to content

Commit b7141e7

Browse files
committed
feat: add forward_type to MNNR_Config and implement backend conversion in Rust
1 parent e2478d4 commit b7141e7

3 files changed

Lines changed: 33 additions & 18 deletions

File tree

cpp/include/mnn_wrapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ extern "C"
4141
int32_t precision_mode; // 0=Normal, 1=Low(faster), 2=High(accurate)
4242
bool use_cache; // Whether to use cache file
4343
int32_t data_format; // Input/Output data format
44+
int32_t forward_type; // MNNForwardType: 0=CPU, 1=Metal, 2=CUDA, 3=OpenCL, 6=OpenGL, 7=Vulkan, 5=CoreML/NNAPI
4445
} MNNR_Config;
4546

4647
// ============== Version & Info ==============

cpp/src/mnn_wrapper.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,17 @@ struct MNN_SessionPool
7979

8080
// ============== Helper Functions ==============
8181

82-
static MNN::ScheduleConfig create_schedule_config(const MNNR_Config *config)
82+
// Initialize schedule and backend configs from MNNR_Config.
83+
// Caller must ensure `schedule` and `backend` outlive any use of schedule.backendConfig.
84+
static void init_schedule_config(MNN::ScheduleConfig &schedule, MNN::BackendConfig &backend, const MNNR_Config *config)
8385
{
84-
MNN::ScheduleConfig schedule;
85-
schedule.type = MNN_FORWARD_CPU;
86+
schedule.type = (config) ? static_cast<MNNForwardType>(config->forward_type) : MNN_FORWARD_CPU;
8687
schedule.numThread = config ? config->thread_count : 4;
8788
if (schedule.numThread <= 0)
8889
{
8990
schedule.numThread = 4;
9091
}
9192

92-
MNN::BackendConfig backend;
9393
if (config)
9494
{
9595
switch (config->precision_mode)
@@ -106,8 +106,6 @@ static MNN::ScheduleConfig create_schedule_config(const MNNR_Config *config)
106106
}
107107
}
108108
schedule.backendConfig = &backend;
109-
110-
return schedule;
111109
}
112110

113111
static bool init_engine_tensors(MNN_InferenceEngine *engine)
@@ -165,7 +163,7 @@ MNN_SharedRuntime *mnnr_create_runtime(const MNNR_Config *config)
165163

166164
runtime->precision_mode = config ? config->precision_mode : 0;
167165

168-
runtime->schedule_config.type = MNN_FORWARD_CPU;
166+
runtime->schedule_config.type = (config) ? static_cast<MNNForwardType>(config->forward_type) : MNN_FORWARD_CPU;
169167
runtime->schedule_config.numThread = runtime->thread_count;
170168

171169
switch (runtime->precision_mode)
@@ -214,7 +212,9 @@ MNN_InferenceEngine *mnnr_create_engine(
214212
}
215213

216214
// Create default session
217-
MNN::ScheduleConfig schedule = create_schedule_config(config);
215+
MNN::ScheduleConfig schedule;
216+
MNN::BackendConfig backend;
217+
init_schedule_config(schedule, backend, config);
218218
engine->default_session = engine->interpreter->createSession(schedule);
219219
if (!engine->default_session)
220220
{
@@ -405,7 +405,9 @@ MNN_SessionPool *mnnr_create_session_pool(
405405
auto pool = new MNN_SessionPool();
406406
pool->engine = engine;
407407

408-
MNN::ScheduleConfig schedule = create_schedule_config(config);
408+
MNN::ScheduleConfig schedule;
409+
MNN::BackendConfig backend;
410+
init_schedule_config(schedule, backend, config);
409411

410412
// Create sessions
411413
for (size_t i = 0; i < pool_size; i++)
@@ -418,14 +420,8 @@ MNN_SessionPool *mnnr_create_session_pool(
418420
{
419421
engine->interpreter->releaseSession(s);
420422
}
421-
for (auto t : pool->input_tensors)
422-
{
423-
delete t;
424-
}
425-
for (auto t : pool->output_tensors)
426-
{
427-
delete t;
428-
}
423+
// Note: input/output tensors are owned by MNN sessions, not by us.
424+
// They will be freed when sessions are released above.
429425
delete pool;
430426
return nullptr;
431427
}
@@ -557,7 +553,9 @@ MNN_SingleSession *mnnr_create_session(
557553
auto session = new MNN_SingleSession();
558554
session->engine = engine;
559555

560-
MNN::ScheduleConfig schedule = create_schedule_config(config);
556+
MNN::ScheduleConfig schedule;
557+
MNN::BackendConfig backend;
558+
init_schedule_config(schedule, backend, config);
561559
session->session = engine->interpreter->createSession(schedule);
562560

563561
if (!session->session)

src/mnn/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ mod normal_impl {
117117
CoreML,
118118
}
119119

120+
impl Backend {
121+
/// Convert to MNNForwardType integer value
122+
fn to_forward_type(self) -> i32 {
123+
match self {
124+
Backend::CPU => 0, // MNN_FORWARD_CPU
125+
Backend::Metal => 1, // MNN_FORWARD_METAL
126+
Backend::CUDA => 2, // MNN_FORWARD_CUDA
127+
Backend::OpenCL => 3, // MNN_FORWARD_OPENCL
128+
Backend::CoreML => 5, // MNN_FORWARD_NN
129+
Backend::OpenGL => 6, // MNN_FORWARD_OPENGL
130+
Backend::Vulkan => 7, // MNN_FORWARD_VULKAN
131+
}
132+
}
133+
}
134+
120135
/// Inference configuration
121136
#[derive(Debug, Clone)]
122137
pub struct InferenceConfig {
@@ -180,6 +195,7 @@ mod normal_impl {
180195
precision_mode: self.precision_mode as i32,
181196
use_cache: self.use_cache,
182197
data_format: self.data_format as i32,
198+
forward_type: self.backend.to_forward_type(),
183199
}
184200
}
185201
}

0 commit comments

Comments
 (0)