@@ -129,18 +129,30 @@ class GraphImplOrt::ComputeResources {
129129 }
130130
131131 // Save EP Context model
132- void SaveCompiledModel (std::string key, std::string compute_resource_info) {
132+ void SaveCompiledModel (std::string key,
133+ std::string compute_resource_info,
134+ base::OnceCallback<void (mojom::ErrorPtr)> callback) {
133135 TRACE_EVENT0 (" gpu" ,
134136 " ort::GraphImplOrt::ComputeResources::SaveCompiledModel" );
135137
136138 auto cache_dir = base::FilePath::FromASCII (
137139 base::StrCat ({" .\\ EpContextModelCache\\ " , key}));
138- CHECK (base::CreateDirectory (cache_dir));
140+ if (!base::CreateDirectory (cache_dir)) {
141+ std::move (callback).Run (
142+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
143+ " Failed to create a cache directory." ));
144+ return ;
145+ }
139146
140147 std::string compiled_model_path =
141148 cache_dir.AppendASCII (" model.onnx" ).MaybeAsASCII ();
142- CALL_ORT_FUNC (GetOrtApi ()->SaveEpContextModel (session_->GetSession (),
143- compiled_model_path.c_str ()));
149+ if (ORT_CALL_FAILED (GetOrtApi ()->SaveEpContextModel (
150+ session_->GetSession (), compiled_model_path.c_str ()))) {
151+ std::move (callback).Run (
152+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
153+ " Failed to save EPContext Model." ));
154+ return ;
155+ }
144156
145157 base::Value::Dict input_names_dict;
146158 for (const auto & [operand_input, onnx_input] :
@@ -152,7 +164,12 @@ class GraphImplOrt::ComputeResources {
152164
153165 base::FilePath input_names_dict_path =
154166 cache_dir.AppendASCII (" input_names_dict.json" );
155- base::WriteFile (input_names_dict_path, input_names_str);
167+ if (!base::WriteFile (input_names_dict_path, input_names_str)) {
168+ std::move (callback).Run (
169+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
170+ " Failed to write input names dict." ));
171+ return ;
172+ }
156173
157174 base::Value::Dict output_names_dict;
158175 for (const auto & [operand_output, onnx_output] :
@@ -164,11 +181,21 @@ class GraphImplOrt::ComputeResources {
164181
165182 base::FilePath output_names_dict_path =
166183 cache_dir.AppendASCII (" output_names_dict.json" );
167- base::WriteFile (output_names_dict_path, output_names_str);
184+ if (!base::WriteFile (output_names_dict_path, output_names_str)) {
185+ std::move (callback).Run (
186+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
187+ " Failed to write output names dict." ));
188+ return ;
189+ }
168190
169191 base::FilePath compute_resource_info_path =
170192 cache_dir.AppendASCII (" compute_resource_info.txt" );
171- base::WriteFile (compute_resource_info_path, compute_resource_info);
193+ if (!base::WriteFile (compute_resource_info_path, compute_resource_info)) {
194+ std::move (callback).Run (mojom::Error::New (
195+ mojom::Error::Code::kUnknownError ,
196+ " Failed to write compute resources info into disk." ));
197+ return ;
198+ }
172199 }
173200
174201 private:
@@ -318,20 +345,26 @@ void GraphImplOrt::LoadAndBuild(
318345 std::move (wrapped_callback), std::move (scoped_trace)));
319346}
320347
348+ GraphImplOrt::ComputeResourcesAndInfo::ComputeResourcesAndInfo (
349+ ComputeResourceInfo compute_resource_info,
350+ std::unique_ptr<ComputeResources> compute_resources)
351+ : compute_resource_info(std::move(compute_resource_info)),
352+ compute_resources(std::move(compute_resources)) {}
353+ GraphImplOrt::ComputeResourcesAndInfo::~ComputeResourcesAndInfo () = default ;
354+
321355// static
322356void GraphImplOrt::LoadAndBuildOnBackgroundThread (
323357 std::string key,
324358 scoped_refptr<SessionOptions> session_options,
325359 base::OnceCallback<
326- void (ComputeResourceInfo compute_resource_info,
327- base::expected<std::unique_ptr<GraphImplOrt::ComputeResources>,
360+ void (base::expected<std::unique_ptr<ComputeResourcesAndInfo>,
328361 mojom::ErrorPtr>)> callback,
329362 ScopedTrace scoped_trace) {
330363 scoped_trace.AddStep (" Create Env" );
331364
332365 const OrtApi* ort_api = GetOrtApi ();
333366 ScopedOrtEnv env;
334- CHECK (IsSuccess (ort_api->CreateEnv (ORT_LOGGING_LEVEL_WARNING, " test " ,
367+ CHECK (IsSuccess (ort_api->CreateEnv (ORT_LOGGING_LEVEL_WARNING, " WebNN " ,
335368 ScopedOrtEnv::Receiver (env).get ())));
336369
337370 scoped_trace.AddStep (" Load compiled model" );
@@ -342,95 +375,138 @@ void GraphImplOrt::LoadAndBuildOnBackgroundThread(
342375 cache_dir.AppendASCII (" model.onnx" ).value ();
343376
344377 ScopedOrtSession session;
378+
345379 // disable EP context
346380 CHECK (IsSuccess (ort_api->AddSessionConfigEntry (
347- session_options->get (), kOrtSessionOptionEpContextEnable ,
348- /* config_value=*/ " 0" )));
349- CHECK (IsSuccess (ort_api->CreateSession (
350- env.get (), compiled_model_path.c_str (), session_options->get (),
351- ScopedOrtSession::Receiver (session).get ())));
381+ session_options->get (), kOrtSessionOptionEpContextEnable ,
382+ /* config_value=*/ " 0" )));
383+
384+ if (ORT_CALL_FAILED (ort_api->CreateSession (
385+ env.get (), compiled_model_path.c_str (), session_options->get (),
386+ ScopedOrtSession::Receiver (session).get ()))) {
387+ std::move (callback).Run (base::unexpected (
388+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
389+ " Failed to load the compiled model." )));
390+ return ;
391+ }
352392
353393 scoped_trace.AddStep (" Get compute resource info" );
354394
355395 base::FilePath input_names_dict_path =
356396 cache_dir.AppendASCII (" input_names_dict.json" );
357397 std::string input_names_str;
358- base::ReadFileToString (input_names_dict_path, &input_names_str);
359- base::Value::Dict input_names_dict =
360- base::JSONReader::ReadDict (input_names_str).value ();
398+ if (!base::ReadFileToString (input_names_dict_path, &input_names_str)) {
399+ std::move (callback).Run (base::unexpected (
400+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
401+ " Failed to read input_name_dict.json." )));
402+ return ;
403+ }
404+
405+ auto input_names_dict = base::JSONReader::ReadDict (input_names_str);
406+ if (!input_names_dict.has_value ()) {
407+ std::move (callback).Run (base::unexpected (mojom::Error::New (
408+ mojom::Error::Code::kUnknownError , " Failed to get input names dict." )));
409+ return ;
410+ }
361411
362412 base::flat_map<std::string, std::string>
363413 operand_input_name_to_onnx_input_name;
364- for (auto current = input_names_dict. begin ();
365- current != input_names_dict. end (); ++current) {
414+ for (auto current = input_names_dict-> begin ();
415+ current != input_names_dict-> end (); ++current) {
366416 operand_input_name_to_onnx_input_name.emplace (current->first ,
367417 current->second .GetString ());
368418 }
369419
370420 base::FilePath output_names_dict_path =
371421 cache_dir.AppendASCII (" output_names_dict.json" );
372422 std::string output_names_str;
373- base::ReadFileToString (output_names_dict_path, &output_names_str);
374- base::Value::Dict output_names_dict =
375- base::JSONReader::ReadDict (output_names_str).value ();
423+ if (!base::ReadFileToString (output_names_dict_path, &output_names_str)) {
424+ std::move (callback).Run (base::unexpected (
425+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
426+ " Failed to read output_name_dict.json." )));
427+ return ;
428+ }
429+
430+ auto output_names_dict = base::JSONReader::ReadDict (output_names_str);
431+ if (!output_names_dict.has_value ()) {
432+ std::move (callback).Run (base::unexpected (
433+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
434+ " Failed to read the output_name_dict.json." )));
435+ return ;
436+ }
376437
377438 base::flat_map<std::string, std::string>
378439 operand_output_name_to_onnx_output_name;
379- for (auto current = output_names_dict. begin ();
380- current != output_names_dict. end (); ++current) {
440+ for (auto current = output_names_dict-> begin ();
441+ current != output_names_dict-> end (); ++current) {
381442 operand_output_name_to_onnx_output_name.emplace (
382443 current->first , current->second .GetString ());
383444 }
384445
385446 base::FilePath compute_resource_info_path =
386447 cache_dir.AppendASCII (" compute_resource_info.txt" );
387448 std::string compute_resource_info_str;
388- base::ReadFileToString (compute_resource_info_path,
389- &compute_resource_info_str);
449+ if (!base::ReadFileToString (compute_resource_info_path,
450+ &compute_resource_info_str)) {
451+ std::move (callback).Run (base::unexpected (
452+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
453+ " Failed to read the compute resource info." )));
454+ return ;
455+ }
390456
391457 auto compute_resource_info =
392458 ComputeResourceInfo::ParseFromString (compute_resource_info_str);
459+ if (!compute_resource_info.has_value ()) {
460+ std::move (callback).Run (
461+ base::unexpected (std::move (compute_resource_info.error ())));
462+ return ;
463+ }
393464
394465 scoped_trace.AddStep (" Create compute resources" );
395466
396467 auto compute_session =
397468 base::WrapUnique (new Session (std::move (env), std::move (session),
398469 std::vector<base::HeapArray<uint8_t >>{}));
399470
400- std::move (callback).Run (
401- std::move (compute_resource_info),
471+ std::move (callback).Run (base::WrapUnique ( new ComputeResourcesAndInfo (
472+ std::move (compute_resource_info. value () ),
402473 base::WrapUnique (new GraphImplOrt::ComputeResources (
403474 std::move (compute_session),
404475 std::move (operand_input_name_to_onnx_input_name),
405- std::move (operand_output_name_to_onnx_output_name))));
476+ std::move (operand_output_name_to_onnx_output_name)))))) ;
406477}
407478
408479// static
409480void GraphImplOrt::DidLoadAndBuild (
410481 mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
411482 base::WeakPtr<WebNNContextImpl> context,
412483 WebNNContextImpl::LoadGraphImplCallback callback,
413- ComputeResourceInfo compute_resource_info,
414- base::expected<std::unique_ptr<GraphImplOrt::ComputeResources>,
415- mojom::ErrorPtr> result) {
416- // if (!result.has_value()) {
417- // std::move(callback).Run(base::unexpected(std::move(result.error())));
418- // return;
419- // }
484+ base::expected<std::unique_ptr<ComputeResourcesAndInfo>, mojom::ErrorPtr>
485+ result) {
486+ if (!result.has_value ()) {
487+ std::move (callback).Run (base::unexpected (std::move (result.error ())));
488+ return ;
489+ }
420490
421- // if (!context) {
422- // std::move(callback).Run(base::unexpected(mojom::Error::New(
423- // mojom::Error::Code::kUnknownError, "Context was destroyed.")));
424- // return;
425- // }
491+ if (!context) {
492+ std::move (callback).Run (base::unexpected (mojom::Error::New (
493+ mojom::Error::Code::kUnknownError , " Context was destroyed." )));
494+ return ;
495+ }
426496
497+ auto input_names_to_descriptors =
498+ result.value ()->compute_resource_info .input_names_to_descriptors ;
499+ auto output_names_to_descriptors =
500+ result.value ()->compute_resource_info .output_names_to_descriptors ;
427501 std::move (callback).Run (
428- base::WrapUnique (new GraphImplOrt (
429- std::move (receiver), std::move (compute_resource_info),
430- std::move (result.value ()),
431- static_cast <ContextImplOrt*>(context.get ()))),
432- compute_resource_info.input_names_to_descriptors ,
433- compute_resource_info.output_names_to_descriptors );
502+ base::WrapUnique (new WebNNContextImpl::LoadGraphResult (
503+ base::WrapUnique (
504+ new GraphImplOrt (std::move (receiver),
505+ std::move (result.value ()->compute_resource_info ),
506+ std::move (result.value ()->compute_resources ),
507+ static_cast <ContextImplOrt*>(context.get ()))),
508+ std::move (input_names_to_descriptors),
509+ std::move (output_names_to_descriptors))));
434510}
435511
436512GraphImplOrt::~GraphImplOrt () = default ;
@@ -527,15 +603,25 @@ void GraphImplOrt::DispatchImpl(
527603 task->Enqueue ();
528604}
529605
530- void GraphImplOrt::SaveGraphImpl (std::string_view key) {
606+ void GraphImplOrt::SaveGraphImpl (
607+ std::string_view key,
608+ base::OnceCallback<void (mojom::ErrorPtr)> callback) {
531609 TRACE_EVENT0 (" gpu" , " ort::GraphImplOrt::SaveGraphImpl" );
532610
533611 std::vector<scoped_refptr<QueueableResourceStateBase>> exclusive_resources;
534612 exclusive_resources.reserve (1 );
535613 exclusive_resources.push_back (compute_resources_state_);
536614
537615 std::string compute_resource_info;
538- this ->compute_resource_info ().SerializeToString (compute_resource_info);
616+ if (!this ->compute_resource_info ().SerializeToString (compute_resource_info)) {
617+ std::move (callback).Run (
618+ mojom::Error::New (mojom::Error::Code::kUnknownError ,
619+ " Failed to serialize compute resources info." ));
620+ return ;
621+ }
622+
623+ auto save_graph_callback =
624+ base::BindPostTaskToCurrentDefault (std::move (callback));
539625
540626 auto task = base::MakeRefCounted<ResourceTask>(
541627 std::vector<scoped_refptr<QueueableResourceStateBase>>{},
@@ -544,6 +630,7 @@ void GraphImplOrt::SaveGraphImpl(std::string_view key) {
544630 [](scoped_refptr<QueueableResourceState<ComputeResources>>
545631 compute_resources_state,
546632 std::string key, std::string compute_resource_info,
633+ base::OnceCallback<void (mojom::ErrorPtr)> save_graph_callback,
547634 base::OnceClosure completion_closure) {
548635 ComputeResources* raw_compute_resources =
549636 compute_resources_state->GetExclusivelyLockedResource ();
@@ -557,12 +644,12 @@ void GraphImplOrt::SaveGraphImpl(std::string_view key) {
557644 base::MayBlock ()},
558645 base::BindOnce (&ComputeResources::SaveCompiledModel,
559646 base::Unretained (raw_compute_resources),
560- std::move (key),
561- std::move (compute_resource_info )),
647+ std::move (key), std::move (compute_resource_info),
648+ std::move (save_graph_callback )),
562649 std::move (completion_closure));
563650 },
564651 compute_resources_state_, std::string (key),
565- std::move (compute_resource_info)));
652+ std::move (compute_resource_info), std::move (save_graph_callback) ));
566653
567654 task->Enqueue ();
568655}
0 commit comments