diff --git a/examples/helper/events.go b/examples/helper/events.go index cb50cf22..1e445464 100644 --- a/examples/helper/events.go +++ b/examples/helper/events.go @@ -22,7 +22,6 @@ import ( "github.com/llm-d/llm-d-kv-cache/examples/testdata" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" - "github.com/llm-d/llm-d-kv-cache/pkg/utils" "github.com/vmihailenco/msgpack/v5" "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -33,23 +32,27 @@ func SimulateProduceEvent(ctx context.Context, publisher *Publisher) error { logger := log.FromContext(ctx) logger.Info("@@@ Simulating vLLM engine publishing BlockStored events...") medium := "GPU" - blockStoredEvent := kvevents.BlockStored{ - BlockHashes: utils.SliceMap(testdata.PromptHashes, func(h uint64) any { return h }), - ParentBlockHash: nil, - TokenIds: []uint32{1, 2, 3}, - BlockSize: 256, - LoraID: nil, - Medium: &medium, - LoraName: nil, + + // Create event in vLLM msgpack array format: [tag, hashes, parent, tokens, blockSize, loraID, medium, loraName] + blockStoredEvent := []any{ + "BlockStored", // Tag + testdata.PromptHashes, // BlockHashes (already []uint64) + nil, // ParentBlockHash + []uint32{1, 2, 3}, // TokenIds + 256, // BlockSize + nil, // LoraID + medium, // Medium + nil, // LoraName } //nolint // won't fail - blockStoredPayload, _ := msgpack.Marshal(blockStoredEvent.ToTaggedUnion()) + blockStoredPayload, _ := msgpack.Marshal(blockStoredEvent) - eventBatch := kvevents.EventBatch{ - TS: float64(time.Now().UnixNano()) / 1e9, - Events: []msgpack.RawMessage{blockStoredPayload}, - DataParallelRank: nil, + // Create vLLM msgpack event batch in array format: [timestamp, [event1, event2, ...], data_parallel_rank] + eventBatch := []any{ + float64(time.Now().UnixNano()) / 1e9, // Timestamp + [][]byte{blockStoredPayload}, // Events array + nil, // DataParallelRank } if err := publisher.PublishEvent(ctx, topic, eventBatch); err != nil { @@ -70,17 +73,21 @@ func SimulateProduceEvent(ctx context.Context, publisher *Publisher) error { func SimulateRemoveEvent(ctx context.Context, publisher *Publisher) error { logger := log.FromContext(ctx) logger.Info("@@@ Simulating vLLM engine removing some blocks...") - blockRemovedEvent := kvevents.BlockRemoved{ - BlockHashes: []any{testdata.PromptHashes[2], testdata.PromptHashes[3]}, + + // Create event in vLLM msgpack array format: [tag, hashes] + blockRemovedEvent := []any{ + "BlockRemoved", + []uint64{testdata.PromptHashes[2], testdata.PromptHashes[3]}, } //nolint // won't fail - blockRemovedPayload, _ := msgpack.Marshal(blockRemovedEvent.ToTaggedUnion()) + blockRemovedPayload, _ := msgpack.Marshal(blockRemovedEvent) - removeEventBatch := kvevents.EventBatch{ - TS: float64(time.Now().UnixNano()) / 1e9, - Events: []msgpack.RawMessage{blockRemovedPayload}, - DataParallelRank: nil, + // Create vLLM msgpack event batch in array format: [timestamp, [event1, event2, ...], data_parallel_rank] + removeEventBatch := []any{ + float64(time.Now().UnixNano()) / 1e9, + [][]byte{blockRemovedPayload}, + nil, } if err := publisher.PublishEvent(ctx, topic, removeEventBatch); err != nil { diff --git a/examples/kv_events/pod_reconciler/pod_reconciler.go b/examples/kv_events/pod_reconciler/pod_reconciler.go index d43aaf29..df6b8fa3 100644 --- a/examples/kv_events/pod_reconciler/pod_reconciler.go +++ b/examples/kv_events/pod_reconciler/pod_reconciler.go @@ -32,6 +32,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/engineadapter" "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" ) @@ -44,8 +45,10 @@ type PodReconcilerConfig struct { PodNamespace string // TopicFilter is the ZMQ subscription filter (e.g., "kv@"). TopicFilter string - // SocketPort is the port where vLLM pods expose ZMQ (default: 5557). + // SocketPort is the port where LLM pods expose ZMQ (default: 5557). SocketPort string + // EngineType specifies which LLM engine type this reconciler manages. + EngineType string } // NewPodReconcilerConfig creates a PodReconcilerConfig from kvevents.PodDiscoveryConfig. @@ -71,6 +74,7 @@ func NewPodReconcilerConfig(cfg *kvevents.PodDiscoveryConfig, topicFilter string PodNamespace: cfg.PodNamespace, TopicFilter: topicFilter, SocketPort: fmt.Sprintf("%d", socketPort), + EngineType: cfg.EngineType, }, nil } @@ -118,13 +122,17 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R podIdentifier := req.String() endpoint := r.buildEndpoint(&pod) + // Get engine type from config (currently vLLM only) + engineType := engineadapter.EngineType(r.Config.EngineType) + debugLogger.Info("Ensuring subscriber for pod", "pod", req, "endpoint", endpoint, - "podIP", pod.Status.PodIP) + "podIP", pod.Status.PodIP, + "engineType", engineType) - if err := r.SubscriberManager.EnsureSubscriber(ctx, podIdentifier, endpoint, r.Config.TopicFilter, true); err != nil { - debugLogger.Error(err, "Failed to ensure subscriber for pod", "pod", req) + if err := r.SubscriberManager.EnsureSubscriber(ctx, podIdentifier, endpoint, r.Config.TopicFilter, engineType, true); err != nil { + debugLogger.Error(err, "Failed to ensure subscriber for pod", "pod", req, "engineType", engineType) return ctrl.Result{}, err } diff --git a/examples/kv_events/vllm/vllm_kv_cache_demo.py b/examples/kv_events/vllm/vllm_kv_cache_demo.py index b197c9f6..3d220c53 100644 --- a/examples/kv_events/vllm/vllm_kv_cache_demo.py +++ b/examples/kv_events/vllm/vllm_kv_cache_demo.py @@ -78,7 +78,7 @@ def create_llm(): disable_hybrid_kv_cache_manager=True, kv_events_config=kv_events_config, block_size=16, - prefix_caching_hash_algo="sha256_cbor", + prefix_caching_hash_algo="sha256_cbor_64bit", enable_lora=True, max_model_len=4096, ) diff --git a/pkg/kvevents/decoder/decoder.go b/pkg/kvevents/decoder/decoder.go new file mode 100644 index 00000000..755edd6d --- /dev/null +++ b/pkg/kvevents/decoder/decoder.go @@ -0,0 +1,26 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package decoder + +// Decoder defines the interface for encoding and decoding raw bytes. +type Decoder interface { + // Decode unmarshals data into the provided value. + Decode(data []byte, v interface{}) error + + // Encode marshals the provided value into bytes. + Encode(v interface{}) ([]byte, error) +} diff --git a/pkg/kvevents/decoder/msgpack.go b/pkg/kvevents/decoder/msgpack.go new file mode 100644 index 00000000..a963b307 --- /dev/null +++ b/pkg/kvevents/decoder/msgpack.go @@ -0,0 +1,48 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package decoder + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" +) + +// MsgpackDecoder implements Decoder for MessagePack format. +type MsgpackDecoder struct{} + +// NewMsgpackDecoder creates a new msgpack decoder. +func NewMsgpackDecoder() *MsgpackDecoder { + return &MsgpackDecoder{} +} + +// Decode unmarshals msgpack data into the provided value. +func (m *MsgpackDecoder) Decode(data []byte, v interface{}) error { + if err := msgpack.Unmarshal(data, v); err != nil { + return fmt.Errorf("failed to decode msgpack: %w", err) + } + return nil +} + +// Encode marshals the provided value into msgpack bytes. +func (m *MsgpackDecoder) Encode(v interface{}) ([]byte, error) { + data, err := msgpack.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to encode msgpack: %w", err) + } + return data, nil +} diff --git a/pkg/kvevents/engineadapter/adapter.go b/pkg/kvevents/engineadapter/adapter.go new file mode 100644 index 00000000..da0f0ab5 --- /dev/null +++ b/pkg/kvevents/engineadapter/adapter.go @@ -0,0 +1,75 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engineadapter + +import ( + "context" + "fmt" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/decoder" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/events" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/transport" +) + +// EngineType represents the type of LLM engine. +type EngineType string + +const ( + // EngineTypeVLLM represents the vLLM engine. + EngineTypeVLLM EngineType = "vllm" +) + +// NewAdapter creates a new engine adapter based on the engine type. +func NewAdapter(engineType EngineType) (EngineAdapter, error) { + // It looks useless right now but we're preparing for future support of other engines ;) + switch engineType { + case EngineTypeVLLM: + return NewVLLMAdapter() + default: + return nil, fmt.Errorf("unknown engine type: %s", engineType) + } +} + +// EngineAdapter defines the interface for engine-specific adapters. +// Each inference engine has its own adapter implementation that handles +// engine-specific operations. +type EngineAdapter interface { + // Transport returns the transport layer for receiving messages. + Transport() transport.Transport + + // Decoder returns the decoder for parsing message payloads. + Decoder() decoder.Decoder + + // getHashAsUint64 converts engine-specific hash formats to uint64. + getHashAsUint64(raw any) (uint64, error) + + // ReceiveAndDecode receives a message from the transport, parses it, + // decodes the payload, and returns a batch of generic events. + ReceiveAndDecode(ctx context.Context) (*events.EventBatch, error) + + // Connect establishes a connection to a remote endpoint. + Connect(ctx context.Context, endpoint string) error + + // Bind listens on a local endpoint for incoming connections. + Bind(ctx context.Context, endpoint string) error + + // SubscribeToTopic sets the topic filter for receiving messages. + SubscribeToTopic(topicFilter string) error + + // Close closes the adapter and releases all resources. + Close() error +} diff --git a/pkg/kvevents/engineadapter/vllm_adapter.go b/pkg/kvevents/engineadapter/vllm_adapter.go new file mode 100644 index 00000000..97d89560 --- /dev/null +++ b/pkg/kvevents/engineadapter/vllm_adapter.go @@ -0,0 +1,358 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engineadapter + +import ( + "context" + "encoding/binary" + "fmt" + "strings" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/decoder" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/events" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/transport" + "github.com/vmihailenco/msgpack/v5" +) + +const ( + // vLLM event type tags + eventTagBlockStored = "BlockStored" + eventTagBlockRemoved = "BlockRemoved" + eventTagAllBlocksCleared = "AllBlocksCleared" + + defaultDeviceTier = "gpu" +) + +// VLLMAdapter implements the EngineAdapter interface for vLLM engines. +type VLLMAdapter struct { + transport transport.Transport + decoder decoder.Decoder + eventConverters map[string]func([]byte) (events.GenericEvent, error) +} + +// NewVLLMAdapter creates a new vLLM adapter with ZMQ transport and msgpack decoder. +// Returns an error if the transport cannot be created. +func NewVLLMAdapter() (*VLLMAdapter, error) { + trans, err := transport.NewZMQTransport() + if err != nil { + return nil, fmt.Errorf("failed to create ZMQ transport: %w", err) + } + + adapter := &VLLMAdapter{ + transport: trans, + decoder: decoder.NewMsgpackDecoder(), + } + + // Initialize event converters map + adapter.eventConverters = map[string]func([]byte) (events.GenericEvent, error){ + eventTagBlockStored: adapter.convertBlockStoredEvent, + eventTagBlockRemoved: adapter.convertBlockRemovedEvent, + eventTagAllBlocksCleared: adapter.convertAllBlocksClearedEvent, + } + + return adapter, nil +} + +// Transport returns the Transport. +func (v *VLLMAdapter) Transport() transport.Transport { + return v.transport +} + +// Decoder returns the Decoder. +func (v *VLLMAdapter) Decoder() decoder.Decoder { + return v.decoder +} + +// getHashAsUint64 converts vLLM hash formats (uint64 or []byte) to uint64. +// This handles both legacy uint64 hashes and new []byte hashes by taking +// the last 8 bytes and interpreting them as a big-endian integer. +func (v *VLLMAdapter) getHashAsUint64(raw any) (uint64, error) { + switch val := raw.(type) { + case uint64: + return val, nil + case int64: + // msgpack can decode small integers as int64 + //nolint:gosec // int64 to uint64 conversion is safe here + return uint64(val), nil + case []byte: + if len(val) == 0 { + return 0, fmt.Errorf("hash byte slice is empty") + } + // If the slice is 8 bytes or longer, use the last 8 bytes + if len(val) >= 8 { + return binary.BigEndian.Uint64(val[len(val)-8:]), nil + } + // If the slice is shorter than 8 bytes, pad it with leading zeros + padded := make([]byte, 8) + copy(padded[8-len(val):], val) + return binary.BigEndian.Uint64(padded), nil + default: + return 0, fmt.Errorf("unsupported hash type: %T", val) + } +} + +// vLLM msgpack-specific event structures +// These structs are designed for msgpack array encoding and match vLLM's format +type msgpackVLLMEventBatch struct { + _ struct{} `msgpack:",array"` + TS float64 + Events [][]byte // Raw event bytes (decoder-agnostic) + DataParallelRank *int `msgpack:",omitempty"` +} + +type msgpackVLLMBlockStoredEvent struct { + _ struct{} `msgpack:",array"` + Tag string + BlockHashes []any // Changed from []uint64 + ParentBlockHash any // Changed from *uint64 + TokenIds []uint32 + BlockSize int + LoraID *int `msgpack:",omitempty"` + Medium *string `msgpack:",omitempty"` + LoraName *string `msgpack:",omitempty"` + ExtraKeys []any `msgpack:",omitempty"` // New field for vLLM - currently not supported +} + +type msgpackVLLMBlockRemovedEvent struct { + _ struct{} `msgpack:",array"` + Tag string + BlockHashes []any + Medium *string `msgpack:",omitempty"` +} + +type msgpackVLLMAllBlocksClearedEvent struct { + _ struct{} `msgpack:",array"` +} + +// vllmMessage represents a parsed vLLM 3-part message. +type vllmMessage struct { + Topic string + Sequence uint64 + Payload []byte +} + +// ReceiveAndDecode receives a message from the transport, parses the vLLM +// 3-part message structure, decodes the payload using the decoder, and returns +// a batch of generic events. +func (v *VLLMAdapter) ReceiveAndDecode(ctx context.Context) (*events.EventBatch, error) { + // Receive raw message parts from transport + parts, err := v.transport.Receive(ctx) + if err != nil { + return nil, fmt.Errorf("failed to receive message: %w", err) + } + + // Parse vLLM 3-part message structure + msg, err := parseVLLMMessage(parts) + if err != nil { + return nil, err + } + + // Extract metadata from topic + podID, modelName := parseVLLMTopic(msg.Topic) + + // Decode the payload into vLLM event batch using the decoder + var vllmBatch msgpackVLLMEventBatch + if err := v.decoder.Decode(msg.Payload, &vllmBatch); err != nil { + return nil, fmt.Errorf("failed to decode vLLM event batch: %w", err) + } + + // Convert vLLM events to generic events + genericEvents := make([]events.GenericEvent, len(vllmBatch.Events)) + for i, rawEventBytes := range vllmBatch.Events { + genericEvent, err := v.decodeVLLMEvent(rawEventBytes) + if err != nil { + return nil, fmt.Errorf("failed to decode vLLM event: %w", err) + } + genericEvents[i] = genericEvent + } + + return &events.EventBatch{ + Metadata: events.Metadata{ + Topic: msg.Topic, + PodID: podID, + ModelName: modelName, + Sequence: msg.Sequence, + Engine: "vllm", + }, + Timestamp: vllmBatch.TS, + Events: genericEvents, + }, nil +} + +// parseVLLMMessage validates and parses a vLLM 3-part message structure. +// vLLM sends messages as: [topic, sequence, payload] +// Returns an error if the message structure is invalid. +func parseVLLMMessage(parts [][]byte) (*vllmMessage, error) { + if len(parts) != 3 { + return nil, fmt.Errorf("expected 3 message parts from vLLM, got %d", len(parts)) + } + + topic := string(parts[0]) + sequenceBytes := parts[1] + payload := parts[2] + + // Parse sequence number (8 bytes, big-endian uint64) + if len(sequenceBytes) != 8 { + return nil, fmt.Errorf("invalid sequence bytes length: %d", len(sequenceBytes)) + } + sequence := binary.BigEndian.Uint64(sequenceBytes) + + return &vllmMessage{ + Topic: topic, + Sequence: sequence, + Payload: payload, + }, nil +} + +// parseVLLMTopic extracts pod ID and model name from vLLM topic format. +// Expected format: "pod_id@model_name" +// TODO: Find a way to avoid it +func parseVLLMTopic(topic string) (podID, modelName string) { + parts := strings.SplitN(topic, "@", 2) + if len(parts) == 2 { + return parts[0], parts[1] + } + // Fallback if format is unexpected + return topic, "" +} + +// decodeVLLMEvent decodes a single vLLM event using the decoder and converts it to a generic event. +// vLLM events are tagged unions: [tag, ...fields] +func (v *VLLMAdapter) decodeVLLMEvent(rawEventBytes []byte) (events.GenericEvent, error) { + // First decode to extract just the tag + var taggedUnion []any + if err := v.decoder.Decode(rawEventBytes, &taggedUnion); err != nil { + return nil, fmt.Errorf("failed to decode tagged union: %w", err) + } + + if len(taggedUnion) < 1 { + return nil, fmt.Errorf("malformed tagged union: no tag") + } + + // Extract the event type tag + tag, ok := taggedUnion[0].(string) + if !ok { + return nil, fmt.Errorf("event tag is not a string: %T", taggedUnion[0]) + } + + // Dispatch to appropriate converter + converter, exists := v.eventConverters[tag] + if !exists { + return nil, fmt.Errorf("unknown vLLM event tag: %s", tag) + } + + return converter(rawEventBytes) +} + +// convertBlockStoredEvent decodes and converts a msgpack vLLM BlockStored event to a generic event. +// Parses all hashes from engine-specific formats to uint64. +func (v *VLLMAdapter) convertBlockStoredEvent(rawEventBytes []byte) (events.GenericEvent, error) { + var vllmEvent msgpackVLLMBlockStoredEvent + if err := msgpack.Unmarshal(rawEventBytes, &vllmEvent); err != nil { + return nil, fmt.Errorf("failed to decode BlockStored event: %w", err) + } + + deviceTier := defaultDeviceTier + if vllmEvent.Medium != nil { + deviceTier = strings.ToLower(*vllmEvent.Medium) + } + + // Parse block hashes + blockHashes := make([]uint64, 0, len(vllmEvent.BlockHashes)) + for _, rawHash := range vllmEvent.BlockHashes { + hash, err := v.getHashAsUint64(rawHash) + if err != nil { + return nil, fmt.Errorf("failed to parse block hash: %w", err) + } + blockHashes = append(blockHashes, hash) + } + + // Parse parent hash + var parentHash uint64 + if vllmEvent.ParentBlockHash != nil { + hash, err := v.getHashAsUint64(vllmEvent.ParentBlockHash) + if err != nil { + return nil, fmt.Errorf("failed to parse parent hash: %w", err) + } + parentHash = hash + } + + return &events.BlockStoredEvent{ + BlockHashes: blockHashes, + Tokens: vllmEvent.TokenIds, + ParentHash: parentHash, + DeviceTier: deviceTier, + LoraID: vllmEvent.LoraID, + LoraName: vllmEvent.LoraName, + }, nil +} + +// convertBlockRemovedEvent decodes and converts a msgpack vLLM BlockRemoved event to a generic event. +// Parses all hashes from engine-specific formats to uint64. +func (v *VLLMAdapter) convertBlockRemovedEvent(rawEventBytes []byte) (events.GenericEvent, error) { + var vllmEvent msgpackVLLMBlockRemovedEvent + if err := msgpack.Unmarshal(rawEventBytes, &vllmEvent); err != nil { + return nil, fmt.Errorf("failed to decode BlockRemoved event: %w", err) + } + + deviceTier := defaultDeviceTier + if vllmEvent.Medium != nil { + deviceTier = strings.ToLower(*vllmEvent.Medium) + } + + // Parse block hashes + blockHashes := make([]uint64, 0, len(vllmEvent.BlockHashes)) + for _, rawHash := range vllmEvent.BlockHashes { + hash, err := v.getHashAsUint64(rawHash) + if err != nil { + return nil, fmt.Errorf("failed to parse block hash: %w", err) + } + blockHashes = append(blockHashes, hash) + } + + return &events.BlockRemovedEvent{ + BlockHashes: blockHashes, + DeviceTier: deviceTier, + }, nil +} + +// convertAllBlocksClearedEvent converts an AllBlocksCleared event. +func (v *VLLMAdapter) convertAllBlocksClearedEvent(rawEventBytes []byte) (events.GenericEvent, error) { + return &events.AllBlocksClearedEvent{}, nil +} + +// TODO: not sure if it best to keep or remove these + +// Connect establishes a connection to a remote vLLM endpoint. +func (v *VLLMAdapter) Connect(ctx context.Context, endpoint string) error { + return v.transport.Connect(ctx, endpoint) +} + +// Bind listens on a local endpoint for incoming vLLM connections. +func (v *VLLMAdapter) Bind(ctx context.Context, endpoint string) error { + return v.transport.Bind(ctx, endpoint) +} + +// SubscribeToTopic sets the topic filter for receiving vLLM messages. +func (v *VLLMAdapter) SubscribeToTopic(topicFilter string) error { + return v.transport.Subscribe(topicFilter) +} + +// Close closes the adapter and releases all resources. +func (v *VLLMAdapter) Close() error { + return v.transport.Close() +} diff --git a/pkg/kvevents/engineadapter/vllm_adapter_test.go b/pkg/kvevents/engineadapter/vllm_adapter_test.go new file mode 100644 index 00000000..cc6e9029 --- /dev/null +++ b/pkg/kvevents/engineadapter/vllm_adapter_test.go @@ -0,0 +1,330 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engineadapter + +import ( + "encoding/binary" + "testing" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/events" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" +) + +// TestParseVLLMMessage_Valid tests parsing a valid 3-part vLLM message. +func TestParseVLLMMessage_Valid(t *testing.T) { + topic := []byte("pod-123@llama-2-7b") + sequence := make([]byte, 8) + binary.BigEndian.PutUint64(sequence, 42) + payload := []byte("test payload") + + parts := [][]byte{topic, sequence, payload} + + msg, err := parseVLLMMessage(parts) + require.NoError(t, err) + assert.Equal(t, "pod-123@llama-2-7b", msg.Topic) + assert.Equal(t, uint64(42), msg.Sequence) + assert.Equal(t, payload, msg.Payload) +} + +// TestParseVLLMMessage_InvalidParts tests error handling for messages with invalid parts number. +func TestParseVLLMMessage_TooFewParts(t *testing.T) { + parts := [][]byte{ + []byte("topic"), + []byte("sequence"), + } + + msg, err := parseVLLMMessage(parts) + assert.Error(t, err) + assert.Nil(t, msg) + assert.Contains(t, err.Error(), "expected 3 message parts") +} + +// TestParseVLLMMessage_TooManyParts tests error handling for messages with more than 3 parts. +func TestParseVLLMMessage_TooManyParts(t *testing.T) { + sequence := make([]byte, 8) + binary.BigEndian.PutUint64(sequence, 1) + + parts := [][]byte{ + []byte("topic"), + sequence, + []byte("payload"), + []byte("extra"), + } + + msg, err := parseVLLMMessage(parts) + assert.Error(t, err) + assert.Nil(t, msg) + assert.Contains(t, err.Error(), "expected 3 message parts") +} + +// TestParseVLLMMessage_InvalidSequenceLength tests error handling for invalid sequence byte length. +// vLLM sends sequence numbers as 8-byte big-endian uint64. This test verifies that parseVLLMMessage +// correctly rejects messages with sequence byte lengths that are not exactly 8 bytes (empty, too short, or too long). +func TestParseVLLMMessage_InvalidSequenceLength(t *testing.T) { + testCases := []struct { + name string + sequenceLength int + }{ + {"empty", 0}, + {"too short", 4}, + {"too long", 16}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parts := [][]byte{ + []byte("topic"), + make([]byte, tc.sequenceLength), + []byte("payload"), + } + + msg, err := parseVLLMMessage(parts) + assert.Error(t, err) + assert.Nil(t, msg) + assert.Contains(t, err.Error(), "invalid sequence bytes length") + }) + } +} + +// TestParseVLLMMessage_MaxSequence tests parsing with maximum sequence number. +func TestParseVLLMMessage_MaxSequence(t *testing.T) { + sequence := make([]byte, 8) + binary.BigEndian.PutUint64(sequence, ^uint64(0)) // Max uint64 + + parts := [][]byte{ + []byte("topic"), + sequence, + []byte("payload"), + } + + msg, err := parseVLLMMessage(parts) + require.NoError(t, err) + assert.Equal(t, ^uint64(0), msg.Sequence) +} + +// TestDecodeVLLMEvent_BlockStored tests decoding a valid BlockStored event without LoRA. +func TestDecodeVLLMEvent_BlockStored(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Create a BlockStored event without LoRA + // Note: With array encoding, all fields must be present (nil for unused LoRA fields) + vllmEvent := []any{ + "BlockStored", // Tag + []any{uint64(100), uint64(101)}, // BlockHashes + uint64(99), // ParentBlockHash + []uint32{1, 2, 3}, // TokenIds + 16, // BlockSize + nil, // LoraID (nil when not using LoRA) + "gpu", // Medium + nil, // LoraName (nil when not using LoRA) + nil, // ExtraKeys + } + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + require.NoError(t, err) + require.NotNil(t, event) + + blockStored, ok := event.(*events.BlockStoredEvent) + require.True(t, ok, "expected BlockStoredEvent") + assert.Equal(t, []uint64{100, 101}, blockStored.BlockHashes) + assert.Equal(t, uint64(99), blockStored.ParentHash) + assert.Equal(t, []uint32{1, 2, 3}, blockStored.Tokens) + assert.Equal(t, "gpu", blockStored.DeviceTier) + assert.Nil(t, blockStored.LoraID) + assert.Nil(t, blockStored.LoraName) +} + +// TestDecodeVLLMEvent_BlockStoredWithLora tests decoding a valid BlockStored event +func TestDecodeVLLMEvent_BlockStoredWithLora(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Create a BlockStored event with LoRA fields + vllmEvent := []any{ + "BlockStored", // Tag (not part of struct) + []any{uint64(200), uint64(201)}, // BlockHashes + uint64(199), // ParentBlockHash + []uint32{4, 5, 6}, // TokenIds + 32, // BlockSize + 42, // LoraID + "gpu", // Medium + "test-lora", // LoraName + nil, // ExtraKeys + } + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + require.NoError(t, err) + require.NotNil(t, event) + + blockStored, ok := event.(*events.BlockStoredEvent) + require.True(t, ok, "expected BlockStoredEvent") + assert.Equal(t, []uint64{200, 201}, blockStored.BlockHashes) + assert.Equal(t, uint64(199), blockStored.ParentHash) + assert.Equal(t, []uint32{4, 5, 6}, blockStored.Tokens) + assert.Equal(t, "gpu", blockStored.DeviceTier) + require.NotNil(t, blockStored.LoraID) + assert.Equal(t, 42, *blockStored.LoraID) + require.NotNil(t, blockStored.LoraName) + assert.Equal(t, "test-lora", *blockStored.LoraName) +} + +// TestDecodeVLLMEvent_BlockStoredMissingLoraName tests decoding with missing field. +func TestDecodeVLLMEvent_BlockStoredMissingLoraName(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Create a BlockStored event with LoraID but no LoraName (6 fields - invalid) + // vLLM should send either all LoRA fields or none + vllmEvent := []any{ + "BlockStored", // Tag (not part of struct) + []any{uint64(300), uint64(301)}, // BlockHashes + uint64(299), // ParentBlockHash + []uint32{7, 8, 9}, // TokenIds + 64, // BlockSize + 123, // LoraID + "gpu", // Medium + } + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + assert.Error(t, err) + assert.Nil(t, event) +} + +// TestDecodeVLLMEvent_BlockRemoved tests decoding a valid BlockRemoved event. +func TestDecodeVLLMEvent_BlockRemoved(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Create a BlockRemoved event in vLLM format + // The struct has 2 fields: BlockHashes, Medium + medium := "cpu" + vllmEvent := []any{ + "BlockRemoved", // Tag (not part of struct) + []any{uint64(200), uint64(201), uint64(202)}, // BlockHashes + &medium, // Medium (optional) + } + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + require.NoError(t, err) + require.NotNil(t, event) + + blockRemoved, ok := event.(*events.BlockRemovedEvent) + require.True(t, ok, "expected BlockRemovedEvent") + assert.Equal(t, []uint64{200, 201, 202}, blockRemoved.BlockHashes) + assert.Equal(t, "cpu", blockRemoved.DeviceTier) +} + +// TestDecodeVLLMEvent_AllBlocksCleared tests decoding a valid AllBlocksCleared event. +func TestDecodeVLLMEvent_AllBlocksCleared(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Create an AllBlocksCleared event in vLLM format + vllmEvent := []any{"AllBlocksCleared"} + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + require.NoError(t, err) + require.NotNil(t, event) + + _, ok := event.(*events.AllBlocksClearedEvent) + require.True(t, ok, "expected AllBlocksClearedEvent") +} + +// TestDecodeVLLMEvent_UnknownTag tests error handling for unknown event tags. +func TestDecodeVLLMEvent_UnknownTag(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + vllmEvent := []any{"UnknownEventType", "some", "data"} + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + assert.Error(t, err) + assert.Nil(t, event) + assert.Contains(t, err.Error(), "unknown vLLM event tag") +} + +// TestDecodeVLLMEvent_MalformedPayload tests error handling for malformed msgpack data. +func TestDecodeVLLMEvent_MalformedPayload(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Invalid msgpack data + rawBytes := []byte{0xFF, 0xFF, 0xFF} + + event, err := adapter.decodeVLLMEvent(rawBytes) + assert.Error(t, err) + assert.Nil(t, event) +} + +// TestDecodeVLLMEvent_EmptyPayload tests error handling for empty event bytes. +func TestDecodeVLLMEvent_EmptyPayload(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + rawBytes := []byte{} + + event, err := adapter.decodeVLLMEvent(rawBytes) + assert.Error(t, err) + assert.Nil(t, event) +} + +// TestDecodeVLLMEvent_MissingTag tests error handling for events without a tag. +func TestDecodeVLLMEvent_MissingTag(t *testing.T) { + adapter, err := NewVLLMAdapter() + require.NoError(t, err) + defer adapter.Close() + + // Empty array - no tag + vllmEvent := []any{} + + rawBytes, err := msgpack.Marshal(vllmEvent) + require.NoError(t, err) + + event, err := adapter.decodeVLLMEvent(rawBytes) + assert.Error(t, err) + assert.Nil(t, event) + assert.Contains(t, err.Error(), "malformed tagged union") +} diff --git a/pkg/kvevents/events.go b/pkg/kvevents/events.go deleted file mode 100644 index 94ad2795..00000000 --- a/pkg/kvevents/events.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2025 The llm-d Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvevents - -import ( - "fmt" - - "github.com/vmihailenco/msgpack/v5" -) - -const ( - // BlockStoredEventTag is the tag for BlockStored events. - BlockStoredEventTag = "BlockStored" - // BlockRemovedEventTag is the tag for BlockRemoved events. - BlockRemovedEventTag = "BlockRemoved" - // AllBlocksClearedEventTag is the tag for AllBlocksCleared events. - AllBlocksClearedEventTag = "AllBlocksCleared" -) - -// event is a marker interface for KV-cache events. -type event interface { - isEvent() - ToTaggedUnion() []any -} - -// EventBatch represents a batch of events. -// It is encoded as an array to match vLLM's format. -type EventBatch struct { - _ struct{} `msgpack:",array"` - TS float64 - Events []msgpack.RawMessage - DataParallelRank *int `msgpack:",omitempty"` -} - -// BlockStored event. -// The BlockHashes and ParentBlockHash fields are `any` to handle -// both the legacy uint64 format and the new []byte format from vLLM. -type BlockStored struct { - _ struct{} `msgpack:",array"` - BlockHashes []any // Changed from []uint64 - ParentBlockHash any // Changed from *uint64 - TokenIds []uint32 - BlockSize int - LoraID *int `msgpack:",omitempty"` - Medium *string `msgpack:",omitempty"` - LoraName *string `msgpack:",omitempty"` -} - -// ToTaggedUnion converts the BlockStored event to a tagged union format. -// -//nolint:gocritic // Keeping the receiver as a value -func (bs BlockStored) ToTaggedUnion() []any { - return []any{ - BlockStoredEventTag, - bs.BlockHashes, - bs.ParentBlockHash, - bs.TokenIds, - bs.BlockSize, - bs.LoraID, - bs.Medium, - bs.LoraName, - } -} - -func (BlockStored) isEvent() {} - -// BlockRemoved event. -// The BlockHashes field is `any` to handle both uint64 and []byte formats. -type BlockRemoved struct { - _ struct{} `msgpack:",array"` - BlockHashes []any - Medium *string `msgpack:",omitempty"` -} - -func (br BlockRemoved) ToTaggedUnion() []any { - return []any{ - BlockRemovedEventTag, - br.BlockHashes, - br.Medium, - } -} - -func (BlockRemoved) isEvent() {} - -// AllBlocksCleared event. -type AllBlocksCleared struct { - _ struct{} `msgpack:",array"` -} - -func (ac AllBlocksCleared) ToTaggedUnion() []any { - return []any{ - AllBlocksClearedEventTag, - } -} - -func (AllBlocksCleared) isEvent() {} - -// UnmarshalKVEvent unmarshals a raw msgpack event into the event interface. -func UnmarshalKVEvent(rawEvent msgpack.RawMessage) (event, error) { - var taggedUnion []msgpack.RawMessage - if err := msgpack.Unmarshal(rawEvent, &taggedUnion); err != nil { - return nil, fmt.Errorf("failed to unmarshal tagged union: %w", err) - } - - if len(taggedUnion) < 1 { - return nil, fmt.Errorf("malformed tagged union: no tag") - } - - var tag string - if err := msgpack.Unmarshal(taggedUnion[0], &tag); err != nil { - return nil, fmt.Errorf("failed to unmarshal tag: %w", err) - } - - payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) - if err != nil { - return nil, fmt.Errorf("failed to re-marshal payload parts: %w", err) - } - - var unmarshalErr error - switch tag { - case BlockStoredEventTag: - var bs BlockStored - unmarshalErr = msgpack.Unmarshal(payloadBytes, &bs) - return bs, unmarshalErr - - case BlockRemovedEventTag: - var br BlockRemoved - unmarshalErr = msgpack.Unmarshal(payloadBytes, &br) - return br, unmarshalErr - - case AllBlocksClearedEventTag: - var ac AllBlocksCleared - unmarshalErr = msgpack.Unmarshal(payloadBytes, &ac) - return ac, unmarshalErr - - default: - return nil, fmt.Errorf("unknown event tag: %s", tag) - } -} diff --git a/pkg/kvevents/events/all_blocks_cleared.go b/pkg/kvevents/events/all_blocks_cleared.go new file mode 100644 index 00000000..6bd0ac42 --- /dev/null +++ b/pkg/kvevents/events/all_blocks_cleared.go @@ -0,0 +1,50 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package events + +import ( + "context" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// AllBlocksClearedEvent represents all blocks being cleared from a pod's cache. +type AllBlocksClearedEvent struct { + DeviceTier string +} + +// Type returns the event type. +func (e *AllBlocksClearedEvent) Type() EventType { + return EventTypeAllBlocksCleared +} + +// Process processes the AllBlocksCleared event and updates the index. +// This removes all entries for the pod from the index. +func (e *AllBlocksClearedEvent) Process(ctx context.Context, index kvblock.Index, + tokenProcessor kvblock.TokenProcessor, podIdentifier, modelName string) error { + + logger := log.FromContext(ctx) + + // For now, we just log the event. + logger.Info("All blocks cleared event received", + "podIdentifier", podIdentifier, + "deviceTier", e.DeviceTier, + "modelName", modelName) + + return nil +} diff --git a/pkg/kvevents/events/block_removed.go b/pkg/kvevents/events/block_removed.go new file mode 100644 index 00000000..21b363d7 --- /dev/null +++ b/pkg/kvevents/events/block_removed.go @@ -0,0 +1,61 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package events + +import ( + "context" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// BlockRemovedEvent represents blocks being evicted from the cache. +// All hashes are already parsed to uint64 by the adapter. +type BlockRemovedEvent struct { + BlockHashes []uint64 + DeviceTier string +} + +// Type returns the event type. +func (e *BlockRemovedEvent) Type() EventType { + return EventTypeBlockRemoved +} + +// Process processes the BlockRemoved event and updates the index. +func (e *BlockRemovedEvent) Process(ctx context.Context, index kvblock.Index, + tokenProcessor kvblock.TokenProcessor, podIdentifier, modelName string) error { + + debugLogger := log.FromContext(ctx).V(logging.DEBUG) + + // Create PodEntry for this event's device tier + podEntries := []kvblock.PodEntry{{ + PodIdentifier: podIdentifier, + DeviceTier: e.DeviceTier, + }} + + // Evict each block + for _, hash := range e.BlockHashes { + engineKey := kvblock.BlockHash(hash) + if err := index.Evict(ctx, engineKey, podEntries); err != nil { + debugLogger.Error(err, "Failed to evict block from index", + "engineKey", engineKey, "podIdentifier", podIdentifier) + // Continue processing other blocks even if one fails + } + } + + return nil +} diff --git a/pkg/kvevents/events/block_stored.go b/pkg/kvevents/events/block_stored.go new file mode 100644 index 00000000..3eb7f1e5 --- /dev/null +++ b/pkg/kvevents/events/block_stored.go @@ -0,0 +1,92 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package events + +import ( + "context" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// BlockStoredEvent represents blocks being added to the cache. +type BlockStoredEvent struct { + BlockHashes []uint64 + Tokens []uint32 + ParentHash uint64 + DeviceTier string + LoraID *int + LoraName *string +} + +// Type returns the event type. +func (e *BlockStoredEvent) Type() EventType { + return EventTypeBlockStored +} + +// Process processes the BlockStored event and updates the index. +func (e *BlockStoredEvent) Process(ctx context.Context, index kvblock.Index, + tokenProcessor kvblock.TokenProcessor, podIdentifier, modelName string) error { + + debugLogger := log.FromContext(ctx).V(logging.DEBUG) + + // Use LoRA name as model identifier if available, otherwise fall back to base model name + effectiveModelName := modelName + if e.LoraName != nil && *e.LoraName != "" { + effectiveModelName = *e.LoraName + } + + // Create PodEntry for this event's device tier + podEntries := []kvblock.PodEntry{{ + PodIdentifier: podIdentifier, + DeviceTier: e.DeviceTier, + }} + + // Convert block hashes to BlockHash type + engineKeys := make([]kvblock.BlockHash, len(e.BlockHashes)) + for i, hash := range e.BlockHashes { + engineKeys[i] = kvblock.BlockHash(hash) + } + + // Get parent request key if parent hash exists + parentRequestKey := kvblock.EmptyBlockHash + if e.ParentHash != 0 { + parentEngineKey := kvblock.BlockHash(e.ParentHash) + key, err := index.GetRequestKey(ctx, parentEngineKey) + if err != nil { + debugLogger.Error(err, "Failed to get request key for parent block", + "parentEngineKey", parentEngineKey) + } else { + parentRequestKey = key + } + } + + // Compute request keys from tokens using effective model name + requestKeys := tokenProcessor.TokensToKVBlockKeys(parentRequestKey, e.Tokens, effectiveModelName) + + // Only proceed if we have valid keys to add. + if len(engineKeys) > 0 { + if err := index.Add(ctx, engineKeys, requestKeys, podEntries); err != nil { + debugLogger.Error(err, "Failed to add blocks to index", + "podIdentifier", podIdentifier, "modelName", modelName) + return err + } + } + + return nil +} diff --git a/pkg/kvevents/events/event.go b/pkg/kvevents/events/event.go new file mode 100644 index 00000000..97f4fe1b --- /dev/null +++ b/pkg/kvevents/events/event.go @@ -0,0 +1,68 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package events + +import ( + "context" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" +) + +// EventType represents the type of KV-cache event. +type EventType string + +// For logs +const ( + // EventTypeBlockStored indicates blocks were added to cache. + EventTypeBlockStored EventType = "block_stored" + // EventTypeBlockRemoved indicates blocks were evicted from cache. + EventTypeBlockRemoved EventType = "block_removed" + // EventTypeAllBlocksCleared indicates entire cache was cleared. + EventTypeAllBlocksCleared EventType = "all_blocks_cleared" +) + +// GenericEvent represents a KV-cache events containing already-parsed data. +type GenericEvent interface { + // Type returns the event type. + Type() EventType + + // Process processes the event and updates the index. + Process(ctx context.Context, index kvblock.Index, tokenProcessor kvblock.TokenProcessor, + podIdentifier, modelName string) error +} + +// Metadata contains information about the source of an event batch. +type Metadata struct { + // Topic is the original transport topic. + Topic string + // PodID identifies the pod that generated these events. + PodID string + // ModelName is the model associated with these events. + ModelName string + // Sequence is the message sequence number from the transport. + Sequence uint64 + // Engine identifies which inference engine generated these events. + Engine string +} + +// EventBatch represents a batch of generic events from an inference engine. +// This is the primary data structure passed from adapters to the pool for processing. +type EventBatch struct { + Metadata Metadata + Timestamp float64 + Events []GenericEvent +} diff --git a/pkg/kvevents/pool.go b/pkg/kvevents/pool.go index 53db3200..53a6a46a 100644 --- a/pkg/kvevents/pool.go +++ b/pkg/kvevents/pool.go @@ -16,17 +16,15 @@ package kvevents import ( "context" - "encoding/binary" - "fmt" "hash/fnv" - "strings" "sync" - "github.com/vmihailenco/msgpack/v5" "k8s.io/client-go/util/workqueue" "sigs.k8s.io/controller-runtime/pkg/log" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/engineadapter" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/events" "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" ) @@ -60,17 +58,23 @@ type PodDiscoveryConfig struct { // PodNamespace limits the reconciler to watch pods in a specific namespace. // If empty, watches all namespaces (requires appropriate RBAC). PodNamespace string `json:"podNamespace,omitempty"` - // SocketPort is the port number where vLLM pods expose their ZMQ socket. + // SocketPort is the port number where LLM pods expose their ZMQ socket. // The reconciler will connect to tcp://: // Default: 5557 SocketPort int `json:"socketPort"` + // EngineType specifies which LLM engine type this reconciler manages. + // This determines which adapter will be used for subscribers. + // Default: "vllm" + EngineType string `json:"engineType"` } // DefaultPodReconcilerConfig returns a default configuration for the pod reconciler. +// Defaults to vLLM engine type. func DefaultPodReconcilerConfig() *PodDiscoveryConfig { return &PodDiscoveryConfig{ PodLabelSelector: defaultPodSelector, SocketPort: 5557, + EngineType: string(engineadapter.EngineTypeVLLM), } } @@ -84,23 +88,10 @@ func DefaultConfig() *Config { } } -// Message represents a message that is read from a ZMQ topic. -type Message struct { - Topic string - Payload []byte - // Sequence number of the message - Seq uint64 - // PodIdentifier is the identifier of the pod that sent the event. - // This will be extracted from the ZMQ topic. - PodIdentifier string - // ModelName is the name of the model that is associated with this event. - ModelName string -} - -// Pool is a sharded worker pool that processes events from ZMQ subscribers. +// Pool is a sharded worker pool that processes event batches from engine adapters. // It ensures that events for the same PodIdentifier are processed in order. type Pool struct { - queues []workqueue.TypedRateLimitingInterface[*Message] + queues []workqueue.TypedRateLimitingInterface[*events.EventBatch] concurrency int // can replace use with len(queues) index kvblock.Index tokenProcessor kvblock.TokenProcessor @@ -116,14 +107,14 @@ func NewPool(cfg *Config, index kvblock.Index, tokenProcessor kvblock.TokenProce } p := &Pool{ - queues: make([]workqueue.TypedRateLimitingInterface[*Message], cfg.Concurrency), + queues: make([]workqueue.TypedRateLimitingInterface[*events.EventBatch], cfg.Concurrency), concurrency: cfg.Concurrency, index: index, tokenProcessor: tokenProcessor, } for i := 0; i < p.concurrency; i++ { - p.queues[i] = workqueue.NewTypedRateLimitingQueue(workqueue.DefaultTypedControllerRateLimiter[*Message]()) + p.queues[i] = workqueue.NewTypedRateLimitingQueue(workqueue.DefaultTypedControllerRateLimiter[*events.EventBatch]()) } return p @@ -155,42 +146,42 @@ func (p *Pool) Shutdown(ctx context.Context) { logger.Info("event processing pool shut down.") } -// AddTask is called by the subscriber to add a message to the processing queue. -// It hashes the PodIdentifier to select a queue, ensuring messages for the +// AddTask is called by the subscriber to add an event batch to the processing queue. +// It hashes the PodID to select a queue, ensuring events for the // same pod always go to the same worker (ordered queue). -func (p *Pool) AddTask(task *Message) { +func (p *Pool) AddTask(batch *events.EventBatch) { // Use an FNV-1a hash to deterministically select a queue. // TODO: round-robin or simpler approach could be good enough h := fnv.New32a() - _, err := h.Write([]byte(task.PodIdentifier)) + _, err := h.Write([]byte(batch.Metadata.PodID)) if err != nil { return } //nolint:gosec // if concurrency overflows then the world is in trouble anyway queueIndex := h.Sum32() % uint32(p.concurrency) - p.queues[queueIndex].Add(task) + p.queues[queueIndex].Add(batch) } // worker is the main processing loop for a single worker goroutine. -// It processes messages from its dedicated queue using the workqueue pattern. +// It processes event batches from its dedicated queue using the workqueue pattern. // TODO: profile and benchmark cases like backpressure, slow processing (profile), etc. func (p *Pool) worker(ctx context.Context, workerIndex int) { defer p.wg.Done() queue := p.queues[workerIndex] for { - task, shutdown := queue.Get() + batch, shutdown := queue.Get() if shutdown { return } // Use a nested func to ensure Done is always called. - func(task *Message) { - defer queue.Done(task) - p.processEvent(ctx, task) + func(batch *events.EventBatch) { + defer queue.Done(batch) + p.processEventBatch(ctx, batch) // Task succeeded, remove it from the queue. - queue.Forget(task) - }(task) + queue.Forget(batch) + }(batch) // Check if context was cancelled after processing a task. select { @@ -201,163 +192,26 @@ func (p *Pool) worker(ctx context.Context, workerIndex int) { } } -// processEvent deserializes the message payload and calls the appropriate -// index method based on the event type. It returns an error to trigger retries. -func (p *Pool) processEvent(ctx context.Context, msg *Message) { - debugLogger := log.FromContext(ctx).V(logging.DEBUG) - debugLogger.V(logging.TRACE).Info("Processing event", "topic", msg.Topic, "seq", msg.Seq) - - var eventBatch EventBatch - if err := msgpack.Unmarshal(msg.Payload, &eventBatch); err != nil { - // This is likely a "poison pill" message that can't be unmarshalled. - // We log the error but return nil to prevent it from being retried indefinitely. - debugLogger.Error(err, "Failed to unmarshal event batch, dropping message") - return - } - - events := make([]event, 0, len(eventBatch.Events)) - for _, rawEvent := range eventBatch.Events { - event, err := UnmarshalKVEvent(rawEvent) - if err != nil { - debugLogger.Error(err, "Failed to unmarshal event, skipping") - continue - } - events = append(events, event) - } - - podIdentifier := msg.PodIdentifier - modelName := msg.ModelName - p.digestEvents(ctx, podIdentifier, modelName, events) -} - -func (p *Pool) digestEvents(ctx context.Context, podIdentifier, modelName string, - events []event, -) { +// processEventBatch processes a batch of generic events by calling each event's Process method. +func (p *Pool) processEventBatch(ctx context.Context, batch *events.EventBatch) { debugLogger := log.FromContext(ctx).V(logging.DEBUG) - debugLogger.V(logging.TRACE).Info("Digesting events", "count", len(events)) - - // Process each event in the batch - for _, event := range events { - switch ev := event.(type) { - case BlockStored: - // Default to gpu. - // For non-gpu events, vLLM KV event has a non-empty Medium field. - deviceTier := defaultEventSourceDeviceTier - if ev.Medium != nil { - deviceTier = strings.ToLower(*ev.Medium) - } - - // Use LoRA name as model identifier if available, otherwise fall back to base model name. - effectiveModelName := modelName - if ev.LoraName != nil && *ev.LoraName != "" { - effectiveModelName = *ev.LoraName - } - - // Create PodEntry for this specific event's device tier - podEntries := []kvblock.PodEntry{{PodIdentifier: podIdentifier, DeviceTier: deviceTier}} - - // Create a slice to hold the processed keys. - engineKeys := make([]kvblock.BlockHash, 0, len(ev.BlockHashes)) - - // Iterate over the hashes, convert each one to uint64, and create a key. - for _, rawHash := range ev.BlockHashes { - hash, err := getHashAsUint64(rawHash) - if err != nil { - debugLogger.Error(err, "Failed to convert block hash for BlockStored event", "rawHash", rawHash) - continue - } - engineKeys = append(engineKeys, kvblock.BlockHash(hash)) - } - - var parentRequestKey kvblock.BlockHash - if ev.ParentBlockHash != nil { - hash, err := getHashAsUint64(ev.ParentBlockHash) - if err != nil { - debugLogger.Error(err, "Failed to convert parent block hash for BlockStored event", - "rawHash", ev.ParentBlockHash) - continue - } - - parentEngineKey := kvblock.BlockHash(hash) - - key, err := p.index.GetRequestKey(ctx, parentEngineKey) - if err != nil { - debugLogger.Error(err, "Failed to get request key for parent block", - "parentEngineKey", parentEngineKey, "effectiveModelName", effectiveModelName) - continue - } - parentRequestKey = key - } - - requestKeys := p.tokenProcessor.TokensToKVBlockKeys(parentRequestKey, ev.TokenIds, effectiveModelName) - - // Only proceed if we have valid keys to add. - if len(engineKeys) > 0 { - if err := p.index.Add(ctx, engineKeys, requestKeys, podEntries); err != nil { - debugLogger.Error(err, "Failed to add event to index", - "podIdentifier", podIdentifier, "event", ev) - continue // Continue processing other events even if one fails - } - } - - case BlockRemoved: - // Default to gpu. - // For non-gpu events, vLLM KV event has a non-empty Medium field. - deviceTier := defaultEventSourceDeviceTier - if ev.Medium != nil { - deviceTier = strings.ToLower(*ev.Medium) - } - - // Create PodEntry for this specific event's device tier - podEntries := []kvblock.PodEntry{{PodIdentifier: podIdentifier, DeviceTier: deviceTier}} - - // Iterate over the hashes, convert each one to uint64, and evict the key. - for _, rawHash := range ev.BlockHashes { - hash, err := getHashAsUint64(rawHash) - if err != nil { - debugLogger.Error(err, "Failed to convert block hash for BlockRemoved event", "rawHash", rawHash) - continue - } - engineKey := kvblock.BlockHash(hash) - if err := p.index.Evict(ctx, engineKey, podEntries); err != nil { - debugLogger.Error(err, "Failed to remove event from index", - "podIdentifier", podIdentifier, "event", ev) - continue // Continue processing other events even if one fails - } - } - case AllBlocksCleared: - continue - default: - debugLogger.Info("Unknown event", "podIdentifier", podIdentifier, "event", ev) - } - } -} - -// getHashAsUint64 converts a block hash from an `any` type to a uint64. -// It handles legacy uint64 hashes and new []byte hashes by taking the last 8 bytes -// and interpreting them as a big-endian integer, matching vLLM's compatibility logic. -func getHashAsUint64(hash any) (uint64, error) { - switch val := hash.(type) { - case uint64: - // Hash is already in the target format. - return val, nil - case int64: - // msgpack can decode small integers as int64. - //nolint:gosec // int64 to uint64 conversion is safe here - return uint64(val), nil - case []byte: - if len(val) == 0 { - return 0, fmt.Errorf("hash byte slice is empty") - } - // If the slice is 8 bytes or longer, use the last 8 bytes. - if len(val) >= 8 { - return binary.BigEndian.Uint64(val[len(val)-8:]), nil + debugLogger.V(logging.TRACE).Info("Processing event batch", + "podID", batch.Metadata.PodID, + "modelName", batch.Metadata.ModelName, + "engine", batch.Metadata.Engine, + "eventCount", len(batch.Events)) + + podIdentifier := batch.Metadata.PodID + modelName := batch.Metadata.ModelName + + // Process each generic event in the batch + for _, genericEvent := range batch.Events { + if err := genericEvent.Process(ctx, p.index, p.tokenProcessor, podIdentifier, modelName); err != nil { + debugLogger.Error(err, "Failed to process event", + "eventType", genericEvent.Type(), + "podIdentifier", podIdentifier, + "modelName", modelName) + // Continue processing other events even if one fails } - // If the slice is shorter than 8 bytes, pad it with leading zeros. - padded := make([]byte, 8) - copy(padded[8-len(val):], val) - return binary.BigEndian.Uint64(padded), nil - default: - return 0, fmt.Errorf("unsupported hash type: %T", val) } } diff --git a/pkg/kvevents/process_event_test.go b/pkg/kvevents/process_event_test.go deleted file mode 100644 index 97d336e3..00000000 --- a/pkg/kvevents/process_event_test.go +++ /dev/null @@ -1,103 +0,0 @@ -/* -Copyright 2025 The llm-d Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package kvevents_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - "github.com/vmihailenco/msgpack/v5" - - . "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" -) - -// Helper function to create BlockStored raw msgpack message. -func createBlockStoredRaw(t *testing.T, fields []any) msgpack.RawMessage { - t.Helper() - data, err := msgpack.Marshal(fields) - if err != nil { - t.Fatalf("Failed to marshal fields: %v", err) - } - return msgpack.RawMessage(data) -} - -func TestBlockStoredMissingLoraName(t *testing.T) { - rawMsg := createBlockStoredRaw(t, []any{ - BlockStoredEventTag, // Event tag - []any{uint64(1001), uint64(1002)}, // BlockHashes - nil, // ParentBlockHash - []uint32{1, 2, 3}, // TokenIds - 256, // BlockSize - 42, // LoraID - "GPU", // Medium - // LoraName is missing - }) - - _, err := UnmarshalKVEvent(rawMsg) - - // Expect error due to missing LoraName - require.Error(t, err) -} - -func TestBlockStoredAllFieldsPresent(t *testing.T) { - rawMsg := createBlockStoredRaw(t, []any{ - BlockStoredEventTag, // Event tag - []any{uint64(1001), uint64(1002)}, // BlockHashes - nil, // ParentBlockHash - []uint32{1, 2, 3}, // TokenIds - 256, // BlockSize - 42, // LoraID - "GPU", // Medium - "test-lora", // LoraName - }) - - event, err := UnmarshalKVEvent(rawMsg) - - require.NoError(t, err, "Expected no error during unmarshaling") - require.NotNil(t, event, "Expected event to be non-nil") - - blockStored, ok := event.(BlockStored) - require.True(t, ok, "Expected event to be of type BlockStored") - - if blockStored.Medium == nil || *blockStored.Medium != "GPU" { - t.Errorf("Expected Medium to be 'GPU', got %v", blockStored.Medium) - } - require.NotNil(t, blockStored.Medium, "Expected Medium to be non-nil") - require.Equal(t, "GPU", *blockStored.Medium, "Expected Medium to be 'GPU'") - - require.NotNil(t, blockStored.LoraName, "Expected LoraName to be non-nil") - require.Equal(t, "test-lora", *blockStored.LoraName, "Expected LoraName to be 'test-lora'") -} - -func TestUnmarshalKVEventErrors(t *testing.T) { - // Test unknown event tag - rawMsg := createBlockStoredRaw(t, []any{ - BlockStoredEventTag, // Event tag - []any{uint64(1001), uint64(1002)}, // BlockHashes - nil, // ParentBlockHash - []uint32{1, 2, 3}, // TokenIds - }) - - var err error - _, err = UnmarshalKVEvent(rawMsg) - require.Error(t, err, "Expected error for incomplete BlockStored event") - - // Test malformed union (empty array) - emptyRawMsg := createBlockStoredRaw(t, []any{}) - _, err = UnmarshalKVEvent(emptyRawMsg) - require.Error(t, err, "Expected error for malformed tagged union") -} diff --git a/pkg/kvevents/subscriber.go b/pkg/kvevents/subscriber.go new file mode 100644 index 00000000..0014160f --- /dev/null +++ b/pkg/kvevents/subscriber.go @@ -0,0 +1,145 @@ +// Copyright 2025 The llm-d Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvevents + +import ( + "context" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/engineadapter" + "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" +) + +const ( + // How long to wait before retrying to connect. + retryInterval = 5 * time.Second +) + +// subscriber connects to an engine via an adapter and forwards messages to a pool. +type subscriber struct { + pool *Pool + adapter engineadapter.EngineAdapter + endpoint string + remote bool + topicFilter string +} + +// newSubscriber creates a new generic subscriber. +func newSubscriber(pool *Pool, adapter engineadapter.EngineAdapter, endpoint, topicFilter string, remote bool) *subscriber { + return &subscriber{ + pool: pool, + adapter: adapter, + endpoint: endpoint, + remote: remote, + topicFilter: topicFilter, + } +} + +// Start connects to an engine publisher, receives messages, +// wraps them in Message structs, and pushes them into the pool. +// This loop will run until the provided context is canceled. +func (s *subscriber) Start(ctx context.Context) { + logger := log.FromContext(ctx).WithName("subscriber") + + for { + select { + case <-ctx.Done(): + logger.Info("shutting down subscriber") + return + default: + // We run the subscriber in a separate function to handle + // setup/teardown and connection retries cleanly. + s.runSubscriber(ctx) + // wait before retrying, unless the context has been canceled. + select { + case <-time.After(retryInterval): + logger.Info("retrying subscriber") + case <-ctx.Done(): + logger.Info("shutting down subscriber") + return + } + } + } +} + +// runSubscriber connects to the engine, subscribes to the topic filter, +// and listens for messages. +func (s *subscriber) runSubscriber(ctx context.Context) { + logger := log.FromContext(ctx).WithName("subscriber") + debugLogger := logger.V(logging.DEBUG) + + // Connect or bind based on mode + var err error + if s.remote { + err = s.adapter.Connect(ctx, s.endpoint) + if err != nil { + logger.Error(err, "Failed to connect to endpoint", "endpoint", s.endpoint) + return + } + logger.Info("Connected to endpoint", "endpoint", s.endpoint) + } else { + err = s.adapter.Bind(ctx, s.endpoint) + if err != nil { + logger.Error(err, "Failed to bind to endpoint", "endpoint", s.endpoint) + return + } + logger.Info("Bound to endpoint", "endpoint", s.endpoint) + } + + // Ensure cleanup + defer func() { + if err := s.adapter.Close(); err != nil { + logger.Error(err, "Failed to close adapter") + } + }() + + // Subscribe to topic filter + if err := s.adapter.SubscribeToTopic(s.topicFilter); err != nil { + logger.Error(err, "Failed to subscribe to topic filter", "topic", s.topicFilter) + return + } + + // Receive messages in a loop + for { + select { + case <-ctx.Done(): + return + default: + } + + // Receive and decode message from adapter + eventBatch, err := s.adapter.ReceiveAndDecode(ctx) + if err != nil { + if ctx.Err() != nil { + // Context was canceled, exit gracefully + return + } + debugLogger.Error(err, "Failed to receive and decode message", "endpoint", s.endpoint) + break // Exit on receive error to reconnect + } + + debugLogger.V(logging.TRACE).Info("Received event batch", + "topic", eventBatch.Metadata.Topic, + "seq", eventBatch.Metadata.Sequence, + "podIdentifier", eventBatch.Metadata.PodID, + "modelName", eventBatch.Metadata.ModelName, + "eventCount", len(eventBatch.Events)) + + // Push event batch directly to pool + s.pool.AddTask(eventBatch) + } +} diff --git a/pkg/kvevents/subscriber_manager.go b/pkg/kvevents/subscriber_manager.go index 6aca63f0..fa690336 100644 --- a/pkg/kvevents/subscriber_manager.go +++ b/pkg/kvevents/subscriber_manager.go @@ -18,13 +18,15 @@ package kvevents import ( "context" + "fmt" "sync" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/engineadapter" "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" "sigs.k8s.io/controller-runtime/pkg/log" ) -// SubscriberManager manages multiple ZMQ subscribers, one per LLM engine. +// SubscriberManager manages multiple subscribers, one per LLM engine pod. type SubscriberManager struct { pool *Pool subscribers map[string]*subscriberEntry @@ -33,7 +35,7 @@ type SubscriberManager struct { // subscriberEntry represents a single subscriber and its cancellation. type subscriberEntry struct { - subscriber *zmqSubscriber + subscriber *subscriber cancel context.CancelFunc endpoint string } @@ -50,7 +52,7 @@ func NewSubscriberManager(pool *Pool) *SubscriberManager { // If the subscriber already exists with the same endpoint, it's a no-op. // If the endpoint changed, the old subscriber is removed and a new one is created. func (sm *SubscriberManager) EnsureSubscriber(ctx context.Context, podIdentifier, endpoint, topicFilter string, - remoteSocket bool, + engineType engineadapter.EngineType, remoteSocket bool, ) error { debugLogger := log.FromContext(ctx).V(logging.DEBUG) @@ -73,9 +75,19 @@ func (sm *SubscriberManager) EnsureSubscriber(ctx context.Context, podIdentifier delete(sm.subscribers, podIdentifier) } - // Create new subscriber - debugLogger.Info("Creating new subscriber", "podIdentifier", podIdentifier, "endpoint", endpoint) - subscriber := newZMQSubscriber(sm.pool, endpoint, topicFilter, remoteSocket) + // Create new subscriber with specified engine adapter + debugLogger.Info("Creating new subscriber", + "podIdentifier", podIdentifier, + "endpoint", endpoint, + "engineType", engineType) + + // Create adapter based on engine type + adapter, err := engineadapter.NewAdapter(engineType) + if err != nil { + return fmt.Errorf("failed to create %s adapter: %w", engineType, err) + } + + subscriber := newSubscriber(sm.pool, adapter, endpoint, topicFilter, remoteSocket) // Create a context and start subscriber subCtx, cancel := context.WithCancel(ctx) @@ -88,7 +100,10 @@ func (sm *SubscriberManager) EnsureSubscriber(ctx context.Context, podIdentifier endpoint: endpoint, } - debugLogger.Info("Subscriber created and started", "podIdentifier", podIdentifier, "endpoint", endpoint) + debugLogger.Info("Subscriber created and started", + "podIdentifier", podIdentifier, + "endpoint", endpoint, + "engineType", engineType) return nil } diff --git a/pkg/kvevents/subscriber_manager_test.go b/pkg/kvevents/subscriber_manager_test.go index e89e1b6a..2b541f2d 100644 --- a/pkg/kvevents/subscriber_manager_test.go +++ b/pkg/kvevents/subscriber_manager_test.go @@ -23,6 +23,7 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/engineadapter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -46,8 +47,9 @@ func TestSubscriberManager_EnsureSubscriber(t *testing.T) { podID := "default/test-pod-0" endpoint := "tcp://127.0.0.1:5557" topicFilter := "kv@" + engineType := engineadapter.EngineTypeVLLM - err = sm.EnsureSubscriber(ctx, podID, endpoint, topicFilter, true) + err = sm.EnsureSubscriber(ctx, podID, endpoint, topicFilter, engineType, true) assert.NoError(t, err) // Verify subscriber was added @@ -57,7 +59,7 @@ func TestSubscriberManager_EnsureSubscriber(t *testing.T) { assert.Contains(t, endpoints, endpoint) // Ensure with same endpoint should be no-op - err = sm.EnsureSubscriber(ctx, podID, endpoint, topicFilter, true) + err = sm.EnsureSubscriber(ctx, podID, endpoint, topicFilter, engineType, true) assert.NoError(t, err) identifiers, _ = sm.GetActiveSubscribers() assert.Len(t, identifiers, 1) @@ -87,8 +89,9 @@ func TestSubscriberManager_RemoveSubscriber(t *testing.T) { podID := "default/test-pod-0" endpoint := "tcp://127.0.0.1:5557" topicFilter := "kv@" + engineType := engineadapter.EngineTypeVLLM - err = sm.EnsureSubscriber(ctx, podID, endpoint, topicFilter, true) + err = sm.EnsureSubscriber(ctx, podID, endpoint, topicFilter, engineType, true) require.NoError(t, err) // Remove subscriber @@ -117,6 +120,8 @@ func TestSubscriberManager_MultipleSubscribers(t *testing.T) { // Create subscriber manager sm := kvevents.NewSubscriberManager(pool) + engineType := engineadapter.EngineTypeVLLM + // Add multiple subscribers pods := []struct { id string @@ -128,7 +133,7 @@ func TestSubscriberManager_MultipleSubscribers(t *testing.T) { } for _, pod := range pods { - err := sm.EnsureSubscriber(ctx, pod.id, pod.endpoint, "kv@", true) + err := sm.EnsureSubscriber(ctx, pod.id, pod.endpoint, "kv@", engineType, true) require.NoError(t, err) } @@ -170,18 +175,20 @@ func TestSubscriberManager_EndpointChange(t *testing.T) { // Create subscriber manager sm := kvevents.NewSubscriberManager(pool) + engineType := engineadapter.EngineTypeVLLM + podID := "default/test-pod-0" endpoint1 := "tcp://10.0.0.1:5557" endpoint2 := "tcp://10.0.0.2:5557" // Add subscriber with first endpoint - err = sm.EnsureSubscriber(ctx, podID, endpoint1, "kv@", true) + err = sm.EnsureSubscriber(ctx, podID, endpoint1, "kv@", engineType, true) require.NoError(t, err) identifiers, _ := sm.GetActiveSubscribers() assert.Len(t, identifiers, 1) // Change endpoint - err = sm.EnsureSubscriber(ctx, podID, endpoint2, "kv@", true) + err = sm.EnsureSubscriber(ctx, podID, endpoint2, "kv@", engineType, true) require.NoError(t, err) // Should still have one subscriber (old was removed, new was added) @@ -212,6 +219,8 @@ func TestSubscriberManager_ConcurrentOperations(t *testing.T) { // Create subscriber manager sm := kvevents.NewSubscriberManager(pool) + engineType := engineadapter.EngineTypeVLLM + // Concurrently add subscribers done := make(chan bool, 10) for i := 0; i < 10; i++ { @@ -219,7 +228,7 @@ func TestSubscriberManager_ConcurrentOperations(t *testing.T) { defer func() { done <- true }() podID := "default/pod-" + string(rune('0'+id)) endpoint := "tcp://10.0.0." + string(rune('0'+id)) + ":5557" - if err := sm.EnsureSubscriber(ctx, podID, endpoint, "kv@", true); err != nil { + if err := sm.EnsureSubscriber(ctx, podID, endpoint, "kv@", engineType, true); err != nil { t.Errorf("failed to add subscriber %s: %v", podID, err) } }(i) diff --git a/pkg/kvevents/transport/transport.go b/pkg/kvevents/transport/transport.go new file mode 100644 index 00000000..0e7dabba --- /dev/null +++ b/pkg/kvevents/transport/transport.go @@ -0,0 +1,44 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import "context" + +// Transport defines the interface for receiving raw bytes from different +// transport protocols (ZMQ, HTTP, gRPC, etc.). +type Transport interface { + // Connect establishes a connection to a remote endpoint. + // Used for per-pod subscriber mode where we connect to specific pods. + Connect(ctx context.Context, endpoint string) error + + // Bind listens on a local endpoint for incoming connections. + // Used for global subscriber mode where multiple pods publish to us. + Bind(ctx context.Context, endpoint string) error + + // Subscribe sets the topic filter for receiving messages. + // The filter format depends on the transport implementation. + Subscribe(topicFilter string) error + + // Receive blocks until a message is received or context is canceled. + // Returns raw message parts from the transport. For protocols that support + // multi-part messages (like ZMQ), this returns multiple byte slices. + // For single-part protocols (like HTTP), this returns a slice with one element. + Receive(ctx context.Context) ([][]byte, error) + + // Close closes the transport connection and releases resources. + Close() error +} diff --git a/pkg/kvevents/transport/zmq.go b/pkg/kvevents/transport/zmq.go new file mode 100644 index 00000000..933c5eb1 --- /dev/null +++ b/pkg/kvevents/transport/zmq.go @@ -0,0 +1,113 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import ( + "context" + "fmt" + "time" + + zmq "github.com/pebbe/zmq4" +) + +const ( + // pollTimeout is how often the poller should time out to check for context cancellation. + pollTimeout = 250 * time.Millisecond +) + +// ZMQTransport implements the Transport interface using ZeroMQ PUB/SUB pattern. +type ZMQTransport struct { + socket *zmq.Socket + poller *zmq.Poller +} + +// NewZMQTransport creates a new ZMQ transport instance. +func NewZMQTransport() (*ZMQTransport, error) { + socket, err := zmq.NewSocket(zmq.SUB) + if err != nil { + return nil, fmt.Errorf("failed to create ZMQ SUB socket: %w", err) + } + + return &ZMQTransport{ + socket: socket, + poller: zmq.NewPoller(), + }, nil +} + +// Connect establishes a connection to a remote ZMQ PUB endpoint. +func (z *ZMQTransport) Connect(ctx context.Context, endpoint string) error { + if err := z.socket.Connect(endpoint); err != nil { + return fmt.Errorf("failed to connect to endpoint %s: %w", endpoint, err) + } + z.poller.Add(z.socket, zmq.POLLIN) + return nil +} + +// Bind listens on a local endpoint for incoming ZMQ PUB connections. +func (z *ZMQTransport) Bind(ctx context.Context, endpoint string) error { + if err := z.socket.Bind(endpoint); err != nil { + return fmt.Errorf("failed to bind to endpoint %s: %w", endpoint, err) + } + z.poller.Add(z.socket, zmq.POLLIN) + return nil +} + +// Subscribe sets the topic filter for receiving messages. +func (z *ZMQTransport) Subscribe(topicFilter string) error { + if err := z.socket.SetSubscribe(topicFilter); err != nil { + return fmt.Errorf("failed to subscribe to topic filter %s: %w", topicFilter, err) + } + return nil +} + +// Receive blocks until a message is received or context is canceled. +// Returns the raw multi-part ZMQ message as a slice of byte slices. +func (z *ZMQTransport) Receive(ctx context.Context) ([][]byte, error) { + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Poll with timeout to allow checking context cancellation + polled, err := z.poller.Poll(pollTimeout) + if err != nil { + return nil, fmt.Errorf("failed to poll ZMQ socket: %w", err) + } + + if len(polled) == 0 { + // Timeout, continue to check context + continue + } + + parts, err := z.socket.RecvMessageBytes(0) + if err != nil { + return nil, fmt.Errorf("failed to receive ZMQ message: %w", err) + } + + return parts, nil + } +} + +// Close closes the ZMQ socket and releases resources. +func (z *ZMQTransport) Close() error { + if z.socket != nil { + return z.socket.Close() + } + return nil +} diff --git a/pkg/kvevents/zmq_subscriber.go b/pkg/kvevents/zmq_subscriber.go deleted file mode 100644 index 6ac27944..00000000 --- a/pkg/kvevents/zmq_subscriber.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2025 The llm-d Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvevents - -import ( - "context" - "encoding/binary" - "strings" - "time" - - zmq "github.com/pebbe/zmq4" - "sigs.k8s.io/controller-runtime/pkg/log" - - "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" -) - -const ( - // How long to wait before retrying to connect. - retryInterval = 5 * time.Second - // How often the poller should time out to check for context cancellation. - pollTimeout = 250 * time.Millisecond -) - -// zmqSubscriber connects to a ZMQ publisher and forwards messages to a pool. -type zmqSubscriber struct { - pool *Pool - endpoint string - remote bool - topicFilter string -} - -// newZMQSubscriber creates a new ZMQ subscriber. -func newZMQSubscriber(pool *Pool, endpoint, topicFilter string, remote bool) *zmqSubscriber { - return &zmqSubscriber{ - pool: pool, - endpoint: endpoint, - remote: remote, - topicFilter: topicFilter, - } -} - -// Start connects to a ZMQ PUB socket as a SUB, receives messages, -// wraps them in Message structs, and pushes them into the pool. -// This loop will run until the provided context is canceled. -func (z *zmqSubscriber) Start(ctx context.Context) { - logger := log.FromContext(ctx).WithName("zmq-subscriber") - - for { - select { - case <-ctx.Done(): - logger.Info("shutting down zmq-subscriber") - return - default: - // We run the subscriber in a separate function to handle socket - // setup/teardown and connection retries cleanly. - z.runSubscriber(ctx) - // wait before retrying, unless the context has been canceled. - select { - case <-time.After(retryInterval): - logger.Info("retrying zmq-subscriber") - case <-ctx.Done(): - logger.Info("shutting down zmq-subscriber") - return - } - } - } -} - -// runSubscriber connects to the ZMQ PUB socket, subscribes to the topic filter, -// and listens for messages. -func (z *zmqSubscriber) runSubscriber(ctx context.Context) { - logger := log.FromContext(ctx).WithName("zmq-subscriber") - sub, err := zmq.NewSocket(zmq.SUB) - if err != nil { - logger.Error(err, "Failed to create subscriber socket") - return - } - defer sub.Close() - - // Bind for local endpoints, connect for remote ones. - if !z.remote { - if err := sub.Bind(z.endpoint); err != nil { - logger.Error(err, "Failed to bind subscriber socket", "endpoint", z.endpoint) - return - } - logger.Info("Bound subscriber socket", "endpoint", z.endpoint) - } else { - if err := sub.Connect(z.endpoint); err != nil { - logger.Error(err, "Failed to connect subscriber socket", "endpoint", z.endpoint) - return - } - logger.Info("Connected subscriber socket", "endpoint", z.endpoint) - } - - if err := sub.SetSubscribe(z.topicFilter); err != nil { - logger.Error(err, "Failed to subscribe to topic filter", "topic", z.topicFilter) - return - } - - poller := zmq.NewPoller() - poller.Add(sub, zmq.POLLIN) - debugLogger := logger.V(logging.DEBUG) - - for { - select { - case <-ctx.Done(): - return - default: - } - - polled, err := poller.Poll(pollTimeout) - if err != nil { - debugLogger.Error(err, "Failed to poll zmq subscriber", "endpoint", z.endpoint) - break // Exit on poll error to reconnect - } - - if len(polled) > 0 { - parts, err := sub.RecvMessageBytes(0) - if err != nil { - debugLogger.Error(err, "Failed to receive message from zmq subscriber", "endpoint", z.endpoint) - break // Exit on receive error to reconnect - } - if len(parts) != 3 { - debugLogger.Error(err, "Failed to receive message from zmq subscriber", "endpoint", z.endpoint) - continue - } - topic := string(parts[0]) - seqBytes := parts[1] - payload := parts[2] - - seq := binary.BigEndian.Uint64(seqBytes) - - // Extract pod identifier from topic, assuming "kv@@" format - // TODO: optimize this to not occur for every message - topicParts := strings.Split(topic, "@") - var podIdentifier, modelName string - if len(topicParts) == 3 { - podIdentifier = topicParts[1] - modelName = topicParts[2] - } else { - debugLogger.Error(nil, "Failed to extract identifiers from topic, expected format kv@@", "topic", topic) - continue // Useless if we can't extract pod identifier - } - - debugLogger.V(logging.TRACE).Info("Received message from zmq subscriber", - "topic", topic, - "seq", seq, - "podIdentifier", podIdentifier, - "modelName", modelName, - "payloadSize", len(payload)) - - z.pool.AddTask(&Message{ - Topic: topic, - Payload: payload, - Seq: seq, - PodIdentifier: podIdentifier, - ModelName: modelName, - }) - } - } -} diff --git a/tests/integration/kv_events_test.go b/tests/integration/kv_events_test.go index c3977807..22f3d4fa 100644 --- a/tests/integration/kv_events_test.go +++ b/tests/integration/kv_events_test.go @@ -22,6 +22,7 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents/engineadapter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,7 +59,7 @@ func TestPoolWithSubscriberManager_Integration(t *testing.T) { } for _, pod := range pods { - err := subscriberManager.EnsureSubscriber(ctx, pod.id, pod.endpoint, "kv@", true) + err := subscriberManager.EnsureSubscriber(ctx, pod.id, pod.endpoint, "kv@", engineadapter.EngineTypeVLLM, true) require.NoError(t, err) } @@ -76,7 +77,7 @@ func TestPoolWithSubscriberManager_Integration(t *testing.T) { // 7. Simulate pod update with endpoint change newEndpoint := "tcp://10.0.0.10:5557" - err = subscriberManager.EnsureSubscriber(ctx, "default/vllm-pod-1", newEndpoint, "kv@", true) + err = subscriberManager.EnsureSubscriber(ctx, "default/vllm-pod-1", newEndpoint, "kv@", engineadapter.EngineTypeVLLM, true) require.NoError(t, err) // Still one subscriber, but with new endpoint @@ -113,7 +114,7 @@ func TestSubscriberLifecycle(t *testing.T) { // Lifecycle: Creation t.Run("Creation", func(t *testing.T) { - err := sm.EnsureSubscriber(ctx, podID, endpoint, "kv@", true) + err := sm.EnsureSubscriber(ctx, podID, endpoint, "kv@", engineadapter.EngineTypeVLLM, true) assert.NoError(t, err) identifiers, _ := sm.GetActiveSubscribers() assert.Contains(t, identifiers, podID) @@ -121,7 +122,7 @@ func TestSubscriberLifecycle(t *testing.T) { // Lifecycle: Idempotent creation (same endpoint) t.Run("IdempotentCreation", func(t *testing.T) { - err := sm.EnsureSubscriber(ctx, podID, endpoint, "kv@", true) + err := sm.EnsureSubscriber(ctx, podID, endpoint, "kv@", engineadapter.EngineTypeVLLM, true) assert.NoError(t, err) identifiers, endpoints := sm.GetActiveSubscribers() assert.Contains(t, identifiers, podID) @@ -131,7 +132,7 @@ func TestSubscriberLifecycle(t *testing.T) { // Lifecycle: Update (different endpoint) t.Run("Update", func(t *testing.T) { newEndpoint := "tcp://127.0.0.1:5558" - err := sm.EnsureSubscriber(ctx, podID, newEndpoint, "kv@", true) + err := sm.EnsureSubscriber(ctx, podID, newEndpoint, "kv@", engineadapter.EngineTypeVLLM, true) assert.NoError(t, err) identifiers, endpoints := sm.GetActiveSubscribers() assert.Contains(t, identifiers, podID) diff --git a/vllm-setup-helm/templates/deployment.yaml b/vllm-setup-helm/templates/deployment.yaml index e2629bcc..82f00918 100644 --- a/vllm-setup-helm/templates/deployment.yaml +++ b/vllm-setup-helm/templates/deployment.yaml @@ -82,7 +82,7 @@ spec: --block-size {{ .Values.vllm.blockSize }} \ {{- if .Values.kvCacheManager.enabled }} --kv-events-config "{\"enable_kv_cache_events\":{{ .Values.kvCacheManager.enabled }},\"publisher\":\"zmq\",\"endpoint\":\"{{ include "chart.kvCacheManagerServiceUrl" . }}\",\"topic\":\"kv@${POD_IP}@{{ .Values.vllm.model.name }}\"}" \ - --prefix-caching-hash-algo sha256_cbor \ + --prefix-caching-hash-algo sha256_cbor_64bit \ {{- end }} ports: - name: http