|
| 1 | +"""Engine-neutral KV-cache-event → proto conversion and ZMQ streaming. |
| 2 | +
|
| 3 | +Shared by every engine bridge (vLLM, TokenSpeed, ...). Imports only stdlib + |
| 4 | +the generated proto, and dispatches engine events by class name (BlockStored / |
| 5 | +BlockRemoved / AllBlocksCleared), so it needs no engine import and is |
| 6 | +unit-testable without any engine installed. |
| 7 | +
|
| 8 | +Each engine package keeps its own ``resolve_kv_events_config`` (the only |
| 9 | +engine-specific seam); everything here is wire-format-only. |
| 10 | +""" |
| 11 | + |
| 12 | +import logging |
| 13 | +from collections.abc import AsyncIterator, Awaitable, Callable |
| 14 | + |
| 15 | +from smg_grpc_proto.generated import common_pb2 |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | +_U64_MASK = 0xFFFFFFFFFFFFFFFF |
| 20 | +_I64_SIGN_BIT = 0x8000000000000000 |
| 21 | +_U64_MODULUS = 0x10000000000000000 |
| 22 | + |
| 23 | + |
| 24 | +def to_int64(value: int | bytes) -> int: |
| 25 | + """Reduce an engine block hash to a signed int64 for the proto block_hash field. |
| 26 | +
|
| 27 | + An engine's block hash may be ``int | bytes`` (sha256 bytes when int hashes |
| 28 | + are disabled); bytes are read big-endian. SMG uses the hash only as a node |
| 29 | + identity, so the 64-bit reduction is safe as long as it stays deterministic. |
| 30 | + """ |
| 31 | + if isinstance(value, (bytes, bytearray)): |
| 32 | + value = int.from_bytes(value, "big") |
| 33 | + masked = value & _U64_MASK |
| 34 | + if masked >= _I64_SIGN_BIT: |
| 35 | + masked -= _U64_MODULUS |
| 36 | + return masked |
| 37 | + |
| 38 | + |
| 39 | +def endpoint_for_rank(endpoint: str, dp_rank: int) -> str: |
| 40 | + """Resolve a KV-events PUB endpoint to a connectable SUB address. |
| 41 | +
|
| 42 | + Bind wildcards (``*``, ``0.0.0.0``) are rewritten to ``127.0.0.1`` (the |
| 43 | + latter is not connectable on macOS/Windows). For data-parallel deployments |
| 44 | + each rank publishes on ``base_port + dp_rank``; non-tcp endpoints (ipc://, |
| 45 | + inproc://) get the wildcard substituted but no port arithmetic. |
| 46 | + """ |
| 47 | + resolved = endpoint.replace("*", "127.0.0.1").replace("0.0.0.0", "127.0.0.1") |
| 48 | + if resolved.startswith("tcp://") and dp_rank: |
| 49 | + host, sep, port = resolved.rpartition(":") |
| 50 | + if sep and port.isdigit(): |
| 51 | + return f"{host}:{int(port) + dp_rank}" |
| 52 | + return resolved |
| 53 | + |
| 54 | + |
| 55 | +def convert_event(event: object, event_id: int) -> common_pb2.KvCacheEvent | None: |
| 56 | + """Convert one decoded engine event to a proto KvCacheEvent (or None if unknown).""" |
| 57 | + name = type(event).__name__ |
| 58 | + |
| 59 | + if name == "BlockStored": |
| 60 | + block_size = int(event.block_size) |
| 61 | + blocks = [] |
| 62 | + for i, block_hash in enumerate(event.block_hashes): |
| 63 | + start = i * block_size |
| 64 | + end = start + block_size |
| 65 | + block = common_pb2.KvBlock( |
| 66 | + block_hash=to_int64(block_hash), |
| 67 | + token_ids=list(event.token_ids[start:end]), |
| 68 | + block_size=block_size, |
| 69 | + ) |
| 70 | + lora_id = getattr(event, "lora_id", None) |
| 71 | + if lora_id is not None: |
| 72 | + block.lora_id = to_int64(lora_id) |
| 73 | + blocks.append(block) |
| 74 | + stored = common_pb2.KvBlocksStored(blocks=blocks) |
| 75 | + parent = getattr(event, "parent_block_hash", None) |
| 76 | + if parent is not None: |
| 77 | + stored.parent_block_hash = to_int64(parent) |
| 78 | + return common_pb2.KvCacheEvent(event_id=event_id, stored=stored) |
| 79 | + |
| 80 | + if name == "BlockRemoved": |
| 81 | + return common_pb2.KvCacheEvent( |
| 82 | + event_id=event_id, |
| 83 | + removed=common_pb2.KvBlocksRemoved( |
| 84 | + block_hashes=[to_int64(h) for h in event.block_hashes] |
| 85 | + ), |
| 86 | + ) |
| 87 | + |
| 88 | + if name == "AllBlocksCleared": |
| 89 | + return common_pb2.KvCacheEvent(event_id=event_id, cleared=common_pb2.KvCacheCleared()) |
| 90 | + |
| 91 | + logger.debug("Unknown KV event type %r, skipping", name) |
| 92 | + return None |
| 93 | + |
| 94 | + |
| 95 | +def convert_batch( |
| 96 | + raw_batch: object, seq_num: int, event_id_start: int |
| 97 | +) -> tuple[common_pb2.KvEventBatch, int]: |
| 98 | + """Convert a decoded engine KVEventBatch to a proto KvEventBatch. |
| 99 | +
|
| 100 | + Returns the proto batch and the new event-id counter. The counter advances |
| 101 | + once per input event (even if unconvertible) so ids stay monotonic. |
| 102 | +
|
| 103 | + The DP rank is read from ``data_parallel_rank`` (vLLM) or ``attn_dp_rank`` |
| 104 | + (TokenSpeed); engines that carry neither leave the proto field unset. |
| 105 | + """ |
| 106 | + proto = common_pb2.KvEventBatch(sequence_number=seq_num, timestamp=raw_batch.ts) |
| 107 | + dp_rank = getattr(raw_batch, "data_parallel_rank", None) |
| 108 | + if dp_rank is None: |
| 109 | + dp_rank = getattr(raw_batch, "attn_dp_rank", None) |
| 110 | + if dp_rank is not None: |
| 111 | + proto.dp_rank = dp_rank |
| 112 | + |
| 113 | + event_id = event_id_start |
| 114 | + for event in raw_batch.events: |
| 115 | + event_id += 1 |
| 116 | + proto_event = convert_event(event, event_id) |
| 117 | + if proto_event is not None: |
| 118 | + proto.events.append(proto_event) |
| 119 | + return proto, event_id |
| 120 | + |
| 121 | + |
| 122 | +async def stream_kv_events( |
| 123 | + sub_socket: object, |
| 124 | + decode: Callable[[bytes], object], |
| 125 | + send_initial_metadata: Callable[[], Awaitable[None]], |
| 126 | + is_cancelled: Callable[[], bool], |
| 127 | + *, |
| 128 | + recv_timeout: float = 1.0, |
| 129 | +) -> AsyncIterator[common_pb2.KvEventBatch]: |
| 130 | + """Core ZMQ→proto streaming loop, decoupled from any engine and gRPC types. |
| 131 | +
|
| 132 | + Args: |
| 133 | + sub_socket: a connected ``zmq.asyncio`` SUB socket (duck-typed; only |
| 134 | + ``poll()`` and ``recv_multipart()`` are used). The caller owns the |
| 135 | + socket lifecycle (this function never closes it). |
| 136 | + decode: bytes → decoded engine batch (e.g. ``msgspec.msgpack.Decoder(KVEventBatch).decode``). |
| 137 | + send_initial_metadata: awaitable called once before the first recv so the |
| 138 | + gRPC client's ``subscribe_kv_events().await`` resolves promptly. |
| 139 | + is_cancelled: returns True when the RPC is cancelled; loop then exits. |
| 140 | + recv_timeout: poll timeout so cancellation is observed even when idle. |
| 141 | +
|
| 142 | + Yields proto KvEventBatch using the ZMQ publisher's native sequence numbers. |
| 143 | + """ |
| 144 | + await send_initial_metadata() |
| 145 | + event_id = 0 |
| 146 | + while not is_cancelled(): |
| 147 | + # poll() before recv: cancelling a zmq.asyncio recv future does not |
| 148 | + # cancel the in-flight ZMQ recv and can drop an already-dequeued message. |
| 149 | + if not await sub_socket.poll(timeout=int(recv_timeout * 1000)): |
| 150 | + continue |
| 151 | + frames = await sub_socket.recv_multipart() |
| 152 | + |
| 153 | + # ZMQ multipart: [topic, 8-byte big-endian seq, msgpack payload]. |
| 154 | + if len(frames) < 3: |
| 155 | + continue |
| 156 | + zmq_seq = int.from_bytes(frames[1], "big") |
| 157 | + try: |
| 158 | + raw_batch = decode(frames[2]) |
| 159 | + except Exception as e: # noqa: BLE001 - one bad frame must not kill the stream |
| 160 | + logger.warning("Failed to decode KV event batch: %s", e) |
| 161 | + continue |
| 162 | + |
| 163 | + proto_batch, event_id = convert_batch(raw_batch, zmq_seq, event_id) |
| 164 | + yield proto_batch |
0 commit comments