Skip to content

Commit 5f56c3b

Browse files
author
wejoncy
committed
better hash
1 parent f492fee commit 5f56c3b

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

onnxruntime/core/providers/coreml/builders/model_builder.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ void CreateEmptyFile(const std::string& filename) {
391391
#endif // defined(COREML_ENABLE_MLPROGRAM)
392392

393393
std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
394-
const std::vector<std::string>& onnx_input_names) {
394+
const std::string& graph_name) {
395395
std::string path;
396396
if (coreml_options.ModelCachePath().empty()) {
397397
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
@@ -400,14 +400,11 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
400400
path += ".model.mlmodel";
401401
}
402402
} else {
403-
// input names in onnx are unique. so we can use them as the key in the cache.
404-
std::string inputs_collections = std::accumulate(
405-
onnx_input_names.begin(), onnx_input_names.end(), std::string(),
406-
[](const std::string& a, const std::string& b) { return a + "," + b; });
407-
std::hash<std::string> hasher;
408-
// different subgraph has different folders. so we need to hash the inputs.
409-
path = std::string(coreml_options.ModelCachePath()) +
410-
"/" + std::to_string(hasher(inputs_collections));
403+
// graph_name is uniquely generated by
404+
// onnxruntime/core/providers/coreml/coreml_execution_provider.cc::gen_metadef_name
405+
// int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
406+
// MakeString(COREML, "_", model_hash, "_", metadef_id);.
407+
path = std::string(coreml_options.ModelCachePath()) + "/" + graph_name;
411408
if (!coreml_options.CreateMLProgram()) {
412409
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
413410
path += "/mlmodel";
@@ -427,7 +424,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
427424
coreml_version_(coreml_version),
428425
coreml_options_(coreml_options),
429426
create_ml_program_(coreml_options.CreateMLProgram()),
430-
model_output_path_(GetModelOutputPath(coreml_options, onnx_input_names)),
427+
model_output_path_(GetModelOutputPath(coreml_options, graph_viewer.Name())),
431428
onnx_input_names_(std::move(onnx_input_names)),
432429
onnx_output_names_(std::move(onnx_output_names)),
433430
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {

onnxruntime/core/providers/coreml/coreml_execution_provider.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "core/providers/coreml/model/host_utils.h"
1919
#include "core/providers/coreml/model/model.h"
2020
#include "core/providers/coreml/shape_utils.h"
21+
#include "core/graph/model.h"
2122

2223
namespace onnxruntime {
2324

@@ -57,7 +58,11 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
5758
[&]() {
5859
HashValue model_hash;
5960
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
60-
return MakeString(COREML, "_", model_hash, "_", metadef_id);
61+
std::string user_provide_hash;
62+
if (graph_viewer.GetGraph().GetModel().MetaData().count("model_hash") > 0) {
63+
user_provide_hash = graph_viewer.GetGraph().GetModel().MetaData().at("model_hash");
64+
}
65+
return MakeString(user_provide_hash, "_", COREML, "_", model_hash, "_", metadef_id);
6166
};
6267

6368
result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},

0 commit comments

Comments
 (0)