diff --git a/runtime/onert/api/nnfw/src/nnfw_session.cc b/runtime/onert/api/nnfw/src/nnfw_session.cc index eaeb2999718..7a559d8a2ca 100644 --- a/runtime/onert/api/nnfw/src/nnfw_session.cc +++ b/runtime/onert/api/nnfw/src/nnfw_session.cc @@ -113,6 +113,13 @@ std::string trim(std::string_view value) return std::string(value.substr(begin, end - begin + 1)); } +std::string inferModelType(const std::string &filename) +{ + std::filesystem::path file_path(filename); + std::string ext = file_path.extension().string(); + return ext.empty() ? "" : ext.substr(1); +} + bool loadConfigure(const std::string cfgfile, onert::util::CfgKeyValues &keyValues) { std::ifstream ifs(cfgfile); @@ -377,7 +384,21 @@ NNFW_STATUS nnfw_session::load_model_from_path(const char *path) for (uint16_t i = 0; i < num_models; ++i) { const auto model_file_path = package_dir / models[i].asString(); - const auto model_type = model_types[i].asString(); + std::string model_type; + + // Use model-types if available and not empty, otherwise infer from file extension + if (!model_types.empty() && i < model_types.size()) + model_type = model_types[i].asString(); + else + model_type = inferModelType(models[i].asString()); + if (model_type.empty()) + { + std::cerr << "Error: Cannot determine model type for '" << models[i].asString() << "'." + << "Please specify model-types in MANIFEST or use a file with valid extension." + << std::endl; + return NNFW_STATUS_ERROR; + } + auto model = loadModel(model_file_path.string(), model_type); if (model == nullptr) return NNFW_STATUS_ERROR;