Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ struct Router {
max_idle_secs: u64,
assignment_mode: String,
max_payload_size: usize,
multimodal_tensor_transport: Option<String>,
multimodal_shm_min_bytes: Option<usize>,
dp_aware: bool,
dp_minimum_tokens_scheduler: bool,
api_key: Option<String>,
Expand Down Expand Up @@ -691,6 +693,8 @@ impl Router {
.health_check_port(self.health_check_port)
.connection_mode(self.connection_mode)
.max_payload_size(self.max_payload_size)
.multimodal_tensor_transport(self.multimodal_tensor_transport.clone())
.multimodal_shm_min_bytes(self.multimodal_shm_min_bytes)
.request_timeout_secs(self.request_timeout_secs)
.worker_startup_timeout_secs(self.worker_startup_timeout_secs)
.worker_startup_check_interval_secs(self.worker_startup_check_interval)
Expand Down Expand Up @@ -890,6 +894,8 @@ impl Router {
// positional argument keeps its index for callers that construct
// `_Router(...)` positionally. See the struct-field note above.
health_check_port = None,
multimodal_tensor_transport = None,
multimodal_shm_min_bytes = None,
))]
#[expect(clippy::too_many_arguments)]
#[expect(
Expand Down Expand Up @@ -1009,6 +1015,8 @@ impl Router {
// Appended last to match the `#[pyo3(signature)]` order above and
// preserve positional-argument compatibility.
health_check_port: Option<u16>,
multimodal_tensor_transport: Option<String>,
multimodal_shm_min_bytes: Option<usize>,
) -> PyResult<Self> {
let mut all_urls = worker_urls.clone();

Expand Down Expand Up @@ -1047,6 +1055,8 @@ impl Router {
max_idle_secs,
assignment_mode,
max_payload_size,
multimodal_tensor_transport,
multimodal_shm_min_bytes,
dp_aware,
dp_minimum_tokens_scheduler,
api_key,
Expand Down
28 changes: 28 additions & 0 deletions bindings/python/src/smg/router_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
logger = logging.getLogger(__name__)


def _non_negative_int(value: str) -> int:
parsed = int(value)
if parsed < 0:
raise argparse.ArgumentTypeError(f"must be >= 0, got {parsed}")
return parsed


COMMON_POLICY_CHOICES = [
"random",
"round_robin",
Expand Down Expand Up @@ -64,6 +71,9 @@ class RouterArgs:
max_idle_secs: int = 4 * 3600
assignment_mode: str = "random" # Mode for manual policy new routing key assignment
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
# Multimodal tensor transport (None = use env/default)
multimodal_tensor_transport: str | None = None
multimodal_shm_min_bytes: int | None = None
bucket_adjust_interval_secs: int = 5
dp_aware: bool = False
dp_minimum_tokens_scheduler: bool = False
Expand Down Expand Up @@ -484,6 +494,24 @@ def add_cli_args(
help="Enable IGW (Inference-Gateway) mode for multi-model support",
)

# Multimodal arguments
multimodal_group = parser.add_argument_group(
"Multimodal", "Multimodal tensor transport configuration"
)
multimodal_group.add_argument(
f"--{prefix}multimodal-tensor-transport",
type=str,
choices=["inline", "shm", "auto"],
default=None,
help="Transport for large multimodal tensors (inline|shm|auto)",
)
multimodal_group.add_argument(
f"--{prefix}multimodal-shm-min-bytes",
type=_non_negative_int,
default=None,
help="Minimum multimodal tensor size (bytes) before the SHM transport is used",
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)

# PD-specific arguments
pd_group.add_argument(
f"--{prefix}pd-disaggregation",
Expand Down
23 changes: 23 additions & 0 deletions crates/grpc_client/proto/common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,26 @@ message ProfileResponse {
bool success = 1;
string message = 2;
}

// =====================
// Multimodal Tensor Transport
// =====================
//
// Out-of-band transport descriptors for large multimodal tensor payloads,
// shared by every engine's TensorData. The raw tensor bytes live in exactly
// one transport; the engine's TensorData carries shape and dtype.

message ShmHandle {
string name = 1;
uint64 offset = 2;
uint64 nbytes = 3;
// Producer/lifetime owner. Used by implementations to coordinate cleanup.
string owner_id = 4;
}

message RemoteTensorHandle {
// Examples: "nixl", "ucx", "object_store".
string transport = 1;
bytes descriptor = 2;
uint64 nbytes = 3;
}
19 changes: 2 additions & 17 deletions crates/grpc_client/proto/tokenspeed_scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -125,28 +125,13 @@ message TensorData {
bytes inline = 3;
// Same-host CPU shared memory path. This is the preferred large-payload
// transport when SMG and TokenSpeed share /dev/shm.
ShmHandle shm = 4;
smg.grpc.common.ShmHandle shm = 4;
// Cross-node or non-shared-memory transport descriptor. NIXL is the
// expected remote transport for distributed multimodal tensor payloads.
RemoteTensorHandle remote = 5;
smg.grpc.common.RemoteTensorHandle remote = 5;
}
}

message ShmHandle {
string name = 1;
uint64 offset = 2;
uint64 nbytes = 3;
// Producer/lifetime owner. Used by implementations to coordinate cleanup.
string owner_id = 4;
}

message RemoteTensorHandle {
// Examples: "nixl", "ucx", "object_store".
string transport = 1;
bytes descriptor = 2;
uint64 nbytes = 3;
}

// Where a multimodal item's tokens sit inside input_ids.
message PlaceholderRange {
uint32 offset = 1;
Expand Down
26 changes: 24 additions & 2 deletions crates/grpc_client/proto/vllm_engine.proto
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,22 @@ message TokenizedInput {
repeated uint32 input_ids = 2; // Actual token IDs to process
}

// A typed tensor: raw little-endian bytes + shape + dtype.
// A typed tensor descriptor. The raw bytes live in exactly one payload
// transport; shape and dtype describe how the receiver interprets them.
message TensorData {
bytes data = 1; // Raw little-endian bytes (f32/i64/u32)
repeated uint32 shape = 2; // Dimension sizes
string dtype = 3; // "float32", "int64", "uint32"

oneof payload {
// Current path: raw little-endian bytes carried in the gRPC message
// (field 1, formerly `bytes data`, kept for wire compatibility).
bytes inline = 1;
// Same-host CPU shared memory path; preferred for large payloads when
// SMG and the vLLM worker share /dev/shm.
smg.grpc.common.ShmHandle shm = 4;
// Cross-node or non-shared-memory transport (e.g. NIXL). Not implemented.
smg.grpc.common.RemoteTensorHandle remote = 5;
}
}

message PlaceholderRange {
Expand Down Expand Up @@ -140,6 +151,12 @@ message MultimodalInputs {
// Tensor keys that should remain on CPU (not transferred to GPU).
// Maps to vLLM's MultiModalFieldConfig keep_on_cpu flag.
repeated string keep_on_cpu_keys = 9;

// Whether this payload is video (vs image). `false` (default) keeps wire
// compatibility with existing image-only senders; `true` routes the encoder
// tensor to vLLM's `pixel_values_videos` input with `video` field configs on
// the servicer. vLLM multimodal is image|video only, so a bool suffices.
bool is_video = 10;
}

// =====================
Expand Down Expand Up @@ -339,6 +356,11 @@ message GetServerInfoResponse {
string kv_engine_id = 8; // kv_transfer_config.engine_id, "" if not configured

int32 data_parallel_size = 9; // parallel_config.data_parallel_size (1 when DP is off)

// This worker's /dev/shm tmpfs identity (<boot_id>:<st_dev>), advertised so
// the router can verify a shared /dev/shm before using the SHM tensor
// transport under `auto`. Empty when it can't be determined.
string shm_namespace_id = 10;
}

// =====================
Expand Down
2 changes: 1 addition & 1 deletion crates/grpc_client/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "smg-grpc-proto"
version = "0.4.11"
version = "0.4.12"
description = "SMG gRPC proto definitions for vLLM, TRT-LLM, MLX, TokenSpeed, and SGLang"
requires-python = ">=3.10"
dependencies = [
Expand Down
14 changes: 14 additions & 0 deletions crates/protocols/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,18 @@ pub struct WorkerSpec {
/// Falls back to the global `load_monitor_interval_secs` from router config.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub load_monitor_interval_secs: Option<u64>,

/// Per-worker multimodal tensor transport override (`inline` | `shm` | `auto`).
/// When set, overrides the router-level `multimodal_tensor_transport` default
/// for this worker (e.g. force `shm` for a co-located worker, `inline` for a
/// remote one).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multimodal_tensor_transport: Option<String>,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Nit: This field accepts any string without validation. A typo like "smh" instead of "shm" silently falls back to inline (the other arm in resolve_mm_shm_enabled logs a warning, but only once via OnceLock — so the second misconfigured worker is entirely silent). Consider either a serde deserialize validation or an enum type to catch invalid values at worker registration time rather than at request time. The CLI flag already validates with value_parser = ["inline", "shm", "auto"]; the API path doesn't get the same protection.


/// Per-worker minimum multimodal tensor size (bytes) before the SHM transport
/// is used. When set, overrides the router-level `multimodal_shm_min_bytes`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multimodal_shm_min_bytes: Option<usize>,
}

impl WorkerSpec {
Expand Down Expand Up @@ -694,6 +706,8 @@ impl WorkerSpec {
resilience: ResilienceUpdate::default(),
max_connection_attempts: default_max_connection_attempts(),
load_monitor_interval_secs: None,
multimodal_tensor_transport: None,
multimodal_shm_min_bytes: None,
}
}
}
Expand Down
44 changes: 28 additions & 16 deletions docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -937,19 +937,31 @@ smg \
| `JWT_JWKS_URI` | `--jwt-jwks-uri` | JWKS URI |
| `CONTROL_PLANE_API_KEYS` | `--control-plane-api-keys` | Control plane API keys |

### TokenSpeed Multimodal Tensor Transport

These env-only variables tune how the router ships preprocessed multimodal
tensors (image/video encoder inputs) to a TokenSpeed worker. They do not affect
accuracy — the inline and shared-memory paths produce byte-identical tensors.

| Environment Variable | Default | Description |
|---------------------|---------|-------------|
| `SMG_TOKENSPEED_MM_TENSOR_TRANSPORT` | `inline` | Transport for large MM tensors: `inline` (gRPC bytes), `shm` (always use `/dev/shm`), or `auto` (use `/dev/shm` only when the worker is *verified* to share it). In `auto`, the router compares the worker's advertised `/dev/shm` namespace token (`GetServerInfo`) to its own and uses SHM only on a match; otherwise it falls back to inline. No locality configuration is needed. |
| `SMG_TOKENSPEED_MM_SHM_MIN_BYTES` | `65536` | Minimum tensor size (bytes) before the SHM path is used; smaller tensors stay inline. |
| `SMG_LOG_MM_TIMING` | `false` | Log per-stage multimodal preprocessing/assembly timing at `INFO`. Accepts `1`/`true`/`yes`. |

The TokenSpeed gRPC servicer (worker side) reads two companion variables:
`TOKENSPEED_UNLINK_MM_SHM_AFTER_READ` (default on — unlink each `/dev/shm`
segment after the worker reads it) and `TOKENSPEED_LOG_MM_TIMING` (worker-side
timing logs).
### Multimodal Tensor Transport

Controls how the router ships preprocessed multimodal tensors (image/video
encoder inputs) to a worker. Supported on **TokenSpeed** and **vLLM** workers. It
does not affect accuracy — the inline and shared-memory paths produce
byte-identical tensors.

Resolution precedence (highest wins): per-worker `WorkerSpec` override → router
config (CLI flag or YAML) → environment variable → built-in default.

| CLI Flag | Environment Variable | Default | Description |
|----------|---------------------|---------|-------------|
| `--multimodal-tensor-transport` | `SMG_MM_TENSOR_TRANSPORT` | `inline` | Transport for large MM tensors: `inline` (gRPC bytes), `shm` (always use `/dev/shm`), or `auto` (use `/dev/shm` only when the worker is *verified* to share it). In `auto`, the router compares the worker's advertised `/dev/shm` namespace token (`GetServerInfo`) to its own and uses SHM only on a match; otherwise it falls back to inline. No locality configuration is needed. |
| `--multimodal-shm-min-bytes` | `SMG_MM_SHM_MIN_BYTES` | `65536` | Minimum tensor size (bytes) before the SHM path is used; smaller tensors stay inline. |
| — | `SMG_LOG_MM_TIMING` | `false` | Log per-stage multimodal preprocessing/assembly timing at `INFO`. Accepts `1`/`true`/`yes`. |

The legacy env names `SMG_TOKENSPEED_MM_TENSOR_TRANSPORT` and
`SMG_TOKENSPEED_MM_SHM_MIN_BYTES` are still honored as fallback aliases.

**Per-worker override.** A worker created via `/workers` (or discovered) may set
`multimodal_tensor_transport` and `multimodal_shm_min_bytes` in its `WorkerSpec`
to override the router-level defaults for that worker — e.g. force `shm` for a
co-located worker and `inline` for a remote one.

The gRPC servicer (worker side) reads `SMG_UNLINK_MM_SHM_AFTER_READ` (default on
— unlink each `/dev/shm` segment after the worker reads it; legacy
`TOKENSPEED_UNLINK_MM_SHM_AFTER_READ` honored as a fallback) and, for TokenSpeed,
`TOKENSPEED_LOG_MM_TIMING` (worker-side timing logs).
Loading
Loading