|
33 | 33 |
|
34 | 34 | #include <misc/string_helpers.h> |
35 | 35 |
|
| 36 | +#include <algorithm> |
| 37 | +#include <cctype> |
36 | 38 | #include <fstream> |
37 | 39 | #include <iostream> |
38 | 40 | #include <string> |
@@ -113,11 +115,15 @@ std::string trim(std::string_view value) |
113 | 115 | return std::string(value.substr(begin, end - begin + 1)); |
114 | 116 | } |
115 | 117 |
|
116 | | -std::string inferModelType(const std::string &filename) |
| 118 | +std::string inferModelType(const std::filesystem::path &file_path) |
117 | 119 | { |
118 | | - std::filesystem::path file_path(filename); |
119 | | - std::string ext = file_path.extension().string(); |
120 | | - return ext.empty() ? "" : ext.substr(1); |
| 120 | + if (!file_path.has_extension()) |
| 121 | + return ""; |
| 122 | + |
| 123 | + auto type = file_path.extension().string().substr(1); |
| 124 | + std::transform(type.begin(), type.end(), type.begin(), |
| 125 | + [](unsigned char c) { return std::tolower(c); }); |
| 126 | + return type; |
121 | 127 | } |
122 | 128 |
|
123 | 129 | bool loadConfigure(const std::string cfgfile, onert::util::CfgKeyValues &keyValues) |
@@ -332,10 +338,17 @@ NNFW_STATUS Session::load_model_from_path(const char *path) |
332 | 338 | try |
333 | 339 | { |
334 | 340 | std::filesystem::path filename{path}; |
335 | | - if (!std::filesystem::is_directory(filename) && filename.has_extension()) |
| 341 | + if (!std::filesystem::is_directory(filename)) |
336 | 342 | { |
337 | | - std::string model_type = filename.extension().string().substr(1); // + 1 to exclude dot |
338 | | - return loadModelFile(filename, model_type); |
| 343 | + std::string model_type = inferModelType(filename); |
| 344 | + if (model_type.empty()) |
| 345 | + { |
| 346 | + std::cerr << "Error: Cannot determine model type for '" << filename << "'." |
| 347 | + << "Please use a file with valid extension." << std::endl; |
| 348 | + return NNFW_STATUS_ERROR; |
| 349 | + } |
| 350 | + else |
| 351 | + return loadModelFile(filename, model_type); |
339 | 352 | } |
340 | 353 |
|
341 | 354 | const auto &package_dir = filename; |
@@ -386,14 +399,15 @@ NNFW_STATUS Session::load_model_from_path(const char *path) |
386 | 399 |
|
387 | 400 | for (uint16_t i = 0; i < num_models; ++i) |
388 | 401 | { |
389 | | - const auto model_file_path = package_dir / models[i].asString(); |
| 402 | + const auto model_file_name = std::filesystem::path(models[i].asString()); |
| 403 | + const auto model_file_path = package_dir / model_file_name; |
390 | 404 | std::string model_type; |
391 | 405 |
|
392 | 406 | // Use model-types if available and not empty, otherwise infer from file extension |
393 | 407 | if (!model_types.empty() && i < model_types.size()) |
394 | 408 | model_type = model_types[i].asString(); |
395 | 409 | else |
396 | | - model_type = inferModelType(models[i].asString()); |
| 410 | + model_type = inferModelType(model_file_name); |
397 | 411 | if (model_type.empty()) |
398 | 412 | { |
399 | 413 | std::cerr << "Error: Cannot determine model type for '" << models[i].asString() << "'." |
|
0 commit comments