Skip to content

Commit e99ae9e

Browse files
perf(tokenspeed): optimize multimodal tensor handoff
Pack TokenSpeed encoder inputs into offset SHM segments, preserve placeholder spans for faster worker handoff, and default video tensor transport to auto. Signed-off-by: yechank-nvidia <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 8386968 commit e99ae9e

5 files changed

Lines changed: 1375 additions & 340 deletions

File tree

docs/reference/configuration.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,11 +942,17 @@ smg \
942942
These env-only variables tune how the router ships preprocessed multimodal
943943
tensors (image/video encoder inputs) to a TokenSpeed worker. They do not affect
944944
accuracy — the inline and shared-memory paths produce byte-identical tensors.
945+
SHM handles include offsets; multi-item TokenSpeed encoder inputs may share one
946+
packed segment while preserving the same byte-exact tensor payloads and reducing
947+
per-tensor file lifecycle overhead.
945948

946949
| Environment Variable | Default | Description |
947950
|---------------------|---------|-------------|
948-
| `SMG_TOKENSPEED_MM_TENSOR_TRANSPORT` | `inline` | Transport for large MM tensors: `inline` (gRPC bytes), `shm` (always use `/dev/shm`), or `auto` (use `/dev/shm` only when the worker is *verified* to share it). In `auto`, the router compares the worker's advertised `/dev/shm` namespace token (`GetServerInfo`) to its own and uses SHM only on a match; otherwise it falls back to inline. No locality configuration is needed. |
951+
| `SMG_TOKENSPEED_MM_TENSOR_TRANSPORT` | image/audio: `inline`; video: `auto` | Transport for large MM tensors: `inline` (gRPC bytes), `shm` (always use `/dev/shm`), or `auto` (use `/dev/shm` only when the worker is *verified* to share it). When unset, image/audio stay inline while video uses `auto` to avoid the high-throughput video gRPC byte-copy path on colocated workers without hurting image TTFT. In `auto`, the router compares the worker's advertised `/dev/shm` namespace token (`GetServerInfo`) to its own and uses SHM only on a match; otherwise it falls back to inline. No locality configuration is needed. |
949952
| `SMG_TOKENSPEED_MM_SHM_MIN_BYTES` | `65536` | Minimum tensor size (bytes) before the SHM path is used; smaller tensors stay inline. |
953+
| `SMG_MM_PREPROCESS_PAR_MIN_BYTES` | `524288` | Minimum output size before CPU image/video preprocessing splits work across helper threads. |
954+
| `SMG_MM_PREPROCESS_PAR_MIN_ROWS` | `32` | Minimum output rows or block bands per helper thread for CPU multimodal preprocessing. |
955+
| `SMG_MM_PREPROCESS_PAR_MAX_THREADS` | `8` | Maximum helper threads spawned per image/video preprocessing pass. Raise for large single requests; keep lower for high-concurrency TTFT. |
950956
| `SMG_LOG_MM_TIMING` | `false` | Log per-stage multimodal preprocessing/assembly timing at `INFO`. Accepts `1`/`true`/`yes`. |
951957

952958
The TokenSpeed gRPC servicer (worker side) reads two companion variables:

grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import dataclasses
1414
import json
1515
import logging
16+
import math
1617
import os
1718
import re
1819
import time
@@ -994,23 +995,26 @@ def _mm_inputs_from_itemized_proto(
994995
for name, tensor_data in item_proto.model_specific_tensors.items()
995996
}
996997
model_elapsed_ms = (
997-
(time.perf_counter() - model_started) * 1000 if model_started is not None else None
998+
(time.perf_counter() - model_started) * 1000
999+
if model_started is not None
1000+
else None
9981001
)
9991002
self._validate_item_tensor_consistency(modality, model_specific_data)
10001003

1001-
if not item_proto.placeholders:
1002-
raise ValueError("MultimodalItem carried no placeholders")
1003-
if any(p.length <= 0 for p in item_proto.placeholders):
1004-
raise ValueError("MultimodalItem.placeholders.length must be > 0")
1005-
offsets = [(p.offset, p.offset + p.length - 1) for p in item_proto.placeholders]
1004+
offsets, token_count, offset_ends, offset_prefix = (
1005+
self._offsets_from_proto_placeholders(item_proto.placeholders)
1006+
)
10061007

10071008
content_hash = bytes(item_proto.content_hash)
10081009
mm_item = MultimodalDataItem(
10091010
modality=modality,
10101011
feature=feature,
10111012
model_specific_data=model_specific_data,
10121013
offsets=offsets,
1014+
token_count=token_count,
10131015
hash=int.from_bytes(content_hash[:8], "little") if content_hash else None,
1016+
offset_ends=offset_ends,
1017+
offset_prefix=offset_prefix,
10141018
)
10151019
mm_item.set_pad_value()
10161020
items.append(mm_item)
@@ -1055,6 +1059,7 @@ def _mm_inputs_from_itemized_proto(
10551059
mm_items=items,
10561060
im_token_id=im_token_id,
10571061
video_token_id=video_token_id,
1062+
pad_values_ready=True,
10581063
)
10591064

10601065
@staticmethod
@@ -1092,6 +1097,44 @@ def _validate_item_tensor_consistency(
10921097
if modality == Modality.VIDEO and not has_video_grid:
10931098
raise ValueError("VIDEO MultimodalItem must carry video_grid_thw")
10941099

1100+
@staticmethod
1101+
def _offsets_from_proto_placeholders(
1102+
placeholders,
1103+
) -> tuple[list[tuple[int, int]], int, list[int] | None, list[int] | None]:
1104+
if len(placeholders) == 1:
1105+
placeholder = placeholders[0]
1106+
length = int(placeholder.length)
1107+
if length <= 0:
1108+
raise ValueError("MultimodalItem.placeholders.length must be > 0")
1109+
start = int(placeholder.offset)
1110+
end = start + length - 1
1111+
return [(start, end)], length, [end], [0, length]
1112+
1113+
offsets = []
1114+
offset_ends = []
1115+
offset_prefix = [0]
1116+
sorted_non_overlapping = True
1117+
prev_end = -1
1118+
token_count = 0
1119+
for placeholder in placeholders:
1120+
length = int(placeholder.length)
1121+
if length <= 0:
1122+
raise ValueError("MultimodalItem.placeholders.length must be > 0")
1123+
start = int(placeholder.offset)
1124+
end = start + length - 1
1125+
if start <= prev_end:
1126+
sorted_non_overlapping = False
1127+
offsets.append((start, end))
1128+
offset_ends.append(end)
1129+
token_count += length
1130+
offset_prefix.append(token_count)
1131+
prev_end = end
1132+
if not offsets:
1133+
raise ValueError("MultimodalItem carried no placeholders")
1134+
if not sorted_non_overlapping:
1135+
return offsets, token_count, None, None
1136+
return offsets, token_count, offset_ends, offset_prefix
1137+
10951138
@staticmethod
10961139
def _tensor_from_proto(
10971140
tensor_data: tokenspeed_scheduler_pb2.TensorData,
@@ -1107,7 +1150,7 @@ def _tensor_from_proto(
11071150

11081151
if tensor_data.dtype == "bfloat16":
11091152
# numpy has no bfloat16 — read the raw bits as uint16, reinterpret.
1110-
expected = int(np.prod(shape, dtype=np.int64)) * np.dtype(np.uint16).itemsize
1153+
expected = math.prod(shape) * np.dtype(np.uint16).itemsize
11111154
if len(raw) != expected:
11121155
raise ValueError(
11131156
f"TensorData byte length mismatch for bfloat16 shape={shape}: "
@@ -1118,7 +1161,7 @@ def _tensor_from_proto(
11181161
)
11191162
else:
11201163
dtype = np.dtype(tensor_data.dtype)
1121-
expected = int(np.prod(shape, dtype=np.int64)) * dtype.itemsize
1164+
expected = math.prod(shape) * dtype.itemsize
11221165
if len(raw) != expected:
11231166
raise ValueError(
11241167
f"TensorData byte length mismatch for dtype={tensor_data.dtype}, "
@@ -1146,27 +1189,32 @@ def _feature_from_proto(
11461189
return TokenSpeedSchedulerServicer._tensor_from_proto(tensor_data, cast_to=cast_to)
11471190

11481191
dtype = TokenSpeedSchedulerServicer._torch_dtype_from_proto(tensor_data.dtype)
1149-
if (
1150-
cast_to is not None
1151-
and dtype != cast_to
1152-
and torch.is_floating_point(torch.empty((), dtype=dtype))
1153-
):
1154-
return TokenSpeedSchedulerServicer._tensor_from_proto(tensor_data, cast_to=cast_to)
1155-
1156-
shm = tensor_data.shm
1157-
if shm.offset != 0:
1192+
if cast_to is not None and dtype != cast_to:
11581193
return TokenSpeedSchedulerServicer._tensor_from_proto(tensor_data, cast_to=cast_to)
11591194

11601195
shape = tuple(int(dim) for dim in tensor_data.shape)
1161-
expected = int(np.prod(shape, dtype=np.int64)) * torch.empty((), dtype=dtype).element_size()
1162-
if int(shm.nbytes) != expected:
1196+
expected = math.prod(shape) * TokenSpeedSchedulerServicer._torch_dtype_size(dtype)
1197+
shm = tensor_data.shm
1198+
offset = int(shm.offset)
1199+
nbytes = int(shm.nbytes)
1200+
if offset < 0:
1201+
raise ValueError(
1202+
f"TensorData.shm offset must be non-negative for shape={list(shape)}: {offset}"
1203+
)
1204+
if nbytes != expected:
11631205
raise ValueError(
11641206
f"TensorData.shm byte length mismatch for dtype={tensor_data.dtype}, "
1165-
f"shape={list(shape)}: expected {expected}, got {int(shm.nbytes)}"
1207+
f"shape={list(shape)}: expected {expected}, got {nbytes}"
11661208
)
11671209

11681210
name = TokenSpeedSchedulerServicer._validated_shm_name(shm.name)
1169-
return ShmTensorHandle(shm_name=name, shape=shape, dtype=dtype)
1211+
return ShmTensorHandle(
1212+
shm_name=name,
1213+
shape=shape,
1214+
dtype=dtype,
1215+
offset=offset,
1216+
nbytes=nbytes,
1217+
)
11701218

11711219
@staticmethod
11721220
def _tensor_payload_bytes(tensor_data: tokenspeed_scheduler_pb2.TensorData) -> bytes:
@@ -1223,6 +1271,14 @@ def _torch_dtype_from_proto(dtype: str) -> torch.dtype:
12231271
return torch.float32
12241272
raise ValueError(f"Unsupported TensorData dtype for SHM feature: {dtype!r}")
12251273

1274+
@staticmethod
1275+
def _torch_dtype_size(dtype: torch.dtype) -> int:
1276+
if dtype is torch.float32:
1277+
return 4
1278+
if dtype is torch.float16 or dtype is torch.bfloat16:
1279+
return 2
1280+
return torch.empty((), dtype=dtype).element_size()
1281+
12261282
@staticmethod
12271283
def _torch_dtype_to_proto(dtype: torch.dtype | None) -> str:
12281284
if dtype is torch.bfloat16:
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import torch
3+
from smg_grpc_proto.generated import tokenspeed_scheduler_pb2
4+
from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer
5+
from tokenspeed.runtime.multimodal.shm_transport import ShmTensorHandle
6+
7+
8+
def test_feature_from_proto_preserves_offset_shm_handle():
9+
tensor = tokenspeed_scheduler_pb2.TensorData(
10+
shape=[3, 4],
11+
dtype="float32",
12+
shm=tokenspeed_scheduler_pb2.ShmHandle(
13+
name="smg-tokenspeed-test",
14+
offset=128,
15+
nbytes=3 * 4 * 4,
16+
owner_id="smg:test",
17+
),
18+
)
19+
20+
feature = TokenSpeedSchedulerServicer._feature_from_proto(tensor)
21+
22+
assert isinstance(feature, ShmTensorHandle)
23+
assert feature.shm_name == "smg-tokenspeed-test"
24+
assert feature.shape == (3, 4)
25+
assert feature.dtype is torch.float32
26+
assert feature.offset == 128
27+
assert feature.nbytes == 3 * 4 * 4
28+
29+
30+
def test_feature_from_proto_rejects_offset_shm_length_mismatch():
31+
tensor = tokenspeed_scheduler_pb2.TensorData(
32+
shape=[3, 4],
33+
dtype="float32",
34+
shm=tokenspeed_scheduler_pb2.ShmHandle(
35+
name="smg-tokenspeed-test",
36+
offset=128,
37+
nbytes=4,
38+
owner_id="smg:test",
39+
),
40+
)
41+
42+
with pytest.raises(ValueError, match="byte length mismatch"):
43+
TokenSpeedSchedulerServicer._feature_from_proto(tensor)
44+
45+
46+
def test_offsets_from_proto_placeholders_validates_and_builds_offsets_once():
47+
placeholders = [
48+
tokenspeed_scheduler_pb2.PlaceholderRange(offset=10, length=3),
49+
tokenspeed_scheduler_pb2.PlaceholderRange(offset=20, length=1),
50+
]
51+
52+
assert TokenSpeedSchedulerServicer._offsets_from_proto_placeholders(
53+
placeholders
54+
) == ([(10, 12), (20, 20)], 4, [12, 20], [0, 3, 4])
55+
56+
57+
def test_offsets_from_proto_placeholders_single_placeholder_fast_path():
58+
placeholders = [
59+
tokenspeed_scheduler_pb2.PlaceholderRange(offset=10, length=3),
60+
]
61+
62+
assert TokenSpeedSchedulerServicer._offsets_from_proto_placeholders(
63+
placeholders
64+
) == ([(10, 12)], 3, [12], [0, 3])
65+
66+
67+
def test_offsets_from_proto_placeholders_rejects_empty_and_non_positive_lengths():
68+
with pytest.raises(ValueError, match="no placeholders"):
69+
TokenSpeedSchedulerServicer._offsets_from_proto_placeholders([])
70+
71+
with pytest.raises(ValueError, match="length must be > 0"):
72+
TokenSpeedSchedulerServicer._offsets_from_proto_placeholders(
73+
[tokenspeed_scheduler_pb2.PlaceholderRange(offset=10, length=0)]
74+
)

0 commit comments

Comments
 (0)