Skip to content

Commit bdbb80e

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add backend support to LiteRT-LM end-to-end tests.
This change allows specifying the backend (e.g., cpu, gpu) for each test case in the JSON data, enabling testing across different hardware configurations. The `run_engine` fixture now accepts and uses the backend parameter. LiteRT-LM-PiperOrigin-RevId: 881666398
1 parent 0eb141b commit bdbb80e

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

tools/test/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
def pytest_addoption(parser: pytest.Parser) -> None:
3131
"""Adds custom command line arguments to pytest."""
32+
3233
parser.addoption(
3334
"--model-path",
3435
action="store",
@@ -108,10 +109,10 @@ def run_engine(
108109
model_path: Path to the model file.
109110
"""
110111

111-
def _run(prompt: str, timeout: int = 120) -> str:
112+
def _run(backend: str, prompt: str, timeout: int = 120) -> str:
112113
cmd = [
113114
engine_binary,
114-
"--backend=cpu",
115+
f"--backend={backend}",
115116
f"--model_path={model_path}",
116117
f"--input_prompt={prompt}",
117118
]

tools/test/test_data/test_e2e_sanity_checks.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
{
22
"test_model_sanity": [
33
{
4-
"id": "basic_question",
4+
"id": "basic_question_cpu",
5+
"backend": "cpu",
6+
"prompt": "What's the tallest building in the world?",
7+
"response": "Burj Khalifa"
8+
},
9+
{
10+
"id": "basic_question_gpu",
11+
"backend": "gpu",
512
"prompt": "What's the tallest building in the world?",
613
"response": "Burj Khalifa"
714
}

tools/test/test_e2e_sanity_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ def test_model_sanity(
4747
"response".
4848
"""
4949

50+
backend = test_case["backend"]
5051
prompt = test_case["prompt"]
5152
expected_response = test_case["response"]
5253

5354
# Execute
54-
output = run_engine(prompt)
55+
output = run_engine(backend, prompt)
5556
clean_output = output.replace("\n", " ")
5657

5758
# Validate

0 commit comments

Comments
 (0)