Skip to content

Commit 45ec926

Browse files
authored
test: Add mooncake trace integration test (#603)
1 parent 99bb3ee commit 45ec926

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Integration tests for Mooncake trace custom dataset type."""
5+
6+
from pathlib import Path
7+
8+
import pytest
9+
10+
from tests.harness.utils import AIPerfCLI, AIPerfMockServer
11+
from tests.integration.conftest import IntegrationTestDefaults as defaults
12+
from tests.integration.utils import create_mooncake_trace_file
13+
14+
15+
@pytest.mark.integration
16+
@pytest.mark.asyncio
17+
class TestMooncakeTraceIntegration:
18+
"""Integration tests for mooncake_trace dataset loader."""
19+
20+
async def test_basic_mooncake_trace_with_input_length(
21+
self,
22+
cli: AIPerfCLI,
23+
aiperf_mock_server: AIPerfMockServer,
24+
tmp_path: Path,
25+
):
26+
"""Test basic Mooncake trace with input_length, output_length, and hash_ids."""
27+
# Real trace data from mooncake_trace.jsonl (first 5 lines)
28+
traces = [
29+
{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]},
30+
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]},
31+
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]},
32+
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]},
33+
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]},
34+
] # fmt: skip
35+
trace_file = create_mooncake_trace_file(tmp_path, traces)
36+
request_count = len(traces)
37+
38+
result = await cli.run(
39+
f"""
40+
aiperf profile \
41+
--model {defaults.model} \
42+
--url {aiperf_mock_server.url} \
43+
--endpoint-type chat \
44+
--input-file {trace_file} \
45+
--custom-dataset-type mooncake_trace \
46+
--request-count {request_count} \
47+
--fixed-schedule \
48+
--workers-max {defaults.workers_max} \
49+
--ui {defaults.ui}
50+
"""
51+
)
52+
53+
assert result.request_count == request_count
54+
assert result.has_all_outputs
55+
56+
async def test_mooncake_trace_with_text_input(
57+
self,
58+
cli: AIPerfCLI,
59+
aiperf_mock_server: AIPerfMockServer,
60+
tmp_path: Path,
61+
):
62+
"""Test Mooncake trace with literal text inputs instead of input_length."""
63+
# Each trace is a single-turn conversation; timestamp required for --fixed-schedule
64+
traces = [
65+
{"timestamp": 0, "text_input": "What is the capital of France?", "output_length": 20},
66+
{"timestamp": 100, "text_input": "Explain quantum computing briefly.", "output_length": 30},
67+
{"timestamp": 200, "text_input": "Write a haiku about programming.", "output_length": 25},
68+
{"timestamp": 300, "text_input": "What is machine learning?", "output_length": 40},
69+
{"timestamp": 400, "text_input": "Describe the solar system.", "output_length": 35},
70+
] # fmt: skip
71+
trace_file = create_mooncake_trace_file(tmp_path, traces)
72+
request_count = len(traces)
73+
74+
result = await cli.run(
75+
f"""
76+
aiperf profile \
77+
--model {defaults.model} \
78+
--url {aiperf_mock_server.url} \
79+
--endpoint-type chat \
80+
--input-file {trace_file} \
81+
--custom-dataset-type mooncake_trace \
82+
--request-count {request_count} \
83+
--fixed-schedule \
84+
--workers-max {defaults.workers_max} \
85+
--ui {defaults.ui}
86+
"""
87+
)
88+
89+
assert result.request_count == request_count
90+
assert result.has_all_outputs
91+
92+
async def test_mooncake_trace_multi_turn_with_session_id(
93+
self,
94+
cli: AIPerfCLI,
95+
aiperf_mock_server: AIPerfMockServer,
96+
tmp_path: Path,
97+
):
98+
"""Test Mooncake trace with session_id for multi-turn conversations."""
99+
# First turn of each session needs timestamp; subsequent turns use delay
100+
traces = [
101+
# Session 1: Two-turn conversation (starts at t=0)
102+
{"session_id": "session-1", "timestamp": 0, "input_length": 100, "output_length": 40},
103+
{"session_id": "session-1", "delay": 500, "input_length": 150, "output_length": 50},
104+
# Session 2: Single-turn (starts at t=100)
105+
{"session_id": "session-2", "timestamp": 100, "input_length": 200, "output_length": 60},
106+
# Session 3: Three-turn conversation (starts at t=200)
107+
{"session_id": "session-3", "timestamp": 200, "input_length": 80, "output_length": 30},
108+
{"session_id": "session-3", "delay": 300, "input_length": 120, "output_length": 45},
109+
{"session_id": "session-3", "delay": 400, "input_length": 90, "output_length": 35},
110+
] # fmt: skip
111+
trace_file = create_mooncake_trace_file(tmp_path, traces)
112+
request_count = len(traces) # Each turn is a request
113+
114+
result = await cli.run(
115+
f"""
116+
aiperf profile \
117+
--model {defaults.model} \
118+
--url {aiperf_mock_server.url} \
119+
--endpoint-type chat \
120+
--input-file {trace_file} \
121+
--custom-dataset-type mooncake_trace \
122+
--request-count {request_count} \
123+
--fixed-schedule \
124+
--workers-max {defaults.workers_max} \
125+
--ui {defaults.ui}
126+
"""
127+
)
128+
129+
assert result.request_count == request_count
130+
assert result.has_all_outputs

tests/integration/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@
1414
logger = AIPerfLogger(__name__)
1515

1616

17+
def create_mooncake_trace_file(
18+
tmp_path: Path,
19+
traces: list[dict],
20+
filename: str = "traces.jsonl",
21+
) -> Path:
22+
"""Create a Mooncake trace JSONL file for testing.
23+
24+
Args:
25+
tmp_path: Temporary directory path
26+
traces: List of trace dictionaries to write
27+
filename: Name of the trace file
28+
29+
Returns:
30+
Path to the created trace file
31+
"""
32+
trace_file = tmp_path / filename
33+
with open(trace_file, "wb") as f:
34+
for trace in traces:
35+
f.write(orjson.dumps(trace) + b"\n")
36+
return trace_file
37+
38+
1739
def create_rankings_dataset(tmp_path: Path, num_entries: int) -> Path:
1840
"""Create a rankings dataset for testing.
1941

0 commit comments

Comments
 (0)