diff --git a/litert/runtime/BUILD b/litert/runtime/BUILD index c208c341d0..a95ceb7e66 100644 --- a/litert/runtime/BUILD +++ b/litert/runtime/BUILD @@ -624,11 +624,13 @@ cc_library( "//litert/build_common:litert_disable_cpu": [ # No support for CPU backend. Builtin ops are still needed for GPU. "//tflite/kernels:builtin_ops", + "//tflite/kernels:reference_ops", ], "//conditions:default": [ ":litert_cpu_options", "//litert/runtime/accelerators/xnnpack:xnnpack_accelerator", # buildcleaner: keep "//tflite/kernels:builtin_ops", + "//tflite/kernels:reference_ops", ], }), ) diff --git a/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc b/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc index dde1a4ce79..8be988f446 100644 --- a/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc +++ b/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc @@ -91,8 +91,11 @@ class CpuAccelerator final return options_data_status; } - // TODO: b/403547017 - Make the CPU accelerator configurable using the - // compilation options. + if (parsed_options.kernel_mode != kLiteRtCpuKernelModeXnnpack) { + *delegate_wrapper = nullptr; + return kLiteRtStatusOk; + } + auto xnn_options = parsed_options.xnn; TfLiteOpaqueDelegate* xnnpack_delegate = TfLiteXNNPackDelegateCreate(&xnn_options); @@ -108,6 +111,9 @@ class CpuAccelerator final // Destroys an XNNPack delegate instance. static void DestroyDelegate(LiteRtRuntimeContext* runtime_context, LiteRtDelegateWrapper delegate_wrapper) { + if (delegate_wrapper == nullptr) { + return; + } TfLiteOpaqueDelegate* xnnpack_delegate; runtime_context->unwrap_delegate(delegate_wrapper, &xnnpack_delegate); TfLiteXNNPackDelegateDelete(xnnpack_delegate); diff --git a/litert/runtime/compiled_model.cc b/litert/runtime/compiled_model.cc index 57ed4e0c97..05308fedc9 100644 --- a/litert/runtime/compiled_model.cc +++ b/litert/runtime/compiled_model.cc @@ -29,6 +29,9 @@ #include #include +#include "litert/c/options/litert_cpu_options.h" +#include "tflite/mutable_op_resolver.h" + #if !defined(LITERT_WINDOWS_OS) #include #endif // !defined(LITERT_WINDOWS_OS) @@ -106,6 +109,7 @@ #include "tflite/interpreter_options.h" #if !defined(LITERT_NO_BUILTIN_OPS) #include "tflite/kernels/register.h" +#include "tflite/kernels/register_ref.h" #endif // LITERT_NO_BUILTIN_OPS #if defined(LITERT_NO_BUILTIN_OPS) @@ -241,14 +245,58 @@ void ApplySchedulingInfoOverrides(const LiteRtSchedulingInfo& overrides, Expected LiteRtCompiledModelT::InitializeRuntime( LiteRtEnvironmentT* env, LiteRtHwAcceleratorSet hardware_accelerators, LiteRtOptions jit_compilation_options) { + int num_threads = 1; + bool use_non_xnnpack_cpu_backend = false; + bool use_reference_cpu_kernels = false; +#if !defined(LITERT_DISABLE_CPU) + LiteRtCpuOptionsT cpu_options; + if (jit_compilation_options && + (hardware_accelerators & kLiteRtHwAcceleratorCpu)) { + auto opaque_options = litert::OpaqueOptions::WrapCObject( + jit_compilation_options->options, litert::OwnHandle::kNo); + if (auto cpu_options_data = litert::FindOpaqueData( + opaque_options, LiteRtCpuOptionsT::Identifier()); + cpu_options_data) { + absl::string_view data_str(*cpu_options_data); + if (litert::internal::ParseLiteRtCpuOptions( + data_str.data(), data_str.size(), &cpu_options) != + kLiteRtStatusOk) { + LITERT_LOG(LITERT_WARNING, "Failed to parse CPU options"); + } else { + num_threads = cpu_options.xnn.num_threads; + use_non_xnnpack_cpu_backend = + cpu_options.kernel_mode != kLiteRtCpuKernelModeXnnpack; + use_reference_cpu_kernels = + cpu_options.kernel_mode == kLiteRtCpuKernelModeReference; + } + } + } +#endif // !defined(LITERT_DISABLE_CPU) + #ifdef LITERT_NO_BUILTIN_OPS + if ((hardware_accelerators & kLiteRtHwAcceleratorCpu) && + use_non_xnnpack_cpu_backend) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Builtin and reference CPU kernel modes require builtin " + "kernels."); + } // Use StubOpResolver which provides minimal stub implementations for all // builtin ops. These stubs allow the model to pass validation, but the // actual operations will be handled by LiteRT's accelerator system // (NPU > GPU > CPU) through their respective delegates. - litert::internal::StubOpResolver resolver; + litert::internal::StubOpResolver resolver_storage; + tflite::MutableOpResolver* resolver = &resolver_storage; #else - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; + std::unique_ptr resolver_storage; + if ((hardware_accelerators & kLiteRtHwAcceleratorCpu) && + use_reference_cpu_kernels) { + resolver_storage = + std::make_unique(); + } else { + resolver_storage = std::make_unique< + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>(); + } + tflite::MutableOpResolver* resolver = resolver_storage.get(); #endif // LITERT_NO_BUILTIN_OPS // Apply custom ops. @@ -258,26 +306,30 @@ Expected LiteRtCompiledModelT::InitializeRuntime( std::make_unique(option)); auto* tflite_registration = custom_op_dispatchers_.back()->GetTfLiteRegistration(); - resolver.AddCustom(option.op_name.c_str(), tflite_registration); + resolver->AddCustom(option.op_name.c_str(), tflite_registration); } } // Add custom ops that are supported by the CPU / GPU accelerators. if (hardware_accelerators & kLiteRtHwAcceleratorGpu) { const char* accelerator_supported_custom_ops[] = { - "Convolution2DTransposeBias", "MaxPoolingWithArgmax2D", - "MaxUnpooling2D", "Resampler", "custom_call.GroupNorm", - "custom_call.LayerNorm", "custom_call.RmsNorm", + "Convolution2DTransposeBias", + "MaxPoolingWithArgmax2D", + "MaxUnpooling2D", + "Resampler", + "custom_call.GroupNorm", + "custom_call.LayerNorm", + "custom_call.RmsNorm", "custom_call.PixelShuffle"}; for (const auto& op_name : accelerator_supported_custom_ops) { - resolver.AddCustom(op_name, &sStubRegistration); + resolver->AddCustom(op_name, &sStubRegistration); } } else if (hardware_accelerators & kLiteRtHwAcceleratorCpu) { const char* accelerator_supported_custom_ops[] = { "Convolution2DTransposeBias", "MaxPoolingWithArgmax2D", "MaxUnpooling2D"}; for (const auto& op_name : accelerator_supported_custom_ops) { - resolver.AddCustom(op_name, &sStubRegistration); + resolver->AddCustom(op_name, &sStubRegistration); } } #ifdef __EMSCRIPTEN__ @@ -285,14 +337,13 @@ Expected LiteRtCompiledModelT::InitializeRuntime( const char* accelerator_supported_custom_ops[] = { "Convolution2DTransposeBias"}; for (const auto& op_name : accelerator_supported_custom_ops) { - resolver.AddCustom(op_name, &sStubRegistration); + resolver->AddCustom(op_name, &sStubRegistration); } } #endif // __EMSCRIPTEN__ tflite::InterpreterOptions interpreter_options; interpreter_options.SetUseSignatureTensorNames(true); - int num_threads = 1; if (jit_compilation_options) { auto opaque_options = litert::OpaqueOptions::WrapCObject( jit_compilation_options->options, litert::OwnHandle::kNo); @@ -328,26 +379,10 @@ Expected LiteRtCompiledModelT::InitializeRuntime( } } } - -#if !defined(LITERT_DISABLE_CPU) - if (auto cpu_options_data = litert::FindOpaqueData( - opaque_options, LiteRtCpuOptionsT::Identifier()); - cpu_options_data) { - LiteRtCpuOptionsT cpu_options; - absl::string_view data_str(*cpu_options_data); - if (litert::internal::ParseLiteRtCpuOptions( - data_str.data(), data_str.size(), &cpu_options) != - kLiteRtStatusOk) { - LITERT_LOG(LITERT_WARNING, "Failed to parse CPU options"); - } else { - num_threads = cpu_options.xnn.num_threads; - } - } -#endif // !defined(LITERT_DISABLE_CPU) } tflite::InterpreterBuilder builder( - fb_model_->GetModel(), resolver, error_reporter_.get(), + fb_model_->GetModel(), *resolver, error_reporter_.get(), &interpreter_options, fb_model_->allocation()); builder(&interp_); if (interp_ == nullptr) { @@ -408,7 +443,7 @@ Expected LiteRtCompiledModelT::InitializeRuntime( #if defined(LITERT_WITH_EXTERNAL_WEIGHT_LOADER) std::unique_ptr scoped_weight_source; auto* options_impl = - reinterpret_cast(jit_compilation_options); + reinterpret_cast(jit_compilation_options); if (options_impl != nullptr) { scoped_weight_source = std::move(options_impl->scoped_weight_source); } @@ -785,6 +820,9 @@ Expected LiteRtCompiledModelT::Create( LITERT_RETURN_IF_ERROR(accelerator->CreateDelegate( LrtGetRuntimeContext(), env, accelerator.get(), jit_compilation_options, &delegate_wrapper)); + if (delegate_wrapper == nullptr) { + continue; + } TfLiteOpaqueDelegate* delegate_ptr = nullptr; LrtGetRuntimeContext()->unwrap_delegate(delegate_wrapper, &delegate_ptr); @@ -1604,7 +1642,8 @@ Expected LiteRtCompiledModelT::RunCApi( return result; } -Expected LiteRtCompiledModelT::StartMetricsCollection(int detail_level) const { +Expected LiteRtCompiledModelT::StartMetricsCollection( + int detail_level) const { if (detail_level < 0) { return Unexpected(kLiteRtStatusErrorInvalidArgument, "Detail level must be >= 0");