Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,40 @@ load(
"@org_tensorflow//tensorflow:pytype.default.bzl",
"pytype_strict_library",
)
load("@org_tensorflow//tensorflow:tensorflow.default.bzl", "pybind_extension")
# [Google-internal load of `cc_binary`]

load("@rules_python//python:defs.bzl", "py_test")

pytype_strict_library(
name = "interfaces",
srcs = ["interfaces.py"],
)

pybind_extension(
name = "litert_lm_ext",
pytype_strict_library(
name = "python",
srcs = [
"__init__.py",
"litert_lm_ext.py",
],
data = [":litert_lm_ext.so"],
deps = [
":interfaces",
],
)

cc_binary(
name = "litert_lm_ext.so",
srcs = ["litert_lm.cc"],
copts = [
"-fexceptions",
"-fvisibility=hidden",
],
features = ["-use_header_modules"],
linkopts = select({
"@platforms//os:macos": ["-Wl,-undefined,dynamic_lookup"],
"//conditions:default": ["-Wl,-Bsymbolic"],
}),
linkshared = True,
deps = [
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/log:globals",
Expand All @@ -40,7 +64,19 @@ pybind_extension(
"//runtime/engine:engine_impl_selected", # buildcleaner: keep
"//runtime/engine:engine_interface",
"//runtime/engine:engine_settings",
"@rules_python//python/cc:current_py_cc_headers",
"@litert//tflite:minimal_logging",
"@litert//tflite/core/c:private_c_api_types",
],
)

py_test(
name = "engine_test",
srcs = ["engine_test.py"],
data = ["//runtime/testdata"],
deps = [
"//python",
"@absl_py//absl/flags",
"@absl_py//absl/testing:absltest",
],
)
40 changes: 39 additions & 1 deletion python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 The ODML Authors.
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,3 +13,41 @@
# limitations under the License.

"""LiteRT LM is a library for running GenAI models on devices."""

from litert_lm.python.interfaces import AbstractBenchmark
from litert_lm.python.interfaces import AbstractConversation
from litert_lm.python.interfaces import AbstractEngine
from litert_lm.python.interfaces import Backend
from litert_lm.python.interfaces import BenchmarkInfo
from litert_lm.python.litert_lm_ext import _Benchmark # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import _Engine # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import Benchmark # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import BenchmarkInfo as _BenchmarkInfo # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import Conversation # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import Engine # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import LogSeverity # pytype: disable=import-error
from litert_lm.python.litert_lm_ext import set_min_log_severity # pytype: disable=import-error

# Because the C++ class is created by nanobind and the Python
# interface is a standard ABC, they cannot easily share a formal
# inheritance tree across the C++/Python boundary. Instead, we use the
# register() method in the package's entry point to set the
# relationship.
AbstractEngine.register(_Engine)
AbstractConversation.register(Conversation)
AbstractBenchmark.register(_Benchmark)
BenchmarkInfo.register(_BenchmarkInfo)


__all__ = (
"AbstractBenchmark",
"AbstractConversation",
"AbstractEngine",
"Backend",
"Benchmark",
"BenchmarkInfo",
"Conversation",
"Engine",
"LogSeverity",
"set_min_log_severity",
)
161 changes: 161 additions & 0 deletions python/engine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib

from absl import flags
from absl.testing import absltest

from litert_lm import python as litert_lm

FLAGS = flags.FLAGS


class EngineTest(absltest.TestCase):

_EXPECTED_RESPONSE = "TarefaByte دارایेत्र investigaciónప్రదేశ"

@classmethod
def setUpClass(cls):
super().setUpClass()
litert_lm.set_min_log_severity(litert_lm.LogSeverity.VERBOSE)

def setUp(self):
super().setUp()
self.model_path = str(
pathlib.Path(FLAGS.test_srcdir)
/ "litert_lm/runtime/testdata/test_lm.litertlm"
)

def _create_engine(self, max_num_tokens=10):
return litert_lm.Engine(
self.model_path,
litert_lm.Backend.CPU,
max_num_tokens=max_num_tokens,
cache_dir=":nocache",
)

@staticmethod
def _extract_text(stream):
text_pieces = []
for chunk in stream:
content_list = chunk.get("content", [])
for item in content_list:
if item.get("type") == "text":
text_pieces.append(item.get("text", ""))
return text_pieces

def test_conversation_send_message(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
self.assertIsNotNone(engine)
self.assertIsNotNone(conversation)
user_message = {"role": "user", "content": "Hello world!"}
message = conversation.send_message(user_message)

expected_message = {
"role": "assistant",
"content": [{"type": "text", "text": self._EXPECTED_RESPONSE}],
}
self.assertEqual(message, expected_message)

def test_conversation_send_message_async(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
self.assertIsNotNone(engine)
self.assertIsNotNone(conversation)
user_message = {"role": "user", "content": "Hello world!"}
stream = conversation.send_message_async(user_message)
text_pieces = self._extract_text(stream)

self.assertEqual("".join(text_pieces), self._EXPECTED_RESPONSE)
self.assertLen(text_pieces, 6)

def test_conversation_send_message_async_cancel(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
user_message = {"role": "user", "content": "Hello world!"}
stream = conversation.send_message_async(user_message)

text_pieces = []
for chunk in stream:
content_list = chunk.get("content", [])
for item in content_list:
if item.get("type") == "text":
text_pieces.append(item.get("text", ""))

# Cancel the process after receiving the first chunk.
conversation.cancel_process()

# We only expect to receive the first piece before cancellation.
self.assertNotEmpty(text_pieces)
self.assertLess(len(text_pieces), 6) # Cancelled before completion

def test_benchmark_class(self):
benchmark = litert_lm.Benchmark(
self.model_path,
litert_lm.Backend.CPU,
prefill_tokens=10,
decode_tokens=10,
cache_dir=":nocache",
)
self.assertIsInstance(benchmark, litert_lm.AbstractBenchmark)
result = benchmark.run()
self.assertIsInstance(result, litert_lm.BenchmarkInfo)
self.assertGreater(result.init_time_in_second, 0)
self.assertGreater(result.time_to_first_token_in_second, 0)
self.assertGreater(result.last_prefill_token_count, 0)
self.assertGreater(result.last_prefill_tokens_per_second, 0)
self.assertGreater(result.last_decode_token_count, 0)
self.assertGreater(result.last_decode_tokens_per_second, 0)

def test_engine_abc_inheritance(self):
with self._create_engine() as engine:
self.assertIsInstance(engine, litert_lm.AbstractEngine)

def test_conversation_abc_inheritance(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
self.assertIsInstance(conversation, litert_lm.AbstractConversation)

def test_str_input_support(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
# Test with str input
message = conversation.send_message("Hello world!")
self.assertEqual(message["role"], "assistant")

def test_str_input_support_async(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
# Test with str input (async)
stream = conversation.send_message_async("Hello world!")
text_pieces = self._extract_text(stream)
self.assertNotEmpty(text_pieces)


if __name__ == "__main__":
absltest.main()
10 changes: 5 additions & 5 deletions python/litert_lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ NB_MODULE(litert_lm_ext, module) {

module.def(
"Engine",
[](absl::string_view model_path, const nb::handle& backend,
int max_num_tokens, absl::string_view cache_dir,
[](std::string_view model_path, const nb::handle& backend,
int max_num_tokens, std::string_view cache_dir,
const nb::handle& vision_backend, const nb::handle& audio_backend,
absl::string_view input_prompt_as_hint) {
std::string_view input_prompt_as_hint) {
Backend main_backend = ParseBackend(backend);
std::optional<Backend> vision_backend_opt = std::nullopt;
if (!vision_backend.is_none()) {
Expand Down Expand Up @@ -445,8 +445,8 @@ NB_MODULE(litert_lm_ext, module) {

module.def(
"Benchmark",
[](absl::string_view model_path, const nb::handle& backend,
int prefill_tokens, int decode_tokens, absl::string_view cache_dir) {
[](std::string_view model_path, const nb::handle& backend,
int prefill_tokens, int decode_tokens, std::string_view cache_dir) {
auto benchmark = std::make_unique<Benchmark>(
std::string(model_path), ParseBackend(backend), prefill_tokens,
decode_tokens, std::string(cache_dir));
Expand Down
17 changes: 17 additions & 0 deletions python/litert_lm_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is a shim for the C++ extension 'litert_lm_ext.so'. It is required
# for build system visibility and type checking.
from litert_lm.python.litert_lm_ext import *
2 changes: 1 addition & 1 deletion schema/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@org_tensorflow//tensorflow:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_library", pytype_strict_test = "pytype_strict_contrib_test")
load("@org_tensorflow//tensorflow:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_library", py_test = "pytype_strict_contrib_test")
load("@rules_python//python:defs.bzl", "py_test")

package(
Expand Down
Loading