Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion runtime/executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -635,10 +635,11 @@ cc_test(
":llm_executor_io_types",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/types:span",
"@litert//litert/c:litert_tensor_buffer_types",
"@litert//litert/cc:litert_element_type",
"@litert//litert/cc:litert_environment",
"@litert//litert/cc:litert_layout",
"@litert//litert/cc:litert_model",
"@litert//litert/cc:litert_ranked_tensor_type",
"@litert//litert/cc:litert_tensor_buffer",
"//runtime/components/constrained_decoding:constrained_decoder",
"//runtime/components/constrained_decoding:fake_constraint",
Expand Down
39 changes: 33 additions & 6 deletions runtime/executor/kv_cache_interface.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
// Copyright 2026 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_KV_CACHE_INTERFACE_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_KV_CACHE_INTERFACE_H_

#include <cstddef>
#include <string>

#include "absl/status/status.h" // from @com_google_absl
Expand All @@ -15,19 +28,33 @@ class KVCacheInterface {
public:
virtual ~KVCacheInterface() = default;

// Resizes the KV cache to the specified number of entries.
// Note: If the requested `num_entries` is smaller than the current number
// of entries, the cache will be trimmed to the requested size.
virtual absl::Status Resize(size_t num_entries) = 0;

// Returns the total number of entries in the KV cache per block.
virtual int GetNumEntries() const = 0;

// Returns the batch size of the KV cache.
virtual int GetBatchSize() const = 0;

// Serializes the KV cache to a byte string.
virtual absl::StatusOr<std::string> Serialize() const = 0;

// Loads the KV cache from a serialized byte string.
virtual absl::Status Load(absl::string_view serialized_kv_cache) = 0;

// Selects a single batch from the other KV cache and copies it to this KV
// cache.
// Example:
// This has shape [1, ...] and other has shape [3, ...]. Then we can select
// batch x from other and copy it to this
// (i.e., other[x, :, ...] -> this[0, :, ...]).
virtual absl::Status SelectAndCopyFrom(KVCacheInterface& other,
int batch_index) = 0;

// Broadcasts the source KV with batch size 1 to this KV cache with batch size
// > 1.
// Example:
// This has shape [3, ...] and other has shape [1, ...]. Then we can copy
// other[0, :, ...] -> this[0, :, ...], this[1, :, ...], this[2, :, ...].
virtual absl::Status BroadcastAndCopyFrom(KVCacheInterface& other) = 0;
};

} // namespace litert::lm
Expand Down
229 changes: 229 additions & 0 deletions runtime/executor/litert/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [Google-internal load of `cc_library`]

package(
default_hdrs_check = "strict",
default_visibility = [
"//visibility:public",
],
)

licenses(["notice"])

cc_library(
name = "debug_utils",
srcs = ["debug_utils.cc"],
hdrs = ["debug_utils.h"],
deps = [
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"//runtime/util:convert_tensor_buffer",
] + select({
"@litert//litert:litert_link_capi_so": [
"@litert//litert/cc:litert_api_with_dynamic_runtime",
],
"//conditions:default": [
"@litert//litert/cc:litert_tensor_buffer",
],
}),
)

cc_test(
name = "debug_utils_test",
srcs = ["debug_utils_test.cc"],
deps = [
":debug_utils",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/log:scoped_mock_log",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@litert//litert/cc:litert_tensor_buffer",
"@litert//litert/test:matchers",
"//runtime/util:convert_tensor_buffer",
],
)

cc_library(
name = "kv_cache",
srcs = ["kv_cache.cc"],
hdrs = ["kv_cache.h"],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"//runtime/executor:common_utils",
"//runtime/executor:kv_cache_interface",
"//runtime/executor:litert_compiled_model_executor_utils",
"//runtime/util:litert_status_util",
] + select({
"@litert//litert:litert_link_capi_so": [
"@litert//litert/cc:litert_api_with_dynamic_runtime",
],
"//conditions:default": [
"@litert//litert/cc:litert_compiled_model",
"@litert//litert/cc:litert_element_type",
"@litert//litert/cc:litert_environment",
"@litert//litert/cc:litert_layout",
"@litert//litert/cc:litert_macros",
"@litert//litert/cc:litert_model",
"@litert//litert/cc:litert_model_types",
"@litert//litert/cc:litert_options",
"@litert//litert/cc:litert_ranked_tensor_type",
"@litert//litert/cc:litert_tensor_buffer",
"@litert//litert/cc:litert_tensor_buffer_types",
],
}),
)

cc_test(
name = "kv_cache_test",
srcs = ["kv_cache_test.cc"],
data = [
"//runtime/testdata",
],
deps = [
":kv_cache",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@litert//litert/cc:litert_common",
"@litert//litert/cc:litert_compiled_model",
"@litert//litert/cc:litert_environment",
"@litert//litert/cc:litert_model",
"@litert//litert/cc:litert_options",
"@litert//litert/test:matchers",
"//runtime/components:model_resources",
"//runtime/components:model_resources_litert_lm",
"//runtime/util:convert_tensor_buffer",
"//runtime/util:litert_lm_loader",
"//runtime/util:scoped_file",
"//runtime/util:test_utils",
],
)

cc_library(
name = "llm_executor",
srcs = ["llm_executor.cc"],
hdrs = ["llm_executor.h"],
deps = [
":kv_cache",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"//runtime/components:model_resources",
"//runtime/components/embedding_lookup:embedding_lookup_manager",
"//runtime/executor:executor_settings_base",
"//runtime/executor:kv_cache_interface",
"//runtime/executor:litert_compiled_model_executor_utils",
"//runtime/executor:llm_executor_interface",
"//runtime/executor:llm_executor_io_types",
"//runtime/executor:llm_executor_settings",
"//runtime/executor:llm_litert_compiled_model_cache_utils",
"//runtime/util:convert_tensor_buffer",
"//runtime/util:file_util",
"//runtime/util:litert_status_util",
"//runtime/util:lora_util",
"//runtime/util:scoped_file",
"@litert//tflite/delegates/xnnpack:xnnpack_delegate",
] + select({
"@litert//litert:litert_link_capi_so": [
"@litert//litert/cc:litert_api_with_dynamic_runtime",
],
"//conditions:default": [
"@litert//litert/cc:litert_common",
"@litert//litert/cc:litert_compiled_model",
"@litert//litert/cc:litert_element_type",
"@litert//litert/cc:litert_environment",
"@litert//litert/cc:litert_expected",
"@litert//litert/cc:litert_layout",
"@litert//litert/cc:litert_macros",
"@litert//litert/cc:litert_model",
"@litert//litert/cc:litert_model_types",
"@litert//litert/cc:litert_options",
"@litert//litert/cc:litert_ranked_tensor_type",
"@litert//litert/cc:litert_tensor_buffer",
"@litert//litert/cc:litert_tensor_buffer_types",
"@litert//litert/cc/options:litert_cpu_options",
"@litert//litert/cc/options:litert_gpu_options",
"@litert//litert/cc/options:litert_runtime_options",
],
}),
)

cc_test(
name = "llm_executor_cpu_test",
srcs = ["llm_executor_cpu_test.cc"],
data = [
"//runtime/testdata",
],
deps = [
":llm_executor",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@litert//litert/cc:litert_environment",
"@litert//litert/test:matchers",
"//runtime/components:model_resources",
"//runtime/components:model_resources_litert_lm",
"//runtime/executor:executor_settings_base",
"//runtime/executor:llm_executor_io_types",
"//runtime/executor:llm_executor_settings",
"//runtime/util:convert_tensor_buffer",
"//runtime/util:litert_lm_loader",
"//runtime/util:scoped_file",
"//runtime/util:test_utils",
],
)

cc_test(
name = "llm_executor_gpu_test",
srcs = ["llm_executor_gpu_test.cc"],
data = [
"//runtime/testdata",
],
tags = ["requires-gpu-nvidia"],
deps = [
":llm_executor",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@litert//litert/cc:litert_environment",
"@litert//litert/test:matchers",
"//runtime/components:model_resources",
"//runtime/components:model_resources_litert_lm",
"//runtime/executor:executor_settings_base",
"//runtime/executor:llm_executor_io_types",
"//runtime/executor:llm_executor_settings",
"//runtime/util:convert_tensor_buffer",
"//runtime/util:litert_lm_loader",
"//runtime/util:scoped_file",
"//runtime/util:test_utils",
],
)
66 changes: 66 additions & 0 deletions runtime/executor/litert/debug_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2026 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/executor/litert/debug_utils.h"

#include <cstddef>

#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/strings/str_join.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "runtime/util/convert_tensor_buffer.h"

namespace litert::lm {

void LogValues(absl::Span<const float> values, size_t num_values_to_log,
absl::string_view debug) {
constexpr size_t kNumExtraValuesToLog = 10;
if (num_values_to_log * 3 + kNumExtraValuesToLog >= values.size()) {
ABSL_LOG(INFO) << debug << "(size=" << values.size()
<< "): " << absl::StrJoin(values, ", ");
return;
}

size_t end_offset = values.size() - num_values_to_log;
size_t mid_offset = end_offset / 2;
ABSL_LOG(INFO) << debug << "(size=" << values.size() << "): "
<< absl::StrJoin(values.subspan(0, num_values_to_log), ", ")
<< " ... "
<< absl::StrJoin(values.subspan(mid_offset, num_values_to_log),
", ")
<< " ... " << absl::StrJoin(values.subspan(end_offset), ", ");
}

void LogTensor(TensorBuffer& tensor, size_t num_values_to_log,
absl::string_view debug) {
// Try to get the reference if tensor is in CPU memory.
auto values_span = ReferTensorBufferAsSpan<float>(tensor);
if (values_span) {
LogValues(*values_span, num_values_to_log, debug);
return;
}

// Otherwise, copy the logits from the tensor buffer to a vector.
auto values_vector = CopyFromTensorBuffer<float>(tensor);
if (values_vector) {
LogValues(*values_vector, num_values_to_log, debug);
return;
}

ABSL_LOG(ERROR) << debug << ": Failed to log logits.";
}

} // namespace litert::lm
Loading
Loading