Skip to content

Commit 36a4462

Browse files
committed
add error handling
1 parent 5a5d736 commit 36a4462

File tree

20 files changed

+301
-110
lines changed

20 files changed

+301
-110
lines changed

services/webnn/coreml/graph_impl_coreml.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ class API_AVAILABLE(macos(14.0)) GraphImplCoreml final : public WebNNGraphImpl {
133133
const base::flat_map<std::string_view, WebNNTensorImpl*>& named_outputs)
134134
override;
135135

136-
void SaveGraphImpl(std::string_view key) override;
136+
void SaveGraphImpl(
137+
std::string_view key,
138+
base::OnceCallback<void(mojom::ErrorPtr)> callback) override;
137139

138140
private:
139141
class ComputeResources;

services/webnn/coreml/graph_impl_coreml.mm

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,9 @@ void DidDispatch(base::ElapsedTimer model_predict_timer,
630630
task->Enqueue();
631631
}
632632

633-
void GraphImplCoreml::SaveGraphImpl(std::string_view key) {
633+
void GraphImplCoreml::SaveGraphImpl(
634+
std::string_view key,
635+
base::OnceCallback<void(mojom::ErrorPtr)> callback) {
634636
NOTIMPLEMENTED();
635637
}
636638

services/webnn/dml/graph_impl_dml.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7130,7 +7130,9 @@ void GraphImplDml::OnDispatchComplete(
71307130
}
71317131
}
71327132

7133-
void GraphImplDml::SaveGraphImpl(std::string_view key) {
7133+
void GraphImplDml::SaveGraphImpl(
7134+
std::string_view key,
7135+
base::OnceCallback<void(mojom::ErrorPtr)> callback) {
71347136
NOTIMPLEMENTED();
71357137
}
71367138

services/webnn/dml/graph_impl_dml.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,9 @@ class GraphImplDml final : public WebNNGraphImpl {
287287
const base::flat_map<std::string_view, WebNNTensorImpl*>& named_outputs)
288288
override;
289289

290-
void SaveGraphImpl(std::string_view key) override;
290+
void SaveGraphImpl(
291+
std::string_view key,
292+
base::OnceCallback<void(mojom::ErrorPtr)> callback) override;
291293

292294
// The persistent resource is allocated after the compilation work is
293295
// completed for the graph initialization and will be used for the following

services/webnn/ort/graph_impl_ort.cc

Lines changed: 140 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,30 @@ class GraphImplOrt::ComputeResources {
129129
}
130130

131131
// Save EP Context model
132-
void SaveCompiledModel(std::string key, std::string compute_resource_info) {
132+
void SaveCompiledModel(std::string key,
133+
std::string compute_resource_info,
134+
base::OnceCallback<void(mojom::ErrorPtr)> callback) {
133135
TRACE_EVENT0("gpu",
134136
"ort::GraphImplOrt::ComputeResources::SaveCompiledModel");
135137

136138
auto cache_dir = base::FilePath::FromASCII(
137139
base::StrCat({".\\EpContextModelCache\\", key}));
138-
CHECK(base::CreateDirectory(cache_dir));
140+
if (!base::CreateDirectory(cache_dir)) {
141+
std::move(callback).Run(
142+
mojom::Error::New(mojom::Error::Code::kUnknownError,
143+
"Failed to create a cache directory."));
144+
return;
145+
}
139146

140147
std::string compiled_model_path =
141148
cache_dir.AppendASCII("model.onnx").MaybeAsASCII();
142-
CALL_ORT_FUNC(GetOrtApi()->SaveEpContextModel(session_->GetSession(),
143-
compiled_model_path.c_str()));
149+
if (ORT_CALL_FAILED(GetOrtApi()->SaveEpContextModel(
150+
session_->GetSession(), compiled_model_path.c_str()))) {
151+
std::move(callback).Run(
152+
mojom::Error::New(mojom::Error::Code::kUnknownError,
153+
"Failed to save EPContext Model."));
154+
return;
155+
}
144156

145157
base::Value::Dict input_names_dict;
146158
for (const auto& [operand_input, onnx_input] :
@@ -152,7 +164,12 @@ class GraphImplOrt::ComputeResources {
152164

153165
base::FilePath input_names_dict_path =
154166
cache_dir.AppendASCII("input_names_dict.json");
155-
base::WriteFile(input_names_dict_path, input_names_str);
167+
if (!base::WriteFile(input_names_dict_path, input_names_str)) {
168+
std::move(callback).Run(
169+
mojom::Error::New(mojom::Error::Code::kUnknownError,
170+
"Failed to write input names dict."));
171+
return;
172+
}
156173

157174
base::Value::Dict output_names_dict;
158175
for (const auto& [operand_output, onnx_output] :
@@ -164,11 +181,21 @@ class GraphImplOrt::ComputeResources {
164181

165182
base::FilePath output_names_dict_path =
166183
cache_dir.AppendASCII("output_names_dict.json");
167-
base::WriteFile(output_names_dict_path, output_names_str);
184+
if (!base::WriteFile(output_names_dict_path, output_names_str)) {
185+
std::move(callback).Run(
186+
mojom::Error::New(mojom::Error::Code::kUnknownError,
187+
"Failed to write output names dict."));
188+
return;
189+
}
168190

169191
base::FilePath compute_resource_info_path =
170192
cache_dir.AppendASCII("compute_resource_info.txt");
171-
base::WriteFile(compute_resource_info_path, compute_resource_info);
193+
if (!base::WriteFile(compute_resource_info_path, compute_resource_info)) {
194+
std::move(callback).Run(mojom::Error::New(
195+
mojom::Error::Code::kUnknownError,
196+
"Failed to write compute resources info into disk."));
197+
return;
198+
}
172199
}
173200

174201
private:
@@ -318,20 +345,26 @@ void GraphImplOrt::LoadAndBuild(
318345
std::move(wrapped_callback), std::move(scoped_trace)));
319346
}
320347

348+
GraphImplOrt::ComputeResourcesAndInfo::ComputeResourcesAndInfo(
349+
ComputeResourceInfo compute_resource_info,
350+
std::unique_ptr<ComputeResources> compute_resources)
351+
: compute_resource_info(std::move(compute_resource_info)),
352+
compute_resources(std::move(compute_resources)) {}
353+
GraphImplOrt::ComputeResourcesAndInfo::~ComputeResourcesAndInfo() = default;
354+
321355
// static
322356
void GraphImplOrt::LoadAndBuildOnBackgroundThread(
323357
std::string key,
324358
scoped_refptr<SessionOptions> session_options,
325359
base::OnceCallback<
326-
void(ComputeResourceInfo compute_resource_info,
327-
base::expected<std::unique_ptr<GraphImplOrt::ComputeResources>,
360+
void(base::expected<std::unique_ptr<ComputeResourcesAndInfo>,
328361
mojom::ErrorPtr>)> callback,
329362
ScopedTrace scoped_trace) {
330363
scoped_trace.AddStep("Create Env");
331364

332365
const OrtApi* ort_api = GetOrtApi();
333366
ScopedOrtEnv env;
334-
CHECK(IsSuccess(ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test",
367+
CHECK(IsSuccess(ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "WebNN",
335368
ScopedOrtEnv::Receiver(env).get())));
336369

337370
scoped_trace.AddStep("Load compiled model");
@@ -342,95 +375,138 @@ void GraphImplOrt::LoadAndBuildOnBackgroundThread(
342375
cache_dir.AppendASCII("model.onnx").value();
343376

344377
ScopedOrtSession session;
378+
345379
// disable EP context
346380
CHECK(IsSuccess(ort_api->AddSessionConfigEntry(
347-
session_options->get(), kOrtSessionOptionEpContextEnable,
348-
/*config_value=*/"0")));
349-
CHECK(IsSuccess(ort_api->CreateSession(
350-
env.get(), compiled_model_path.c_str(), session_options->get(),
351-
ScopedOrtSession::Receiver(session).get())));
381+
session_options->get(), kOrtSessionOptionEpContextEnable,
382+
/*config_value=*/"0")));
383+
384+
if (ORT_CALL_FAILED(ort_api->CreateSession(
385+
env.get(), compiled_model_path.c_str(), session_options->get(),
386+
ScopedOrtSession::Receiver(session).get()))) {
387+
std::move(callback).Run(base::unexpected(
388+
mojom::Error::New(mojom::Error::Code::kUnknownError,
389+
"Failed to load the compiled model.")));
390+
return;
391+
}
352392

353393
scoped_trace.AddStep("Get compute resource info");
354394

355395
base::FilePath input_names_dict_path =
356396
cache_dir.AppendASCII("input_names_dict.json");
357397
std::string input_names_str;
358-
base::ReadFileToString(input_names_dict_path, &input_names_str);
359-
base::Value::Dict input_names_dict =
360-
base::JSONReader::ReadDict(input_names_str).value();
398+
if (!base::ReadFileToString(input_names_dict_path, &input_names_str)) {
399+
std::move(callback).Run(base::unexpected(
400+
mojom::Error::New(mojom::Error::Code::kUnknownError,
401+
"Failed to read input_name_dict.json.")));
402+
return;
403+
}
404+
405+
auto input_names_dict = base::JSONReader::ReadDict(input_names_str);
406+
if (!input_names_dict.has_value()) {
407+
std::move(callback).Run(base::unexpected(mojom::Error::New(
408+
mojom::Error::Code::kUnknownError, "Failed to get input names dict.")));
409+
return;
410+
}
361411

362412
base::flat_map<std::string, std::string>
363413
operand_input_name_to_onnx_input_name;
364-
for (auto current = input_names_dict.begin();
365-
current != input_names_dict.end(); ++current) {
414+
for (auto current = input_names_dict->begin();
415+
current != input_names_dict->end(); ++current) {
366416
operand_input_name_to_onnx_input_name.emplace(current->first,
367417
current->second.GetString());
368418
}
369419

370420
base::FilePath output_names_dict_path =
371421
cache_dir.AppendASCII("output_names_dict.json");
372422
std::string output_names_str;
373-
base::ReadFileToString(output_names_dict_path, &output_names_str);
374-
base::Value::Dict output_names_dict =
375-
base::JSONReader::ReadDict(output_names_str).value();
423+
if (!base::ReadFileToString(output_names_dict_path, &output_names_str)) {
424+
std::move(callback).Run(base::unexpected(
425+
mojom::Error::New(mojom::Error::Code::kUnknownError,
426+
"Failed to read output_name_dict.json.")));
427+
return;
428+
}
429+
430+
auto output_names_dict = base::JSONReader::ReadDict(output_names_str);
431+
if (!output_names_dict.has_value()) {
432+
std::move(callback).Run(base::unexpected(
433+
mojom::Error::New(mojom::Error::Code::kUnknownError,
434+
"Failed to read the output_name_dict.json.")));
435+
return;
436+
}
376437

377438
base::flat_map<std::string, std::string>
378439
operand_output_name_to_onnx_output_name;
379-
for (auto current = output_names_dict.begin();
380-
current != output_names_dict.end(); ++current) {
440+
for (auto current = output_names_dict->begin();
441+
current != output_names_dict->end(); ++current) {
381442
operand_output_name_to_onnx_output_name.emplace(
382443
current->first, current->second.GetString());
383444
}
384445

385446
base::FilePath compute_resource_info_path =
386447
cache_dir.AppendASCII("compute_resource_info.txt");
387448
std::string compute_resource_info_str;
388-
base::ReadFileToString(compute_resource_info_path,
389-
&compute_resource_info_str);
449+
if (!base::ReadFileToString(compute_resource_info_path,
450+
&compute_resource_info_str)) {
451+
std::move(callback).Run(base::unexpected(
452+
mojom::Error::New(mojom::Error::Code::kUnknownError,
453+
"Failed to read the compute resource info.")));
454+
return;
455+
}
390456

391457
auto compute_resource_info =
392458
ComputeResourceInfo::ParseFromString(compute_resource_info_str);
459+
if (!compute_resource_info.has_value()) {
460+
std::move(callback).Run(
461+
base::unexpected(std::move(compute_resource_info.error())));
462+
return;
463+
}
393464

394465
scoped_trace.AddStep("Create compute resources");
395466

396467
auto compute_session =
397468
base::WrapUnique(new Session(std::move(env), std::move(session),
398469
std::vector<base::HeapArray<uint8_t>>{}));
399470

400-
std::move(callback).Run(
401-
std::move(compute_resource_info),
471+
std::move(callback).Run(base::WrapUnique(new ComputeResourcesAndInfo(
472+
std::move(compute_resource_info.value()),
402473
base::WrapUnique(new GraphImplOrt::ComputeResources(
403474
std::move(compute_session),
404475
std::move(operand_input_name_to_onnx_input_name),
405-
std::move(operand_output_name_to_onnx_output_name))));
476+
std::move(operand_output_name_to_onnx_output_name))))));
406477
}
407478

408479
// static
409480
void GraphImplOrt::DidLoadAndBuild(
410481
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
411482
base::WeakPtr<WebNNContextImpl> context,
412483
WebNNContextImpl::LoadGraphImplCallback callback,
413-
ComputeResourceInfo compute_resource_info,
414-
base::expected<std::unique_ptr<GraphImplOrt::ComputeResources>,
415-
mojom::ErrorPtr> result) {
416-
// if (!result.has_value()) {
417-
// std::move(callback).Run(base::unexpected(std::move(result.error())));
418-
// return;
419-
// }
484+
base::expected<std::unique_ptr<ComputeResourcesAndInfo>, mojom::ErrorPtr>
485+
result) {
486+
if (!result.has_value()) {
487+
std::move(callback).Run(base::unexpected(std::move(result.error())));
488+
return;
489+
}
420490

421-
// if (!context) {
422-
// std::move(callback).Run(base::unexpected(mojom::Error::New(
423-
// mojom::Error::Code::kUnknownError, "Context was destroyed.")));
424-
// return;
425-
// }
491+
if (!context) {
492+
std::move(callback).Run(base::unexpected(mojom::Error::New(
493+
mojom::Error::Code::kUnknownError, "Context was destroyed.")));
494+
return;
495+
}
426496

497+
auto input_names_to_descriptors =
498+
result.value()->compute_resource_info.input_names_to_descriptors;
499+
auto output_names_to_descriptors =
500+
result.value()->compute_resource_info.output_names_to_descriptors;
427501
std::move(callback).Run(
428-
base::WrapUnique(new GraphImplOrt(
429-
std::move(receiver), std::move(compute_resource_info),
430-
std::move(result.value()),
431-
static_cast<ContextImplOrt*>(context.get()))),
432-
compute_resource_info.input_names_to_descriptors,
433-
compute_resource_info.output_names_to_descriptors);
502+
base::WrapUnique(new WebNNContextImpl::LoadGraphResult(
503+
base::WrapUnique(
504+
new GraphImplOrt(std::move(receiver),
505+
std::move(result.value()->compute_resource_info),
506+
std::move(result.value()->compute_resources),
507+
static_cast<ContextImplOrt*>(context.get()))),
508+
std::move(input_names_to_descriptors),
509+
std::move(output_names_to_descriptors))));
434510
}
435511

436512
GraphImplOrt::~GraphImplOrt() = default;
@@ -527,15 +603,25 @@ void GraphImplOrt::DispatchImpl(
527603
task->Enqueue();
528604
}
529605

530-
void GraphImplOrt::SaveGraphImpl(std::string_view key) {
606+
void GraphImplOrt::SaveGraphImpl(
607+
std::string_view key,
608+
base::OnceCallback<void(mojom::ErrorPtr)> callback) {
531609
TRACE_EVENT0("gpu", "ort::GraphImplOrt::SaveGraphImpl");
532610

533611
std::vector<scoped_refptr<QueueableResourceStateBase>> exclusive_resources;
534612
exclusive_resources.reserve(1);
535613
exclusive_resources.push_back(compute_resources_state_);
536614

537615
std::string compute_resource_info;
538-
this->compute_resource_info().SerializeToString(compute_resource_info);
616+
if (!this->compute_resource_info().SerializeToString(compute_resource_info)) {
617+
std::move(callback).Run(
618+
mojom::Error::New(mojom::Error::Code::kUnknownError,
619+
"Failed to serialize compute resources info."));
620+
return;
621+
}
622+
623+
auto save_graph_callback =
624+
base::BindPostTaskToCurrentDefault(std::move(callback));
539625

540626
auto task = base::MakeRefCounted<ResourceTask>(
541627
std::vector<scoped_refptr<QueueableResourceStateBase>>{},
@@ -544,6 +630,7 @@ void GraphImplOrt::SaveGraphImpl(std::string_view key) {
544630
[](scoped_refptr<QueueableResourceState<ComputeResources>>
545631
compute_resources_state,
546632
std::string key, std::string compute_resource_info,
633+
base::OnceCallback<void(mojom::ErrorPtr)> save_graph_callback,
547634
base::OnceClosure completion_closure) {
548635
ComputeResources* raw_compute_resources =
549636
compute_resources_state->GetExclusivelyLockedResource();
@@ -557,12 +644,12 @@ void GraphImplOrt::SaveGraphImpl(std::string_view key) {
557644
base::MayBlock()},
558645
base::BindOnce(&ComputeResources::SaveCompiledModel,
559646
base::Unretained(raw_compute_resources),
560-
std::move(key),
561-
std::move(compute_resource_info)),
647+
std::move(key), std::move(compute_resource_info),
648+
std::move(save_graph_callback)),
562649
std::move(completion_closure));
563650
},
564651
compute_resources_state_, std::string(key),
565-
std::move(compute_resource_info)));
652+
std::move(compute_resource_info), std::move(save_graph_callback)));
566653

567654
task->Enqueue();
568655
}

0 commit comments

Comments
 (0)