diff --git a/services/webnn/BUILD.gn b/services/webnn/BUILD.gn index 8ef2a8d24f5e65..a30b56863e179c 100644 --- a/services/webnn/BUILD.gn +++ b/services/webnn/BUILD.gn @@ -55,6 +55,7 @@ component("webnn_service") { "//gpu/command_buffer/service:gles2", "//gpu/config", "//mojo/public/cpp/bindings", + "//services/webnn:webnn_proto", "//services/webnn/public/mojom", ] @@ -260,3 +261,13 @@ mojolpm_fuzzer_test("webnn_graph_mojolpm_fuzzer") { "//third_party/libprotobuf-mutator", ] } + +proto_library("webnn_proto") { + sources = [ "compute_resource_info.proto" ] + + # deps = [ + # "//services/webnn/public/mojom:mojom_mojolpm", + # ] + + cc_generator_options = "lite" +} diff --git a/services/webnn/compute_resource_info.proto b/services/webnn/compute_resource_info.proto new file mode 100644 index 00000000000000..78d871530eb229 --- /dev/null +++ b/services/webnn/compute_resource_info.proto @@ -0,0 +1,42 @@ +// Copyright 2025 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Message format for the MojoLPM fuzzer for the webnn service interface. + +syntax = "proto2"; + +package services.webnn.proto; + +enum OperandDataType { + kFloat32 = 10; + kFloat16 = 1; + kInt32 = 2; + kUint32 = 3; + kInt64 = 4; + kUint64 = 5; + kInt8 = 6; + kUint8 = 7; + kInt4 = 8; + kUint4 = 9; +} + +message OperandDescriptor { + required OperandDataType data_type = 1; + repeated uint32 dim = 2; +} + +message OperationIds { + repeated uint64 id = 1; +} + +message ComputeResourceInfo { + map input_names_to_descriptors = 1; + map output_names_to_descriptors = 2; + map operand_to_dependent_operations = 3; +} + +// // +// message NameMapping { +// map ori_name_to_real_name = 1; +// } diff --git a/services/webnn/coreml/context_impl_coreml.h b/services/webnn/coreml/context_impl_coreml.h index 9b25f500e1c1cc..53a824a8fcd423 100644 --- a/services/webnn/coreml/context_impl_coreml.h +++ b/services/webnn/coreml/context_impl_coreml.h @@ -53,6 +53,11 @@ class API_AVAILABLE(macos(14.0)) ContextImplCoreml final mojom::TensorInfoPtr tensor_info, CreateTensorImplCallback callback) override; + void LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) override; + base::WeakPtrFactory weak_factory_{this}; }; diff --git a/services/webnn/coreml/context_impl_coreml.mm b/services/webnn/coreml/context_impl_coreml.mm index b94755ef8ab546..bb7501abe2e90d 100644 --- a/services/webnn/coreml/context_impl_coreml.mm +++ b/services/webnn/coreml/context_impl_coreml.mm @@ -55,4 +55,11 @@ std::move(tensor_info))); } +void ContextImplCoreml::LoadGraphImpl( + mojo::PendingAssociatedRemote remote, + std::string key, + LoadGraphImplCallback callback) { + NOTIMPLEMENTED(); +} + } // namespace webnn::coreml diff --git a/services/webnn/coreml/graph_impl_coreml.h b/services/webnn/coreml/graph_impl_coreml.h index 1cc16e111d771c..90d8c9b9ea48a8 100644 --- a/services/webnn/coreml/graph_impl_coreml.h +++ b/services/webnn/coreml/graph_impl_coreml.h @@ -133,6 +133,10 @@ class API_AVAILABLE(macos(14.0)) GraphImplCoreml final : public WebNNGraphImpl { const base::flat_map& named_outputs) override; + void SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) override; + private: class ComputeResources; diff --git a/services/webnn/coreml/graph_impl_coreml.mm b/services/webnn/coreml/graph_impl_coreml.mm index 75502f8c4d2163..265a704e7e7355 100644 --- a/services/webnn/coreml/graph_impl_coreml.mm +++ b/services/webnn/coreml/graph_impl_coreml.mm @@ -630,6 +630,12 @@ void DidDispatch(base::ElapsedTimer model_predict_timer, task->Enqueue(); } +void GraphImplCoreml::SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) { + NOTIMPLEMENTED(); +} + GraphImplCoreml::Params::Params( ComputeResourceInfo compute_resource_info, base::flat_map coreml_name_to_operand_name) diff --git a/services/webnn/dml/context_impl_dml.cc b/services/webnn/dml/context_impl_dml.cc index a8279d7a76446c..92db9612f3a0c0 100644 --- a/services/webnn/dml/context_impl_dml.cc +++ b/services/webnn/dml/context_impl_dml.cc @@ -944,4 +944,11 @@ void ContextImplDml::RemoveDeviceForTesting() { d3d12_device_5->RemoveDevice(); } +void ContextImplDml::LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) { + NOTIMPLEMENTED(); +} + } // namespace webnn::dml diff --git a/services/webnn/dml/context_impl_dml.h b/services/webnn/dml/context_impl_dml.h index 4479df957cb23b..9da804c12359bd 100644 --- a/services/webnn/dml/context_impl_dml.h +++ b/services/webnn/dml/context_impl_dml.h @@ -94,6 +94,11 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) ContextImplDml final mojom::TensorInfoPtr tensor_info, CreateTensorImplCallback callback) override; + void LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) override; + // Begins recording commands needed for context operations. // If recording failed, calling this function will recreate the recorder to // allow recording to start again. diff --git a/services/webnn/dml/graph_impl_dml.cc b/services/webnn/dml/graph_impl_dml.cc index 2495b670d6a1cf..f77b5f3254f7ce 100644 --- a/services/webnn/dml/graph_impl_dml.cc +++ b/services/webnn/dml/graph_impl_dml.cc @@ -7129,4 +7129,11 @@ void GraphImplDml::OnDispatchComplete( graph_resources_ = std::move(graph_resources); } } + +void GraphImplDml::SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) { + NOTIMPLEMENTED(); +} + } // namespace webnn::dml diff --git a/services/webnn/dml/graph_impl_dml.h b/services/webnn/dml/graph_impl_dml.h index 6ef1105f04c312..c8dbd8cec626b4 100644 --- a/services/webnn/dml/graph_impl_dml.h +++ b/services/webnn/dml/graph_impl_dml.h @@ -287,6 +287,10 @@ class GraphImplDml final : public WebNNGraphImpl { const base::flat_map& named_outputs) override; + void SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) override; + // The persistent resource is allocated after the compilation work is // completed for the graph initialization and will be used for the following // graph executions. It could be nullptr which means it isn't required by the diff --git a/services/webnn/ort/context_impl_ort.cc b/services/webnn/ort/context_impl_ort.cc index bd31f6a4d2e2f6..22817650daea2f 100644 --- a/services/webnn/ort/context_impl_ort.cc +++ b/services/webnn/ort/context_impl_ort.cc @@ -497,4 +497,12 @@ void ContextImplOrt::CreateTensorImpl( TensorImplOrt::Create(std::move(receiver), this, std::move(tensor_info))); } +void ContextImplOrt::LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) { + GraphImplOrt::LoadAndBuild(std::move(receiver), std::move(key), this, + std::move(callback)); +} + } // namespace webnn::ort diff --git a/services/webnn/ort/context_impl_ort.h b/services/webnn/ort/context_impl_ort.h index bb35e264d98f85..df1d7061532ab2 100644 --- a/services/webnn/ort/context_impl_ort.h +++ b/services/webnn/ort/context_impl_ort.h @@ -23,7 +23,7 @@ class SessionOptions final : public base::RefCountedThreadSafe { SessionOptions(const SessionOptions&) = delete; SessionOptions& operator=(const SessionOptions&) = delete; - const OrtSessionOptions* get() const { return session_options_.get(); } + OrtSessionOptions* get() { return session_options_.get(); } private: friend class base::RefCountedThreadSafe; @@ -73,6 +73,11 @@ class ContextImplOrt final : public WebNNContextImpl { mojom::TensorInfoPtr tensor_info, CreateTensorImplCallback callback) override; + void LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) override; + ScopedOrtEnv env_; // The session options are shared among all the sessions created by this diff --git a/services/webnn/ort/graph_impl_ort.cc b/services/webnn/ort/graph_impl_ort.cc index e1d9ba54d4149c..96f11c06e761df 100644 --- a/services/webnn/ort/graph_impl_ort.cc +++ b/services/webnn/ort/graph_impl_ort.cc @@ -5,6 +5,11 @@ #include "services/webnn/ort/graph_impl_ort.h" #include "base/command_line.h" +#include "base/files/file.h" +#include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/json/json_reader.h" +#include "base/json/json_writer.h" #include "base/memory/scoped_refptr.h" #include "base/notimplemented.h" #include "base/task/bind_post_task.h" @@ -25,6 +30,7 @@ #include "services/webnn/resource_task.h" #include "services/webnn/webnn_constant_operand.h" #include "services/webnn/webnn_graph_impl.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" namespace webnn::ort { @@ -122,6 +128,78 @@ class GraphImplOrt::ComputeResources { output_names.size(), output_tensors.data()))); } + // Save EP Context model + void SaveCompiledModel(std::string key, + std::string compute_resource_info, + base::OnceCallback callback) { + TRACE_EVENT0("gpu", + "ort::GraphImplOrt::ComputeResources::SaveCompiledModel"); + + auto cache_dir = base::FilePath::FromASCII( + base::StrCat({".\\EpContextModelCache\\", key})); + if (!base::CreateDirectory(cache_dir)) { + std::move(callback).Run( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to create a cache directory.")); + return; + } + + std::string compiled_model_path = + cache_dir.AppendASCII("model.onnx").MaybeAsASCII(); + if (ORT_CALL_FAILED(GetOrtApi()->SaveEpContextModel( + session_->GetSession(), compiled_model_path.c_str()))) { + std::move(callback).Run( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to save EPContext Model.")); + return; + } + + base::Value::Dict input_names_dict; + for (const auto& [operand_input, onnx_input] : + operand_input_name_to_onnx_input_name_) { + input_names_dict.Set(operand_input, onnx_input); + } + std::string input_names_str; + base::JSONWriter::Write(input_names_dict, &input_names_str); + + base::FilePath input_names_dict_path = + cache_dir.AppendASCII("input_names_dict.json"); + if (!base::WriteFile(input_names_dict_path, input_names_str)) { + std::move(callback).Run( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to write input names dict.")); + return; + } + + base::Value::Dict output_names_dict; + for (const auto& [operand_output, onnx_output] : + operand_output_name_to_onnx_output_name_) { + output_names_dict.Set(operand_output, onnx_output); + } + std::string output_names_str; + base::JSONWriter::Write(output_names_dict, &output_names_str); + + base::FilePath output_names_dict_path = + cache_dir.AppendASCII("output_names_dict.json"); + if (!base::WriteFile(output_names_dict_path, output_names_str)) { + std::move(callback).Run( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to write output names dict.")); + return; + } + + base::FilePath compute_resource_info_path = + cache_dir.AppendASCII("compute_resource_info.txt"); + if (!base::WriteFile(compute_resource_info_path, compute_resource_info)) { + std::move(callback).Run(mojom::Error::New( + mojom::Error::Code::kUnknownError, + "Failed to write compute resources info into disk.")); + return; + } + + std::move(callback).Run(nullptr); + } + private: base::flat_map operand_input_name_to_onnx_input_name_; @@ -248,6 +326,191 @@ void GraphImplOrt::DidCreateAndBuild( std::move(result.value()), static_cast(context.get())))); } +// static +void GraphImplOrt::LoadAndBuild( + mojo::PendingAssociatedReceiver receiver, + std::string key, + ContextImplOrt* context, + WebNNContextImpl::LoadGraphImplCallback callback) { + ScopedTrace scoped_trace("GraphImplOrt::LoadAndBuild"); + + auto wrapped_callback = base::BindPostTaskToCurrentDefault( + base::BindOnce(&GraphImplOrt::DidLoadAndBuild, std::move(receiver), + context->AsWeakPtr(), std::move(callback))); + + base::ThreadPool::PostTask( + FROM_HERE, + {base::TaskPriority::USER_BLOCKING, + base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN, base::MayBlock()}, + base::BindOnce(&GraphImplOrt::LoadAndBuildOnBackgroundThread, + std::move(key), context->session_options(), + std::move(wrapped_callback), std::move(scoped_trace))); +} + +GraphImplOrt::ComputeResourcesAndInfo::ComputeResourcesAndInfo( + ComputeResourceInfo compute_resource_info, + std::unique_ptr compute_resources) + : compute_resource_info(std::move(compute_resource_info)), + compute_resources(std::move(compute_resources)) {} +GraphImplOrt::ComputeResourcesAndInfo::~ComputeResourcesAndInfo() = default; + +// static +void GraphImplOrt::LoadAndBuildOnBackgroundThread( + std::string key, + scoped_refptr session_options, + base::OnceCallback< + void(base::expected, + mojom::ErrorPtr>)> callback, + ScopedTrace scoped_trace) { + scoped_trace.AddStep("Create Env"); + + const OrtApi* ort_api = GetOrtApi(); + ScopedOrtEnv env; + CHECK(IsSuccess(ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "WebNN", + ScopedOrtEnv::Receiver(env).get()))); + + scoped_trace.AddStep("Load compiled model"); + + auto cache_dir = base::FilePath::FromASCII( + base::StrCat({".\\EpContextModelCache\\", key})); + std::wstring compiled_model_path = + cache_dir.AppendASCII("model.onnx").value(); + + ScopedOrtSession session; + + // disable EP context + CHECK(IsSuccess(ort_api->AddSessionConfigEntry( + session_options->get(), kOrtSessionOptionEpContextEnable, + /*config_value=*/"0"))); + + if (ORT_CALL_FAILED(ort_api->CreateSession( + env.get(), compiled_model_path.c_str(), session_options->get(), + ScopedOrtSession::Receiver(session).get()))) { + std::move(callback).Run(base::unexpected( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to load the compiled model."))); + return; + } + + scoped_trace.AddStep("Get compute resource info"); + + base::FilePath input_names_dict_path = + cache_dir.AppendASCII("input_names_dict.json"); + std::string input_names_str; + if (!base::ReadFileToString(input_names_dict_path, &input_names_str)) { + std::move(callback).Run(base::unexpected( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to read input_name_dict.json."))); + return; + } + + auto input_names_dict = base::JSONReader::ReadDict(input_names_str); + if (!input_names_dict.has_value()) { + std::move(callback).Run(base::unexpected(mojom::Error::New( + mojom::Error::Code::kUnknownError, "Failed to get input names dict."))); + return; + } + + base::flat_map + operand_input_name_to_onnx_input_name; + for (auto current = input_names_dict->begin(); + current != input_names_dict->end(); ++current) { + operand_input_name_to_onnx_input_name.emplace(current->first, + current->second.GetString()); + } + + base::FilePath output_names_dict_path = + cache_dir.AppendASCII("output_names_dict.json"); + std::string output_names_str; + if (!base::ReadFileToString(output_names_dict_path, &output_names_str)) { + std::move(callback).Run(base::unexpected( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to read output_name_dict.json."))); + return; + } + + auto output_names_dict = base::JSONReader::ReadDict(output_names_str); + if (!output_names_dict.has_value()) { + std::move(callback).Run(base::unexpected( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to read the output_name_dict.json."))); + return; + } + + base::flat_map + operand_output_name_to_onnx_output_name; + for (auto current = output_names_dict->begin(); + current != output_names_dict->end(); ++current) { + operand_output_name_to_onnx_output_name.emplace( + current->first, current->second.GetString()); + } + + base::FilePath compute_resource_info_path = + cache_dir.AppendASCII("compute_resource_info.txt"); + std::string compute_resource_info_str; + if (!base::ReadFileToString(compute_resource_info_path, + &compute_resource_info_str)) { + std::move(callback).Run(base::unexpected( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to read the compute resource info."))); + return; + } + + auto compute_resource_info = + ComputeResourceInfo::ParseFromString(compute_resource_info_str); + if (!compute_resource_info.has_value()) { + std::move(callback).Run( + base::unexpected(std::move(compute_resource_info.error()))); + return; + } + + scoped_trace.AddStep("Create compute resources"); + + auto compute_session = + base::WrapUnique(new Session(std::move(env), std::move(session), + std::vector>{})); + + std::move(callback).Run(base::WrapUnique(new ComputeResourcesAndInfo( + std::move(compute_resource_info.value()), + base::WrapUnique(new GraphImplOrt::ComputeResources( + std::move(compute_session), + std::move(operand_input_name_to_onnx_input_name), + std::move(operand_output_name_to_onnx_output_name)))))); +} + +// static +void GraphImplOrt::DidLoadAndBuild( + mojo::PendingAssociatedReceiver receiver, + base::WeakPtr context, + WebNNContextImpl::LoadGraphImplCallback callback, + base::expected, mojom::ErrorPtr> + result) { + if (!result.has_value()) { + std::move(callback).Run(base::unexpected(std::move(result.error()))); + return; + } + + if (!context) { + std::move(callback).Run(base::unexpected(mojom::Error::New( + mojom::Error::Code::kUnknownError, "Context was destroyed."))); + return; + } + + auto input_names_to_descriptors = + result.value()->compute_resource_info.input_names_to_descriptors; + auto output_names_to_descriptors = + result.value()->compute_resource_info.output_names_to_descriptors; + std::move(callback).Run( + base::WrapUnique(new WebNNContextImpl::LoadGraphResult( + base::WrapUnique( + new GraphImplOrt(std::move(receiver), + std::move(result.value()->compute_resource_info), + std::move(result.value()->compute_resources), + static_cast(context.get()))), + std::move(input_names_to_descriptors), + std::move(output_names_to_descriptors)))); +} + GraphImplOrt::~GraphImplOrt() = default; GraphImplOrt::GraphImplOrt( @@ -342,4 +605,55 @@ void GraphImplOrt::DispatchImpl( task->Enqueue(); } +void GraphImplOrt::SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) { + TRACE_EVENT0("gpu", "ort::GraphImplOrt::SaveGraphImpl"); + + std::vector> exclusive_resources; + exclusive_resources.reserve(1); + exclusive_resources.push_back(compute_resources_state_); + + std::string compute_resource_info; + if (!this->compute_resource_info().SerializeToString(compute_resource_info)) { + std::move(callback).Run( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to serialize compute resources info.")); + return; + } + + auto save_graph_callback = + base::BindPostTaskToCurrentDefault(std::move(callback)); + + auto task = base::MakeRefCounted( + std::vector>{}, + std::move(exclusive_resources), + base::BindOnce( + [](scoped_refptr> + compute_resources_state, + std::string key, std::string compute_resource_info, + base::OnceCallback save_graph_callback, + base::OnceClosure completion_closure) { + ComputeResources* raw_compute_resources = + compute_resources_state->GetExclusivelyLockedResource(); + + // Compute tasks can take a significant amount of time, use the + // thread pool to avoid blocking the main thread. + base::ThreadPool::PostTaskAndReply( + FROM_HERE, + {base::TaskPriority::USER_BLOCKING, + base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN, + base::MayBlock()}, + base::BindOnce(&ComputeResources::SaveCompiledModel, + base::Unretained(raw_compute_resources), + std::move(key), std::move(compute_resource_info), + std::move(save_graph_callback)), + std::move(completion_closure)); + }, + compute_resources_state_, std::string(key), + std::move(compute_resource_info), std::move(save_graph_callback))); + + task->Enqueue(); +} + } // namespace webnn::ort diff --git a/services/webnn/ort/graph_impl_ort.h b/services/webnn/ort/graph_impl_ort.h index 5ea62796042af4..68e5e4ef362b4a 100644 --- a/services/webnn/ort/graph_impl_ort.h +++ b/services/webnn/ort/graph_impl_ort.h @@ -47,6 +47,12 @@ class GraphImplOrt final : public WebNNGraphImpl { ContextImplOrt* context, WebNNContextImpl::CreateGraphImplCallback callback); + static void LoadAndBuild( + mojo::PendingAssociatedReceiver receiver, + std::string key, + ContextImplOrt* context, + WebNNContextImpl::LoadGraphImplCallback callback); + GraphImplOrt(const GraphImplOrt&) = delete; GraphImplOrt& operator=(const GraphImplOrt&) = delete; ~GraphImplOrt() override; @@ -76,6 +82,34 @@ class GraphImplOrt final : public WebNNGraphImpl { base::expected, mojom::ErrorPtr> result); + struct ComputeResourcesAndInfo { + ComputeResourcesAndInfo( + ComputeResourceInfo compute_resource_info, + std::unique_ptr compute_resources); + ~ComputeResourcesAndInfo(); + + ComputeResourcesAndInfo(const ComputeResourcesAndInfo&) = delete; + ComputeResourcesAndInfo& operator=(const ComputeResourcesAndInfo&) = delete; + + ComputeResourceInfo compute_resource_info; + std::unique_ptr compute_resources; + }; + + static void LoadAndBuildOnBackgroundThread( + std::string key, + scoped_refptr session_options, + base::OnceCallback< + void(base::expected, + mojom::ErrorPtr>)> callback, + ScopedTrace scoped_trace); + + static void DidLoadAndBuild( + mojo::PendingAssociatedReceiver receiver, + base::WeakPtr context, + WebNNContextImpl::LoadGraphImplCallback callback, + base::expected, mojom::ErrorPtr> + result); + // Execute the compiled platform graph asynchronously. The inputs were // validated in base class so we can use them to compute directly. void DispatchImpl(const base::flat_map& @@ -83,6 +117,10 @@ class GraphImplOrt final : public WebNNGraphImpl { const base::flat_map& named_output_tensors) override; + void SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) override; + // std::map operand_infos_; scoped_refptr> compute_resources_state_; diff --git a/services/webnn/public/cpp/operand_descriptor.cc b/services/webnn/public/cpp/operand_descriptor.cc index f87e60e9441202..41b90646cb800d 100644 --- a/services/webnn/public/cpp/operand_descriptor.cc +++ b/services/webnn/public/cpp/operand_descriptor.cc @@ -133,6 +133,8 @@ size_t OperandDescriptor::GetBitsPerElement(OperandDataType data_type) { } } +OperandDescriptor::OperandDescriptor() {} + OperandDescriptor::OperandDescriptor(mojo::DefaultConstruct::Tag) {} OperandDescriptor::OperandDescriptor(OperandDataType data_type, diff --git a/services/webnn/public/cpp/operand_descriptor.h b/services/webnn/public/cpp/operand_descriptor.h index 98c69b01b20821..339ba467bc8bb3 100644 --- a/services/webnn/public/cpp/operand_descriptor.h +++ b/services/webnn/public/cpp/operand_descriptor.h @@ -56,6 +56,10 @@ class COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) OperandDescriptor { static size_t GetBitsPerElement(OperandDataType data_type); + // Creates an invalid instance for use with Mojo deserialization (for HashMap + // use as EmptyValue()), which requires types to be default-constructible. + OperandDescriptor(); + // Creates an invalid instance for use with Mojo deserialization, which // requires types to be default-constructible. explicit OperandDescriptor(mojo::DefaultConstruct::Tag); diff --git a/services/webnn/public/mojom/webnn_context.mojom b/services/webnn/public/mojom/webnn_context.mojom index 8cf41e1b016a0e..91215ab7067e5c 100644 --- a/services/webnn/public/mojom/webnn_context.mojom +++ b/services/webnn/public/mojom/webnn_context.mojom @@ -9,6 +9,7 @@ import "services/webnn/public/mojom/webnn_tensor.mojom"; import "services/webnn/public/mojom/webnn_error.mojom"; import "services/webnn/public/mojom/webnn_graph_builder.mojom"; import "third_party/blink/public/mojom/tokens/tokens.mojom"; +import "services/webnn/public/mojom/webnn_graph.mojom"; // Represents a successful call to `WebNNContext::CreateTensor()`. struct CreateTensorSuccess { @@ -27,6 +28,17 @@ union CreateTensorResult { Error error; }; +struct LoadGraphSuccess { + pending_associated_remote graph_remote; + map input_constraints; + map output_constraints; +}; + +union LoadGraphResult { + LoadGraphSuccess success; + Error error; +}; + // Represents the `MLContext` object in the WebIDL definition that is a global // state of neural network compute workload and execution processes. This // interface runs in the GPU process and is called from the renderer process. @@ -39,4 +51,6 @@ interface WebNNContext { // creating platform specific tensors, the WebNN tensor will be validated and // created. This method guarantees memory allocation on the device. CreateTensor(TensorInfo tensor_info) => (CreateTensorResult result); -}; \ No newline at end of file + + LoadGraph(string key) => (LoadGraphResult result); +}; diff --git a/services/webnn/public/mojom/webnn_graph.mojom b/services/webnn/public/mojom/webnn_graph.mojom index 8eef8635dd6d69..745350690fce59 100644 --- a/services/webnn/public/mojom/webnn_graph.mojom +++ b/services/webnn/public/mojom/webnn_graph.mojom @@ -5,6 +5,7 @@ module webnn.mojom; import "services/webnn/public/mojom/webnn_context_properties.mojom"; +import "services/webnn/public/mojom/webnn_error.mojom"; import "third_party/blink/public/mojom/tokens/tokens.mojom"; // Represents the `MLOperandType` in the WebIDL definition. @@ -1438,4 +1439,6 @@ interface WebNNGraph { // data. Dispatch(map named_inputs, map named_outputs); + + SaveGraph(string key) => (Error? error); }; diff --git a/services/webnn/tflite/context_impl_tflite.cc b/services/webnn/tflite/context_impl_tflite.cc index ff5b7ecbe60551..f3434b672f5993 100644 --- a/services/webnn/tflite/context_impl_tflite.cc +++ b/services/webnn/tflite/context_impl_tflite.cc @@ -51,4 +51,11 @@ void ContextImplTflite::CreateTensorImpl( std::move(tensor_info))); } +void ContextImplTflite::LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) { + NOTIMPLEMENTED(); +} + } // namespace webnn::tflite diff --git a/services/webnn/tflite/context_impl_tflite.h b/services/webnn/tflite/context_impl_tflite.h index 09060e8dafbe9a..c4de71c1fa5a47 100644 --- a/services/webnn/tflite/context_impl_tflite.h +++ b/services/webnn/tflite/context_impl_tflite.h @@ -45,6 +45,11 @@ class ContextImplTflite final : public WebNNContextImpl { mojom::TensorInfoPtr tensor_info, CreateTensorImplCallback callback) override; + void LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) override; + base::WeakPtrFactory weak_factory_{this}; }; diff --git a/services/webnn/tflite/graph_impl_tflite.cc b/services/webnn/tflite/graph_impl_tflite.cc index bf919d957acefa..97d112c0c5d11b 100644 --- a/services/webnn/tflite/graph_impl_tflite.cc +++ b/services/webnn/tflite/graph_impl_tflite.cc @@ -412,4 +412,10 @@ void GraphImplTflite::DispatchImpl( task->Enqueue(); } +void GraphImplTflite::SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) { + NOTIMPLEMENTED(); +} + } // namespace webnn::tflite diff --git a/services/webnn/tflite/graph_impl_tflite.h b/services/webnn/tflite/graph_impl_tflite.h index 4cac6cdbb7f65c..e24c1797b3af1e 100644 --- a/services/webnn/tflite/graph_impl_tflite.h +++ b/services/webnn/tflite/graph_impl_tflite.h @@ -64,6 +64,10 @@ class GraphImplTflite final : public WebNNGraphImpl { const base::flat_map& named_outputs) override; + void SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) override; + scoped_refptr> compute_resources_state_; base::flat_map input_name_to_index_; diff --git a/services/webnn/webnn_context_impl.cc b/services/webnn/webnn_context_impl.cc index 487d1532dd3305..66afa18817e151 100644 --- a/services/webnn/webnn_context_impl.cc +++ b/services/webnn/webnn_context_impl.cc @@ -27,6 +27,15 @@ namespace webnn { +WebNNContextImpl::LoadGraphResult::LoadGraphResult( + std::unique_ptr graph, + base::flat_map input_constraints, + base::flat_map output_constraints) + : graph(std::move(graph)), + input_constraints(std::move(input_constraints)), + output_constraints(std::move(output_constraints)) {} +WebNNContextImpl::LoadGraphResult::~LoadGraphResult() = default; + WebNNContextImpl::WebNNContextImpl( mojo::PendingReceiver receiver, WebNNContextProviderImpl* context_provider, @@ -122,6 +131,36 @@ void WebNNContextImpl::DidCreateWebNNTensorImpl( tensor_impls_.emplace(*std::move(result)); } +void WebNNContextImpl::LoadGraph( + const std::string& key, + mojom::WebNNContext::LoadGraphCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + mojo::PendingAssociatedRemote remote; + auto receiver = remote.InitWithNewEndpointAndPassReceiver(); + LoadGraphImpl(std::move(receiver), key, + base::BindOnce(&WebNNContextImpl::DidLoadGraphImpl, AsWeakPtr(), + std::move(callback), std::move(remote))); +} + +void WebNNContextImpl::DidLoadGraphImpl( + mojom::WebNNContext::LoadGraphCallback callback, + mojo::PendingAssociatedRemote remote, + base::expected, mojom::ErrorPtr> result) { + if (!result.has_value()) { + std::move(callback).Run( + mojom::LoadGraphResult::NewError(std::move(result.error()))); + return; + } + + auto success = mojom::LoadGraphSuccess::New( + std::move(remote), std::move(result.value()->input_constraints), + std::move(result.value()->output_constraints)); + std::move(callback).Run( + mojom::LoadGraphResult::NewSuccess(std::move(success))); + + graph_impls_.emplace(std::move(result.value()->graph)); +} + void WebNNContextImpl::DisconnectAndDestroyWebNNTensorImpl( const blink::WebNNTensorToken& handle) { const auto it = tensor_impls_.find(handle); diff --git a/services/webnn/webnn_context_impl.h b/services/webnn/webnn_context_impl.h index 24976659b27c96..bdd292c53d475c 100644 --- a/services/webnn/webnn_context_impl.h +++ b/services/webnn/webnn_context_impl.h @@ -23,6 +23,7 @@ #include "mojo/public/cpp/bindings/receiver.h" #include "mojo/public/cpp/bindings/unique_associated_receiver_set.h" #include "services/webnn/public/cpp/context_properties.h" +#include "services/webnn/public/mojom/operand_descriptor_mojom_traits.h" #include "services/webnn/public/mojom/webnn_context.mojom.h" #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" #include "services/webnn/public/mojom/webnn_error.mojom-forward.h" @@ -50,6 +51,24 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNContextImpl using CreateTensorImplCallback = base::OnceCallback, mojom::ErrorPtr>)>; + struct LoadGraphResult { + LoadGraphResult( + std::unique_ptr graph, + base::flat_map input_constraints, + base::flat_map + output_constraints); + ~LoadGraphResult(); + + LoadGraphResult(const LoadGraphResult&) = delete; + LoadGraphResult& operator=(const LoadGraphResult&) = delete; + + std::unique_ptr graph; + base::flat_map input_constraints; + base::flat_map output_constraints; + }; + using LoadGraphImplCallback = base::OnceCallback, mojom::ErrorPtr>)>; + WebNNContextImpl(mojo::PendingReceiver receiver, WebNNContextProviderImpl* context_provider, ContextProperties properties, @@ -141,6 +160,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNContextImpl override; void CreateTensor(mojom::TensorInfoPtr tensor_info, CreateTensorCallback callback) override; + void LoadGraph(const std::string& key, LoadGraphCallback callback) override; // This method will be called by `CreateTensor()` after the tensor info is // validated. A backend subclass should implement this method to create and @@ -155,6 +175,16 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNContextImpl mojo::PendingAssociatedRemote remote, base::expected, mojom::ErrorPtr> result); + virtual void LoadGraphImpl( + mojo::PendingAssociatedReceiver receiver, + std::string key, + LoadGraphImplCallback callback) = 0; + + void DidLoadGraphImpl( + LoadGraphCallback callback, + mojo::PendingAssociatedRemote remote, + base::expected, mojom::ErrorPtr> result); + SEQUENCE_CHECKER(sequence_checker_); mojo::Receiver receiver_; diff --git a/services/webnn/webnn_graph_impl.cc b/services/webnn/webnn_graph_impl.cc index 354e3f47d24d33..7b42c92c9fa7db 100644 --- a/services/webnn/webnn_graph_impl.cc +++ b/services/webnn/webnn_graph_impl.cc @@ -13,6 +13,7 @@ #include "base/dcheck_is_on.h" #include "base/types/optional_ref.h" #include "base/types/pass_key.h" +#include "services/webnn/compute_resource_info.pb.h" #include "services/webnn/error.h" #include "services/webnn/public/cpp/operand_descriptor.h" #include "services/webnn/webnn_context_impl.h" @@ -76,6 +77,16 @@ WebNNGraphImpl::ComputeResourceInfo::ComputeResourceInfo( operand_to_dependent_operations( std::move(operand_to_dependent_operations)) {} +WebNNGraphImpl::ComputeResourceInfo::ComputeResourceInfo( + base::flat_map input_names_to_descriptors, + base::flat_map output_names_to_descriptors, + base::flat_map> + operand_to_dependent_operations) + : input_names_to_descriptors(std::move(input_names_to_descriptors)), + output_names_to_descriptors(std::move(output_names_to_descriptors)), + operand_to_dependent_operations( + std::move(operand_to_dependent_operations)) {} + WebNNGraphImpl::ComputeResourceInfo::ComputeResourceInfo( ComputeResourceInfo&&) = default; WebNNGraphImpl::ComputeResourceInfo& @@ -83,6 +94,155 @@ WebNNGraphImpl::ComputeResourceInfo::operator=(ComputeResourceInfo&&) = default; WebNNGraphImpl::ComputeResourceInfo::~ComputeResourceInfo() = default; +bool WebNNGraphImpl::ComputeResourceInfo::SerializeToString( + std::string& str) const { + services::webnn::proto::ComputeResourceInfo resource_info_proto; + + auto* input_names_to_descriptors_proto = + resource_info_proto.mutable_input_names_to_descriptors(); + auto* output_names_to_descriptors_proto = + resource_info_proto.mutable_output_names_to_descriptors(); + auto* operand_to_dependent_operations_proto = + resource_info_proto.mutable_operand_to_dependent_operations(); + + auto convert_data_type = [](OperandDataType data_type) { + switch (data_type) { + case OperandDataType::kFloat32: + return services::webnn::proto::OperandDataType::kFloat32; + case OperandDataType::kFloat16: + return services::webnn::proto::OperandDataType::kFloat16; + case OperandDataType::kInt32: + return services::webnn::proto::OperandDataType::kInt32; + case OperandDataType::kUint32: + return services::webnn::proto::OperandDataType::kUint32; + case OperandDataType::kInt64: + return services::webnn::proto::OperandDataType::kInt64; + case OperandDataType::kUint64: + return services::webnn::proto::OperandDataType::kUint64; + case OperandDataType::kInt8: + return services::webnn::proto::OperandDataType::kInt8; + case OperandDataType::kUint8: + return services::webnn::proto::OperandDataType::kUint8; + case OperandDataType::kInt4: + return services::webnn::proto::OperandDataType::kInt4; + case OperandDataType::kUint4: + return services::webnn::proto::OperandDataType::kUint4; + } + }; + + for (const auto& [name, descriptor] : input_names_to_descriptors) { + services::webnn::proto::OperandDescriptor descriptor_proto; + for (uint32_t dim : descriptor.shape()) { + descriptor_proto.add_dim(dim); + } + descriptor_proto.set_data_type(convert_data_type(descriptor.data_type())); + input_names_to_descriptors_proto->emplace(name, + std::move(descriptor_proto)); + } + + for (const auto& [name, descriptor] : output_names_to_descriptors) { + services::webnn::proto::OperandDescriptor descriptor_proto; + for (uint32_t dim : descriptor.shape()) { + descriptor_proto.add_dim(dim); + } + descriptor_proto.set_data_type(convert_data_type(descriptor.data_type())); + output_names_to_descriptors_proto->emplace(name, + std::move(descriptor_proto)); + } + + for (const auto& [operand_id, operations] : operand_to_dependent_operations) { + services::webnn::proto::OperationIds operations_ids_proto; + for (size_t operation_id : operations) { + operations_ids_proto.add_id(operation_id); + } + operand_to_dependent_operations_proto->emplace( + operand_id, std::move(operations_ids_proto)); + } + + return resource_info_proto.SerializeToString(&str); +} + +// static +base::expected +WebNNGraphImpl::ComputeResourceInfo::ParseFromString(std::string_view str) { + services::webnn::proto::ComputeResourceInfo resource_info_proto; + if (!resource_info_proto.ParseFromString(str)) { + return base::unexpected( + mojom::Error::New(mojom::Error::Code::kUnknownError, + "Failed to parse the compute resource info.")); + } + + auto convert_data_type = + [](services::webnn::proto::OperandDataType data_type) { + switch (data_type) { + case services::webnn::proto::OperandDataType::kFloat32: + return OperandDataType::kFloat32; + case services::webnn::proto::OperandDataType::kFloat16: + return OperandDataType::kFloat16; + case services::webnn::proto::OperandDataType::kInt32: + return OperandDataType::kInt32; + case services::webnn::proto::OperandDataType::kUint32: + return OperandDataType::kUint32; + case services::webnn::proto::OperandDataType::kInt64: + return OperandDataType::kInt64; + case services::webnn::proto::OperandDataType::kUint64: + return OperandDataType::kUint64; + case services::webnn::proto::OperandDataType::kInt8: + return OperandDataType::kInt8; + case services::webnn::proto::OperandDataType::kUint8: + return OperandDataType::kUint8; + case services::webnn::proto::OperandDataType::kInt4: + return OperandDataType::kInt4; + case services::webnn::proto::OperandDataType::kUint4: + return OperandDataType::kUint4; + } + }; + + base::flat_map input_names_to_descriptors; + for (const auto& [name, descriptor_proto] : + resource_info_proto.input_names_to_descriptors()) { + std::vector shape; + for (uint32_t dim : descriptor_proto.dim()) { + shape.push_back(dim); + } + auto descriptor = + OperandDescriptor::CreateForDeserialization( + convert_data_type(descriptor_proto.data_type()), shape) + .value(); + input_names_to_descriptors.emplace(name, std::move(descriptor)); + } + + base::flat_map output_names_to_descriptors; + for (const auto& [name, descriptor_proto] : + resource_info_proto.output_names_to_descriptors()) { + std::vector shape; + for (uint32_t dim : descriptor_proto.dim()) { + shape.push_back(dim); + } + auto descriptor = + OperandDescriptor::CreateForDeserialization( + convert_data_type(descriptor_proto.data_type()), shape) + .value(); + output_names_to_descriptors.emplace(name, std::move(descriptor)); + } + + base::flat_map> + operand_to_dependent_operations; + for (const auto& [operand_id, operations_proto] : + resource_info_proto.operand_to_dependent_operations()) { + base::flat_set operation_ids; + for (uint64_t operation_id : operations_proto.id()) { + operation_ids.insert(base::checked_cast(operation_id)); + } + operand_to_dependent_operations.emplace(operand_id, + std::move(operation_ids)); + } + + return ComputeResourceInfo{std::move(input_names_to_descriptors), + std::move(output_names_to_descriptors), + std::move(operand_to_dependent_operations)}; +} + WebNNGraphImpl::WebNNGraphImpl( mojo::PendingAssociatedReceiver receiver, WebNNContextImpl* context, @@ -162,4 +322,10 @@ void WebNNGraphImpl::Dispatch( DispatchImpl(name_to_input_tensor_map, name_to_output_tensor_map); } +void WebNNGraphImpl::SaveGraph( + const std::string& key, + base::OnceCallback callback) { + SaveGraphImpl(key, std::move(callback)); +} + } // namespace webnn diff --git a/services/webnn/webnn_graph_impl.h b/services/webnn/webnn_graph_impl.h index 5e3615c120ccf6..5a4a3d878173fd 100644 --- a/services/webnn/webnn_graph_impl.h +++ b/services/webnn/webnn_graph_impl.h @@ -13,6 +13,7 @@ #include "mojo/public/cpp/bindings/associated_receiver.h" #include "mojo/public/cpp/bindings/pending_associated_receiver.h" #include "services/webnn/public/cpp/operand_descriptor.h" +#include "services/webnn/public/mojom/webnn_error.mojom.h" #include "services/webnn/public/mojom/webnn_graph.mojom.h" #include "services/webnn/webnn_object_impl.h" @@ -43,6 +44,17 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNGraphImpl ComputeResourceInfo(ComputeResourceInfo&&); ComputeResourceInfo& operator=(ComputeResourceInfo&&); + ComputeResourceInfo(base::flat_map + input_names_to_descriptors, + base::flat_map + output_names_to_descriptors, + base::flat_map> + operand_to_dependent_operations); + + bool SerializeToString(std::string& str) const; + static base::expected ParseFromString( + std::string_view str); + base::flat_map input_names_to_descriptors; base::flat_map output_names_to_descriptors; base::flat_map> @@ -89,6 +101,13 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNGraphImpl const raw_ptr context_; mojo::AssociatedReceiver receiver_; + + void SaveGraph(const std::string& key, + base::OnceCallback callback) override; + + virtual void SaveGraphImpl( + std::string_view key, + base::OnceCallback callback) = 0; }; } // namespace webnn diff --git a/third_party/blink/renderer/modules/ml/ml_context.cc b/third_party/blink/renderer/modules/ml/ml_context.cc index 686494218ca5f0..a81216de5ffc5c 100644 --- a/third_party/blink/renderer/modules/ml/ml_context.cc +++ b/third_party/blink/renderer/modules/ml/ml_context.cc @@ -137,6 +137,7 @@ void MLContext::Trace(Visitor* visitor) const { visitor->Trace(lost_property_); visitor->Trace(context_remote_); visitor->Trace(pending_resolvers_); + visitor->Trace(pending_graph_resolvers_); visitor->Trace(graphs_); visitor->Trace(graph_builders_); visitor->Trace(tensors_); @@ -213,6 +214,12 @@ void MLContext::OnLost(uint32_t custom_reason, const std::string& description) { "Context is lost."); } pending_resolvers_.clear(); + + for (const auto& resolver : pending_graph_resolvers_) { + resolver->RejectWithDOMException(DOMExceptionCode::kInvalidStateError, + "Context is lost."); + } + pending_graph_resolvers_.clear(); } const MLOpSupportLimits* MLContext::opSupportLimits(ScriptState* script_state) { @@ -1164,6 +1171,89 @@ void MLContext::dispatch(ScriptState* script_state, exception_state); } +ScriptPromise MLContext::saveGraph( + ScriptState* script_state, + String key, + MLGraph* graph, + ExceptionState& exception_state) { + webnn::ScopedTrace scoped_trace("MLContext::saveGraph"); + if (!script_state->ContextIsValid()) { + exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, + "Invalid script state"); + return EmptyPromise(); + } + if (graph->Context() != this) { + exception_state.ThrowTypeError( + "The graph isn't built within this context."); + return EmptyPromise(); + } + + return graph->SaveGraph(std::move(scoped_trace), script_state, std::move(key), + exception_state); +} + +ScriptPromise MLContext::loadGraph(ScriptState* script_state, + String key, + ExceptionState& exception_state) { + webnn::ScopedTrace scoped_trace("MLContext::loadGraph"); + if (!script_state->ContextIsValid()) { + exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, + "Invalid script state"); + return EmptyPromise(); + } + + auto* resolver = MakeGarbageCollected>( + script_state, exception_state.GetContext()); + pending_graph_resolvers_.insert(resolver); + + context_remote_->LoadGraph( + std::move(key), + WTF::BindOnce(&MLContext::DidLoadGraph, WrapPersistent(this), + std::move(scoped_trace), WrapPersistent(resolver))); + + return resolver->Promise(); +} + +void MLContext::DidLoadGraph(webnn::ScopedTrace scoped_trace, + ScriptPromiseResolver* resolver, + webnn::mojom::blink::LoadGraphResultPtr result) { + scoped_trace.AddStep("MLContext::DidLoadGraph"); + + pending_graph_resolvers_.erase(resolver); + + ScriptState* script_state = resolver->GetScriptState(); + if (!script_state->ContextIsValid()) { + return; + } + + if (result->is_error()) { + const auto& load_graph_error = result->get_error(); + resolver->RejectWithDOMException( + WebNNErrorCodeToDOMExceptionCode(load_graph_error->code), + load_graph_error->message); + return; + } + + auto& success = result->get_success(); + + MLGraph::NamedOperandDescriptors input_constraints; + for (const auto& constraint : success->input_constraints) { + input_constraints.insert(constraint.key, constraint.value); + } + MLGraph::NamedOperandDescriptors output_constraints; + for (const auto& constraint : success->output_constraints) { + output_constraints.insert(constraint.key, constraint.value); + } + + auto* graph = MakeGarbageCollected( + resolver->GetExecutionContext(), this, std::move(success->graph_remote), + std::move(input_constraints), std::move(output_constraints), + base::PassKey()); + graphs_.insert(graph); + + resolver->Resolve(graph); +} + void MLContext::DidCreateWebNNTensor( webnn::ScopedTrace scoped_trace, ScriptPromiseResolver* resolver, diff --git a/third_party/blink/renderer/modules/ml/ml_context.h b/third_party/blink/renderer/modules/ml/ml_context.h index 9ae502dc2e6923..0cf7851337c5e4 100644 --- a/third_party/blink/renderer/modules/ml/ml_context.h +++ b/third_party/blink/renderer/modules/ml/ml_context.h @@ -95,6 +95,15 @@ class MODULES_EXPORT MLContext : public ScriptWrappable { const MLNamedTensors& outputs, ExceptionState& exception_state); + ScriptPromise saveGraph(ScriptState* script_state, + String key, + MLGraph* graph, + ExceptionState& exception_state); + + ScriptPromise loadGraph(ScriptState* script_state, + String key, + ExceptionState& exception_state); + MLGraphBuilder* CreateWebNNGraphBuilder(ScriptState* script_state, ExceptionState& exception_state); @@ -108,6 +117,10 @@ class MODULES_EXPORT MLContext : public ScriptWrappable { // Close the `context_remote_` pipe because the context has been lost. void OnLost(uint32_t custom_reason, const std::string& description); + void DidLoadGraph(webnn::ScopedTrace scoped_trace, + ScriptPromiseResolver* resolver, + webnn::mojom::blink::LoadGraphResultPtr result); + void DidCreateWebNNTensor(webnn::ScopedTrace scoped_trace, ScriptPromiseResolver* resolver, webnn::OperandDescriptor validated_descriptor, @@ -131,6 +144,8 @@ class MODULES_EXPORT MLContext : public ScriptWrappable { // rejected when the Mojo pipe is unexpectedly disconnected. HeapHashSet>> pending_resolvers_; + HeapHashSet>> pending_graph_resolvers_; + HeapHashSet> graphs_; HeapHashSet> graph_builders_; HeapHashSet> tensors_; diff --git a/third_party/blink/renderer/modules/ml/ml_context.idl b/third_party/blink/renderer/modules/ml/ml_context.idl index 8433a5c383220d..6f7b68d1a56b2b 100644 --- a/third_party/blink/renderer/modules/ml/ml_context.idl +++ b/third_party/blink/renderer/modules/ml/ml_context.idl @@ -331,4 +331,16 @@ typedef record MLNamedTensors; RuntimeEnabled=MachineLearningNeuralNetwork, CallWith=ScriptState ] MLOpSupportLimits opSupportLimits(); + + [ + RuntimeEnabled=MachineLearningNeuralNetwork, + CallWith=ScriptState, + RaisesException + ] Promise saveGraph(DOMString key, MLGraph graph); + + [ + RuntimeEnabled=MachineLearningNeuralNetwork, + CallWith=ScriptState, + RaisesException + ] Promise loadGraph(DOMString key); }; diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc index a1d29d0e27e2a8..632d121a02bfa9 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc @@ -10,6 +10,7 @@ #include "third_party/blink/renderer/core/execution_context/execution_context.h" #include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h" #include "third_party/blink/renderer/modules/ml/ml_context.h" +#include "third_party/blink/renderer/modules/ml/webnn/ml_error.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_tensor.h" @@ -131,11 +132,31 @@ MLGraph::MLGraph(ExecutionContext* execution_context, WTF::BindOnce(&MLGraph::OnConnectionError, WrapWeakPersistent(this))); } +MLGraph::MLGraph(ExecutionContext* execution_context, + MLContext* context, + mojo::PendingAssociatedRemote + pending_graph_remote, + NamedOperandDescriptors input_constraints, + NamedOperandDescriptors output_constraints, + base::PassKey /*pass_key*/) + : input_constraints_(std::move(input_constraints)), + output_constraints_(std::move(output_constraints)), + ml_context_(context), + remote_graph_(execution_context) { + // Bind the end point of `WebNNGraph` mojo interface in the blink side. + remote_graph_.Bind( + std::move(pending_graph_remote), + execution_context->GetTaskRunner(TaskType::kMachineLearning)); + remote_graph_.set_disconnect_handler( + WTF::BindOnce(&MLGraph::OnConnectionError, WrapWeakPersistent(this))); +} + MLGraph::~MLGraph() = default; void MLGraph::Trace(Visitor* visitor) const { visitor->Trace(ml_context_); visitor->Trace(remote_graph_); + visitor->Trace(pending_resolvers_); ScriptWrappable::Trace(visitor); } @@ -203,12 +224,56 @@ void MLGraph::Dispatch(webnn::ScopedTrace scoped_trace, remote_graph_->Dispatch(std::move(mojo_inputs), std::move(mojo_outputs)); } +ScriptPromise MLGraph::SaveGraph( + webnn::ScopedTrace scoped_trace, + ScriptState* script_state, + String key, + ExceptionState& exception_state) { + if (!remote_graph_.is_bound()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kInvalidStateError, + "Graph has been destroyed or context is lost."); + return EmptyPromise(); + } + + auto* resolver = MakeGarbageCollected>( + script_state, exception_state.GetContext()); + pending_resolvers_.insert(resolver); + + remote_graph_->SaveGraph( + std::move(key), + WTF::BindOnce(&MLGraph::DidSaveGraph, WrapPersistent(this), + std::move(scoped_trace), WrapPersistent(resolver))); + + return resolver->Promise(); +} + +void MLGraph::DidSaveGraph(webnn::ScopedTrace scoped_trace, + ScriptPromiseResolver* resolver, + webnn::mojom::blink::ErrorPtr error) { + pending_resolvers_.erase(resolver); + + if (error) { + resolver->RejectWithDOMException( + WebNNErrorCodeToDOMExceptionCode(error->code), error->message); + return; + } + + resolver->Resolve(); +} + const MLContext* MLGraph::Context() const { return ml_context_.Get(); } void MLGraph::OnConnectionError() { remote_graph_.reset(); + for (const auto& resolver : pending_resolvers_) { + resolver->RejectWithDOMException( + DOMExceptionCode::kInvalidStateError, + "Graph has been destroyed or context is lost."); + } + pending_resolvers_.clear(); } } // namespace blink diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph.h b/third_party/blink/renderer/modules/ml/webnn/ml_graph.h index 5b8c84e6092f52..49bf3ef72a42ba 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph.h +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph.h @@ -13,6 +13,7 @@ #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.h" #include "third_party/blink/renderer/modules/modules_export.h" #include "third_party/blink/renderer/platform/bindings/script_wrappable.h" +#include "third_party/blink/renderer/platform/heap/collection_support/heap_hash_set.h" #include "third_party/blink/renderer/platform/heap/collection_support/heap_vector.h" #include "third_party/blink/renderer/platform/heap/member.h" #include "third_party/blink/renderer/platform/heap/visitor.h" @@ -49,6 +50,14 @@ class MODULES_EXPORT MLGraph : public ScriptWrappable { NamedOperandDescriptors output_constraints, base::PassKey pass_key); + MLGraph(ExecutionContext* execution_context, + MLContext* context, + mojo::PendingAssociatedRemote + pending_graph_remote, + NamedOperandDescriptors input_constraints, + NamedOperandDescriptors output_constraints, + base::PassKey pass_key); + MLGraph(const MLGraph&) = delete; MLGraph& operator=(const MLGraph&) = delete; @@ -72,11 +81,20 @@ class MODULES_EXPORT MLGraph : public ScriptWrappable { const MLNamedTensors& outputs, ExceptionState& exception_state); + ScriptPromise SaveGraph(webnn::ScopedTrace scoped_trace, + ScriptState* script_state, + String key, + ExceptionState& exception_state); + const MLContext* Context() const; private: void OnConnectionError(); + void DidSaveGraph(webnn::ScopedTrace scoped_trace, + ScriptPromiseResolver* resolver, + webnn::mojom::blink::ErrorPtr error); + // Describes the constraints on the inputs or outputs to this graph. // Note that `WTF::HashMap` values must be nullable, but // `webnn::OperandDescriptor` lacks a default constructor, so an optional is @@ -89,6 +107,8 @@ class MODULES_EXPORT MLGraph : public ScriptWrappable { // The `WebNNGraph` is a compiled graph that can be executed by the hardware // accelerated OS machine learning API. HeapMojoAssociatedRemote remote_graph_; + + HeapHashSet>> pending_resolvers_; }; } // namespace blink diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h index a5fff90521cd44..48ebae364b5bfb 100644 --- a/third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4869,6 +4869,9 @@ _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + ORT_API2_STATUS(SaveEpContextModel, _Inout_ OrtSession* session, + _In_ const char* ep_context_path); }; /*