Skip to content

Commit de35928

Browse files
authored
[onert] Improve model type inference with filesystem path handling (#16258)
This commit updates inferModelType function to accept filesystem::path directly and handle extension in a case-insensitive way. It updates load_model_from_path function to use inferModelType. ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
1 parent 1abac20 commit de35928

1 file changed

Lines changed: 23 additions & 9 deletions

File tree

runtime/onert/api/nnfw/src/Session.cc

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
#include <misc/string_helpers.h>
3535

36+
#include <algorithm>
37+
#include <cctype>
3638
#include <fstream>
3739
#include <iostream>
3840
#include <string>
@@ -113,11 +115,15 @@ std::string trim(std::string_view value)
113115
return std::string(value.substr(begin, end - begin + 1));
114116
}
115117

116-
std::string inferModelType(const std::string &filename)
118+
std::string inferModelType(const std::filesystem::path &file_path)
117119
{
118-
std::filesystem::path file_path(filename);
119-
std::string ext = file_path.extension().string();
120-
return ext.empty() ? "" : ext.substr(1);
120+
if (!file_path.has_extension())
121+
return "";
122+
123+
auto type = file_path.extension().string().substr(1);
124+
std::transform(type.begin(), type.end(), type.begin(),
125+
[](unsigned char c) { return std::tolower(c); });
126+
return type;
121127
}
122128

123129
bool loadConfigure(const std::string cfgfile, onert::util::CfgKeyValues &keyValues)
@@ -332,10 +338,17 @@ NNFW_STATUS Session::load_model_from_path(const char *path)
332338
try
333339
{
334340
std::filesystem::path filename{path};
335-
if (!std::filesystem::is_directory(filename) && filename.has_extension())
341+
if (!std::filesystem::is_directory(filename))
336342
{
337-
std::string model_type = filename.extension().string().substr(1); // + 1 to exclude dot
338-
return loadModelFile(filename, model_type);
343+
std::string model_type = inferModelType(filename);
344+
if (model_type.empty())
345+
{
346+
std::cerr << "Error: Cannot determine model type for '" << filename << "'."
347+
<< "Please use a file with valid extension." << std::endl;
348+
return NNFW_STATUS_ERROR;
349+
}
350+
else
351+
return loadModelFile(filename, model_type);
339352
}
340353

341354
const auto &package_dir = filename;
@@ -386,14 +399,15 @@ NNFW_STATUS Session::load_model_from_path(const char *path)
386399

387400
for (uint16_t i = 0; i < num_models; ++i)
388401
{
389-
const auto model_file_path = package_dir / models[i].asString();
402+
const auto model_file_name = std::filesystem::path(models[i].asString());
403+
const auto model_file_path = package_dir / model_file_name;
390404
std::string model_type;
391405

392406
// Use model-types if available and not empty, otherwise infer from file extension
393407
if (!model_types.empty() && i < model_types.size())
394408
model_type = model_types[i].asString();
395409
else
396-
model_type = inferModelType(models[i].asString());
410+
model_type = inferModelType(model_file_name);
397411
if (model_type.empty())
398412
{
399413
std::cerr << "Error: Cannot determine model type for '" << models[i].asString() << "'."

0 commit comments

Comments
 (0)