Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions services/webnn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ component("webnn_service") {
"//gpu/command_buffer/service:gles2",
"//gpu/config",
"//mojo/public/cpp/bindings",
"//services/webnn:webnn_proto",
"//services/webnn/public/mojom",
]

Expand Down Expand Up @@ -260,3 +261,13 @@ mojolpm_fuzzer_test("webnn_graph_mojolpm_fuzzer") {
"//third_party/libprotobuf-mutator",
]
}

proto_library("webnn_proto") {
sources = [ "compute_resource_info.proto" ]

# deps = [
# "//services/webnn/public/mojom:mojom_mojolpm",
# ]

cc_generator_options = "lite"
}
42 changes: 42 additions & 0 deletions services/webnn/compute_resource_info.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// Message format for the MojoLPM fuzzer for the webnn service interface.

syntax = "proto2";

package services.webnn.proto;

enum OperandDataType {
kFloat32 = 10;
kFloat16 = 1;
kInt32 = 2;
kUint32 = 3;
kInt64 = 4;
kUint64 = 5;
kInt8 = 6;
kUint8 = 7;
kInt4 = 8;
kUint4 = 9;
}

message OperandDescriptor {
required OperandDataType data_type = 1;
repeated uint32 dim = 2;
}

message OperationIds {
repeated uint64 id = 1;
}

message ComputeResourceInfo {
map<string, OperandDescriptor> input_names_to_descriptors = 1;
map<string, OperandDescriptor> output_names_to_descriptors = 2;
map<uint64, OperationIds> operand_to_dependent_operations = 3;
}

// //
// message NameMapping {
// map<string, string> ori_name_to_real_name = 1;
// }
5 changes: 5 additions & 0 deletions services/webnn/coreml/context_impl_coreml.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class API_AVAILABLE(macos(14.0)) ContextImplCoreml final
mojom::TensorInfoPtr tensor_info,
CreateTensorImplCallback callback) override;

void LoadGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
std::string key,
LoadGraphImplCallback callback) override;

base::WeakPtrFactory<ContextImplCoreml> weak_factory_{this};
};

Expand Down
7 changes: 7 additions & 0 deletions services/webnn/coreml/context_impl_coreml.mm
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,11 @@
std::move(tensor_info)));
}

void ContextImplCoreml::LoadGraphImpl(
mojo::PendingAssociatedRemote<mojom::WebNNGraph> remote,
std::string key,
LoadGraphImplCallback callback) {
NOTIMPLEMENTED();
}

} // namespace webnn::coreml
4 changes: 4 additions & 0 deletions services/webnn/coreml/graph_impl_coreml.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class API_AVAILABLE(macos(14.0)) GraphImplCoreml final : public WebNNGraphImpl {
const base::flat_map<std::string_view, WebNNTensorImpl*>& named_outputs)
override;

void SaveGraphImpl(
std::string_view key,
base::OnceCallback<void(mojom::ErrorPtr)> callback) override;

private:
class ComputeResources;

Expand Down
6 changes: 6 additions & 0 deletions services/webnn/coreml/graph_impl_coreml.mm
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,12 @@ void DidDispatch(base::ElapsedTimer model_predict_timer,
task->Enqueue();
}

void GraphImplCoreml::SaveGraphImpl(
std::string_view key,
base::OnceCallback<void(mojom::ErrorPtr)> callback) {
NOTIMPLEMENTED();
}

GraphImplCoreml::Params::Params(
ComputeResourceInfo compute_resource_info,
base::flat_map<std::string, std::string> coreml_name_to_operand_name)
Expand Down
7 changes: 7 additions & 0 deletions services/webnn/dml/context_impl_dml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -944,4 +944,11 @@ void ContextImplDml::RemoveDeviceForTesting() {
d3d12_device_5->RemoveDevice();
}

void ContextImplDml::LoadGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
std::string key,
LoadGraphImplCallback callback) {
NOTIMPLEMENTED();
}

} // namespace webnn::dml
5 changes: 5 additions & 0 deletions services/webnn/dml/context_impl_dml.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) ContextImplDml final
mojom::TensorInfoPtr tensor_info,
CreateTensorImplCallback callback) override;

void LoadGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
std::string key,
LoadGraphImplCallback callback) override;

// Begins recording commands needed for context operations.
// If recording failed, calling this function will recreate the recorder to
// allow recording to start again.
Expand Down
7 changes: 7 additions & 0 deletions services/webnn/dml/graph_impl_dml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7129,4 +7129,11 @@ void GraphImplDml::OnDispatchComplete(
graph_resources_ = std::move(graph_resources);
}
}

void GraphImplDml::SaveGraphImpl(
std::string_view key,
base::OnceCallback<void(mojom::ErrorPtr)> callback) {
NOTIMPLEMENTED();
}

} // namespace webnn::dml
4 changes: 4 additions & 0 deletions services/webnn/dml/graph_impl_dml.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ class GraphImplDml final : public WebNNGraphImpl {
const base::flat_map<std::string_view, WebNNTensorImpl*>& named_outputs)
override;

void SaveGraphImpl(
std::string_view key,
base::OnceCallback<void(mojom::ErrorPtr)> callback) override;

// The persistent resource is allocated after the compilation work is
// completed for the graph initialization and will be used for the following
// graph executions. It could be nullptr which means it isn't required by the
Expand Down
8 changes: 8 additions & 0 deletions services/webnn/ort/context_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,4 +497,12 @@ void ContextImplOrt::CreateTensorImpl(
TensorImplOrt::Create(std::move(receiver), this, std::move(tensor_info)));
}

void ContextImplOrt::LoadGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
std::string key,
LoadGraphImplCallback callback) {
GraphImplOrt::LoadAndBuild(std::move(receiver), std::move(key), this,
std::move(callback));
}

} // namespace webnn::ort
7 changes: 6 additions & 1 deletion services/webnn/ort/context_impl_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SessionOptions final : public base::RefCountedThreadSafe<SessionOptions> {
SessionOptions(const SessionOptions&) = delete;
SessionOptions& operator=(const SessionOptions&) = delete;

const OrtSessionOptions* get() const { return session_options_.get(); }
OrtSessionOptions* get() { return session_options_.get(); }

private:
friend class base::RefCountedThreadSafe<SessionOptions>;
Expand Down Expand Up @@ -73,6 +73,11 @@ class ContextImplOrt final : public WebNNContextImpl {
mojom::TensorInfoPtr tensor_info,
CreateTensorImplCallback callback) override;

void LoadGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
std::string key,
LoadGraphImplCallback callback) override;

ScopedOrtEnv env_;

// The session options are shared among all the sessions created by this
Expand Down
Loading