Skip to content

Commit d130dcb

Browse files
fix: correct ISL token count and fix zmq message size (#597) (#602)
Co-authored-by: Pavithra Vijayakrishnan <160681768+pvijayakrish@users.noreply.github.com>
1 parent 809a07d commit d130dcb

File tree

19 files changed

+1016
-528
lines changed

19 files changed

+1016
-528
lines changed

src/aiperf/common/bootstrap.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ async def _run_service():
7979

8080
ensure_modules_loaded()
8181

82+
if service_class.__name__ in ("Worker", "TimingManager"):
83+
# Disable garbage collection in child processes to prevent unpredictable latency spikes.
84+
# Only required in timing critical services such as Worker and TimingManager.
85+
import gc
86+
87+
for _ in range(3): # Run 3 times to ensure all objects are collected
88+
gc.collect()
89+
gc.freeze()
90+
gc.set_threshold(0)
91+
gc.disable()
92+
8293
# Load and apply custom GPU metrics in child process
8394
if user_config.gpu_telemetry_metrics_file:
8495
from aiperf.gpu_telemetry import constants

src/aiperf/common/models/dataset_models.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,46 @@ def metadata(self) -> TurnMetadata:
146146
delay_ms=self.delay,
147147
)
148148

149+
def copy_with_stripped_media(self) -> "Turn":
150+
"""Create a copy of this turn with multimodal data replaced by placeholders.
151+
152+
This preserves text data (needed for tokenization) but replaces potentially
153+
large image/audio/video contents with small placeholder strings. This is
154+
more efficient than a full deep copy followed by stripping.
155+
156+
Returns:
157+
A new Turn with stripped multimodal contents.
158+
"""
159+
return Turn(
160+
model=self.model,
161+
role=self.role,
162+
timestamp=self.timestamp,
163+
delay=self.delay,
164+
max_tokens=self.max_tokens,
165+
texts=[Text(name=t.name, contents=list(t.contents)) for t in self.texts],
166+
images=[
167+
Image(
168+
name=img.name,
169+
contents=[f"image_{i}" for i in range(len(img.contents))],
170+
)
171+
for img in self.images
172+
],
173+
audios=[
174+
Audio(
175+
name=aud.name,
176+
contents=[f"audio_{i}" for i in range(len(aud.contents))],
177+
)
178+
for aud in self.audios
179+
],
180+
videos=[
181+
Video(
182+
name=vid.name,
183+
contents=[f"video_{i}" for i in range(len(vid.contents))],
184+
)
185+
for vid in self.videos
186+
],
187+
)
188+
149189

150190
class ConversationMetadata(AIPerfBaseModel):
151191
"""Metadata of a conversation."""

src/aiperf/common/models/record_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,6 @@ class RequestRecord(AIPerfBaseModel):
447447
default=None,
448448
description="The original request info.",
449449
)
450-
turns: list[Turn] = Field(
451-
default_factory=list,
452-
description="The actual turns of the request. This will include assistant turns as well as user turns in multi-turn conversations.",
453-
)
454450
request_headers: dict[str, str] | None = Field(
455451
default=None,
456452
description="The headers of the request.",
@@ -510,6 +506,11 @@ class RequestRecord(AIPerfBaseModel):
510506
"Includes detailed timing for connection establishment, DNS resolution, request/response events, etc. "
511507
"The type of the trace data is determined by the transport and library used.",
512508
)
509+
turns: list[Turn] = Field(
510+
default_factory=list,
511+
description="Deep copy of the request turns. This is a copy of the turns from request_info, "
512+
"made to avoid mutating the original session data when stripping multimodal content.",
513+
)
513514

514515
@field_validator("trace_data", mode="before")
515516
@classmethod

src/aiperf/dataset/dataset_manager.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
import asyncio
4+
import gc
45
import time
56

67
import orjson
@@ -22,6 +23,7 @@
2223
from aiperf.common.factories import (
2324
ComposerFactory,
2425
DatasetBackingStoreFactory,
26+
DatasetClientStoreFactory,
2527
EndpointFactory,
2628
ServiceFactory,
2729
)
@@ -45,6 +47,7 @@
4547
)
4648
from aiperf.common.protocols import (
4749
DatasetBackingStoreProtocol,
50+
DatasetClientStoreProtocol,
4851
EndpointProtocol,
4952
ServiceProtocol,
5053
)
@@ -94,6 +97,7 @@ def __init__(
9497
benchmark_id=user_config.benchmark_id,
9598
)
9699
)
100+
self._dataset_client: DatasetClientStoreProtocol | None = None
97101

98102
@on_command(CommandType.PROFILE_CONFIGURE)
99103
async def _profile_configure_command(
@@ -111,9 +115,33 @@ async def _profile_configure_command(
111115
begin = time.perf_counter()
112116
await self._configure_dataset()
113117
await self._generate_inputs_json_file()
118+
await self._configure_dataset_client_and_free_memory()
119+
114120
duration = time.perf_counter() - begin
115121
self.info(lambda: f"Dataset configured in {duration:.2f} seconds")
116122

123+
async def _configure_dataset_client_and_free_memory(self) -> None:
124+
"""Configure the dataset client for serving fallback requests."""
125+
# Create dataset client for serving fallback requests, then free in-memory dataset
126+
client_metadata = self._backing_store.get_client_metadata()
127+
self._dataset_client = DatasetClientStoreFactory.create_instance(
128+
client_metadata=client_metadata,
129+
)
130+
await self._dataset_client.initialize()
131+
# Now that the client is ready, signal that fallback requests can be served
132+
self.dataset_configured.set()
133+
# Free the in-memory dataset now that we have the client to serve fallback requests.
134+
# Reassign to new empty containers (not .clear()) to release object references,
135+
# then run gc.collect() twice to ensure circular references are cleaned up.
136+
conversation_count = len(self.dataset)
137+
self.dataset = {}
138+
self._conversation_ids_cache = []
139+
gc.collect()
140+
gc.collect()
141+
self.info(
142+
f"Dataset client initialized and freed {conversation_count} conversations from memory"
143+
)
144+
117145
async def _configure_tokenizer(self) -> None:
118146
"""Configure the tokenizer for the dataset manager."""
119147
tokenizer_name = self.user_config.tokenizer.name
@@ -304,7 +332,9 @@ async def _configure_dataset(self) -> None:
304332
f"unique conversations: {len(self.dataset_metadata.conversations)}, "
305333
f"unique turn count: {self.dataset_metadata.total_turn_count}"
306334
)
307-
self.dataset_configured.set()
335+
# Note: dataset_configured event is set in _profile_configure_command after
336+
# the dataset client is initialized, to avoid a race condition where fallback
337+
# requests arrive before the client is ready.
308338
await self.publish(
309339
DatasetConfiguredNotification(
310340
service_id=self.service_id,
@@ -317,55 +347,58 @@ async def _configure_dataset(self) -> None:
317347
async def _handle_conversation_request(
318348
self, message: ConversationRequestMessage
319349
) -> ConversationResponseMessage:
320-
"""Handle a conversation request."""
350+
"""Handle a conversation request using the dataset client."""
321351
self.debug(lambda: f"Handling conversation request: {message}")
322352

323353
await self._wait_for_dataset_configuration()
324354

325-
if not self.dataset:
355+
if self._dataset_client is None:
326356
raise self._service_error(
327-
"Dataset is empty and must be configured before handling requests.",
357+
"Dataset client is not initialized. Dataset must be configured before handling requests.",
328358
)
329359

330-
return self._return_conversation_by_id(
331-
request_id=message.request_id,
332-
conversation_id=message.conversation_id,
333-
)
334-
335-
def _return_conversation_by_id(
336-
self, request_id: str | None, conversation_id: str
337-
) -> ConversationResponseMessage:
338-
"""Return a conversation if it exists, otherwise raise an error."""
339-
340-
if conversation_id not in self.dataset:
341-
raise self._service_error(
342-
f"Conversation {conversation_id} not found in dataset.",
360+
try:
361+
conversation = await self._dataset_client.get_conversation(
362+
message.conversation_id
343363
)
364+
except KeyError:
365+
raise self._service_error(
366+
f"Conversation {message.conversation_id} not found in dataset.",
367+
) from None
344368

345-
conversation = self.dataset[conversation_id]
346369
self.trace_or_debug(
347370
lambda: f"Sending conversation response: {conversation}",
348371
lambda: f"Sending conversation response with id: {conversation.session_id}",
349372
)
350373
return ConversationResponseMessage(
351374
service_id=self.service_id,
352-
request_id=request_id,
375+
request_id=message.request_id,
353376
conversation=conversation,
354377
)
355378

356379
@on_request(MessageType.CONVERSATION_TURN_REQUEST)
357380
async def _handle_conversation_turn_request(
358381
self, message: ConversationTurnRequestMessage
359382
) -> ConversationTurnResponseMessage:
360-
"""Handle a turn request."""
383+
"""Handle a turn request using the dataset client."""
361384
self.debug(lambda: f"Handling turn request: {message}")
362385

363-
if message.conversation_id not in self.dataset:
386+
await self._wait_for_dataset_configuration()
387+
388+
if self._dataset_client is None:
364389
raise self._service_error(
365-
f"Conversation {message.conversation_id} not found in dataset.",
390+
"Dataset client is not initialized. Dataset must be configured before handling requests.",
391+
)
392+
393+
try:
394+
conversation = await self._dataset_client.get_conversation(
395+
message.conversation_id
366396
)
397+
except KeyError as e:
398+
raise self._service_error(
399+
f"Conversation {message.conversation_id} not found in dataset.",
400+
) from e
367401

368-
conversation = self.dataset[message.conversation_id]
369402
if message.turn_index >= len(conversation.turns):
370403
raise self._service_error(
371404
f"Turn index {message.turn_index} is out of range for conversation {message.conversation_id}.",
@@ -395,8 +428,11 @@ async def _wait_for_dataset_configuration(self) -> None:
395428
)
396429

397430
@on_stop
398-
async def _cleanup_backing_store(self) -> None:
399-
"""Clean up the backing store and associated mmap files."""
431+
async def _cleanup(self) -> None:
432+
"""Clean up the backing store, dataset client, and associated mmap files."""
433+
if self._dataset_client is not None:
434+
await self._dataset_client.stop()
435+
self.debug("Dataset client cleanup complete")
400436
if self._backing_store is not None:
401437
await self._backing_store.stop()
402438
self.debug("Backing store cleanup complete")

src/aiperf/records/inference_result_parser.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,29 +230,31 @@ async def compute_input_token_count(
230230
return None
231231

232232
tokenizer = await self.get_tokenizer(request_record.model_name)
233-
input_token_count = 0
233+
prompt_texts: list[str] = []
234234

235235
# Include system_message if present (shared system prompt)
236236
if request_record.request_info and request_record.request_info.system_message:
237-
input_token_count += len(
238-
tokenizer.encode(request_record.request_info.system_message)
239-
)
237+
prompt_texts.append(request_record.request_info.system_message)
240238

241239
# Include user_context_message if present (per-conversation user context)
242240
if (
243241
request_record.request_info
244242
and request_record.request_info.user_context_message
245243
):
246-
input_token_count += len(
247-
tokenizer.encode(request_record.request_info.user_context_message)
248-
)
244+
prompt_texts.append(request_record.request_info.user_context_message)
249245

250246
# Include all turns' text content
251-
# TODO: We need to handle images, audios, videos, etc.
252247
for turn in turns:
253248
for text in turn.texts:
254-
input_token_count += len(tokenizer.encode("".join(text.contents)))
255-
return input_token_count
249+
prompt_texts.append("".join(text.contents))
250+
251+
if not prompt_texts:
252+
return None
253+
254+
# NOTE: We combine all the prompt texts with a space separator to create a single prompt string.
255+
# This will get us the most accurate token count for the prompt by avoiding any potential
256+
# boundary issues that could occur if we were to tokenize each text individually.
257+
return self._compute_token_count(tokenizer, prompt_texts, separator=" ")
256258

257259
async def _compute_server_token_counts(
258260
self, responses: list[ParsedResponse]
@@ -317,20 +319,21 @@ def _parse_output_and_reasoning_texts(
317319
return output_texts, reasoning_texts
318320

319321
def _compute_token_count(
320-
self, tokenizer: Tokenizer, texts: list[str]
322+
self, tokenizer: Tokenizer, texts: list[str], separator: str = ""
321323
) -> int | None:
322-
"""Compute the number of tokens in the texts by joining them without any separators and encoding with the tokenizer.
324+
"""Compute the number of tokens in the texts by joining them with an optional separator (default none) and encoding with the tokenizer.
323325
324326
Args:
325327
tokenizer: The tokenizer to use
326328
texts: List of texts to compute the token count for
329+
separator: The separator to use between the texts
327330
328331
Returns:
329332
The number of tokens in the texts, or None if the texts are empty
330333
"""
331334
if not texts:
332335
return None
333-
return len(tokenizer.encode("".join(texts)))
336+
return len(tokenizer.encode(separator.join(texts)))
334337

335338
async def _compute_client_side_token_counts(
336339
self, request_record: RequestRecord, responses: list[ParsedResponse]

src/aiperf/timing/manager.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44

55
import asyncio
6-
import gc
76

87
from aiperf.common.base_component_service import BaseComponentService
98
from aiperf.common.config import ServiceConfig, UserConfig
@@ -129,13 +128,6 @@ async def _on_start_profiling(self, _message: CommandMessage) -> None:
129128
if not self._phase_orchestrator:
130129
raise InvalidStateError("No phase orchestrator configured")
131130

132-
# Disable GC during profiling to eliminate unpredictable latency spikes.
133-
# Collect and freeze first to minimize memory pressure during the benchmark.
134-
self.debug("Disabling garbage collection for stable timing")
135-
gc.collect()
136-
gc.freeze()
137-
gc.disable()
138-
139131
# Start event loop health monitoring only during the benchmark
140132
self.event_loop_monitor.start()
141133

@@ -164,16 +156,6 @@ async def _timing_manager_stop(self) -> None:
164156
await self._phase_orchestrator.stop()
165157

166158
self.event_loop_monitor.stop()
167-
self._re_enable_gc()
168-
169-
def _re_enable_gc(self) -> None:
170-
"""Re-enable garbage collection."""
171-
self.debug(
172-
"Re-enabling garbage collection to allow the timing manager "
173-
"to clean up resources"
174-
)
175-
gc.unfreeze()
176-
gc.enable()
177159

178160

179161
def main() -> None:

src/aiperf/workers/inference_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,13 @@ async def _send_request_internal(
107107
self.debug(
108108
f"pre_send_perf_ns to start_perf_ns latency: {result.start_perf_ns - pre_send_perf_ns} ns"
109109
)
110-
result.turns = request_info.turns
111110
return result
112111
except Exception as e:
113112
self.error(
114113
f"Error calling inference server API at {self.model_endpoint.endpoint.base_url}: {e!r}"
115114
)
116115
return RequestRecord(
117116
request_info=request_info,
118-
turns=request_info.turns,
119117
timestamp_ns=pre_send_timestamp_ns or time.time_ns(),
120118
# Try and use the pre_send_perf_ns if it is available, otherwise use the current time.
121119
start_perf_ns=pre_send_perf_ns or time.perf_counter_ns(),
@@ -156,11 +154,17 @@ def _enrich_request_record(
156154
or self.model_endpoint.primary_model_name
157155
)
158156
record.request_info = request_info
157+
158+
# Copy turns with stripped multimodal data to avoid mutating original session
159+
# and reduce memory usage (placeholders instead of large image/audio/video data)
160+
record.turns = [turn.copy_with_stripped_media() for turn in request_info.turns]
161+
159162
# If this is the first turn, calculate the credit drop latency
160163
if request_info.turn_index == 0 and request_info.drop_perf_ns is not None:
161164
record.credit_drop_latency = (
162165
record.start_perf_ns - request_info.drop_perf_ns
163166
)
167+
164168
# Preserve headers set by transport; only use endpoint headers if not set
165169
if record.request_headers is None:
166170
record.request_headers = request_info.endpoint_headers

0 commit comments

Comments
 (0)