Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 114 additions & 45 deletions pkg/plugins/gateway/algorithms/pd_disaggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const (
RouterPD types.RoutingAlgorithm = "pd"
VLLMEngine string = "vllm"
SGLangEngine string = "sglang"
TensorRTLLM string = "trtllm"
SGLangBootstrapPort int64 = 8998
SGLangBootstrapPortIdentifier string = "model.aibrix.ai/sglang-bootstrap-port"
LLMEngineIdentifier string = constants.ModelLabelEngine
Expand All @@ -62,11 +63,10 @@ const (
defaultRequestRateHighLoadThreshold = 1.0
defaultRequestRateLowLoadThreshold = 0.25

pdRouteValidateLLMEngineFail = "pd-validate-llm-engine-fail"
pdRouteFilterPrefillDecodePodsFail = "pd-filter-prefill-decode-pods-fail"
pdRoutePrefillRequestError = "pd-do-prefill-request-error"
pdRoutePrefillRequestSuccess = "pd-prefill-request-success"
pdRoutePrefillEmptyKVTransferParams = "pd-prefill-empty-kv-transfer-params"
pdRouteValidateLLMEngineFail = "pd-validate-llm-engine-fail"
pdRouteFilterPrefillDecodePodsFail = "pd-filter-prefill-decode-pods-fail"
pdRoutePrefillRequestError = "pd-do-prefill-request-error"
pdRoutePrefillRequestSuccess = "pd-prefill-request-success"
)

const (
Expand Down Expand Up @@ -522,6 +522,7 @@ func (r *pdRouter) finalPDScore(routingCtx *types.RoutingContext,

return targetPrefillPod, targetDecodePod, nil
}

func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod *v1.Pod, llmEngine string) error {
// Prepare prefill request payload
payload, err := r.preparePrefillPayload(routingCtx, prefillPod, llmEngine)
Expand Down Expand Up @@ -577,53 +578,57 @@ func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod
}()

case VLLMEngine:
defer r.prefillRequestTracker.RemovePrefillRequest(routingCtx.RequestID)

// For vLLM, wait synchronously to get KV transfer params from response
responseData, err := r.executeHTTPRequest(apiURL, routingCtx, payload)
if err != nil {
klog.ErrorS(err, "prefill_request_failed",
"request_id", routingCtx.RequestID,
"llm_engine", llmEngine,
"prefill_pod", prefillPod.Name,
"prefill_pod_ip", prefillPod.Status.PodIP,
"elapsed", routingCtx.Elapsed(time.Now()))
return fmt.Errorf("prefill request failed for request %s, pod %s: %w", routingCtx.RequestID, prefillPod.Name, err)
}
return r.handleSyncPrefill(routingCtx, prefillPod, llmEngine, apiURL, payload, fields, r.updateRoutingContextWithKVTransferParams, "KV transfer params")

// Update routing context with KV transfer params from prefill response
if err := r.updateRoutingContextWithKVTransferParams(routingCtx, responseData, prefillPod); err != nil {
return fmt.Errorf("failed to update routing context with KV transfer params for request %s: %w", routingCtx.RequestID, err)
}

routingCtx.PrefillEndTime = time.Now()
fields = append(fields,
"routing_time_taken", routingCtx.PrefillStartTime.Sub(routingCtx.RequestTime),
"prefill_time_taken", routingCtx.PrefillEndTime.Sub(routingCtx.PrefillStartTime),
"outstanding_prefill_requests", r.prefillRequestTracker.GetPrefillRequestCountsForPod(prefillPod.Name)-1)
klog.InfoS("prefill_request_end", fields...)
case TensorRTLLM:
// For TensorRT-LLM, wait synchronously to get disaggregated_params from response.
// The prefill response contains first_gen_tokens and opaque_state needed by the decode worker.
return r.handleSyncPrefill(routingCtx, prefillPod, llmEngine, apiURL, payload, fields, r.updateRoutingContextWithTRTDisaggParams, "TRT disagg params")

default:
defer r.prefillRequestTracker.RemovePrefillRequest(routingCtx.RequestID)

// For unknown engines, use synchronous approach as a safe default
if _, err := r.executeHTTPRequest(apiURL, routingCtx, payload); err != nil {
klog.ErrorS(err, "prefill_request_failed",
"request_id", routingCtx.RequestID,
"llm_engine", llmEngine,
"prefill_pod", prefillPod.Name,
"prefill_pod_ip", prefillPod.Status.PodIP,
"elapsed", routingCtx.Elapsed(time.Now()))
return fmt.Errorf("prefill request failed for request %s, pod %s: %w", routingCtx.RequestID, prefillPod.Name, err)
return r.handleSyncPrefill(routingCtx, prefillPod, llmEngine, apiURL, payload, fields, nil, "")
}

return nil
}

// handleSyncPrefill executes a synchronous prefill request, optionally calling updateCtxFunc
// to process the response. Pass nil for updateCtxFunc when no response processing is needed.
func (r *pdRouter) handleSyncPrefill(
routingCtx *types.RoutingContext,
prefillPod *v1.Pod,
llmEngine, apiURL string,
payload []byte,
fields []interface{},
updateCtxFunc func(*types.RoutingContext, map[string]any, *v1.Pod) error,
errorContext string) error {
defer r.prefillRequestTracker.RemovePrefillRequest(routingCtx.RequestID)

responseData, err := r.executeHTTPRequest(apiURL, routingCtx, payload)
if err != nil {
klog.ErrorS(err, "prefill_request_failed",
"request_id", routingCtx.RequestID,
"llm_engine", llmEngine,
"prefill_pod", prefillPod.Name,
"prefill_pod_ip", prefillPod.Status.PodIP,
"elapsed", routingCtx.Elapsed(time.Now()))
return fmt.Errorf("prefill request failed for request %s, pod %s: %w", routingCtx.RequestID, prefillPod.Name, err)
}

if updateCtxFunc != nil {
if err := updateCtxFunc(routingCtx, responseData, prefillPod); err != nil {
return fmt.Errorf("failed to update routing context with %s for request %s: %w", errorContext, routingCtx.RequestID, err)
}
routingCtx.PrefillEndTime = time.Now()
fields = append(fields,
"routing_time_taken", routingCtx.PrefillStartTime.Sub(routingCtx.RequestTime),
"prefill_time_taken", routingCtx.PrefillEndTime.Sub(routingCtx.PrefillStartTime),
"outstanding_prefill_requests", r.prefillRequestTracker.GetPrefillRequestCountsForPod(prefillPod.Name)-1)
klog.InfoS("prefill_request_end", fields...)
}

routingCtx.PrefillEndTime = time.Now()
fields = append(fields,
"routing_time_taken", routingCtx.PrefillStartTime.Sub(routingCtx.RequestTime),
"prefill_time_taken", routingCtx.PrefillEndTime.Sub(routingCtx.PrefillStartTime),
"outstanding_prefill_requests", r.prefillRequestTracker.GetPrefillRequestCountsForPod(prefillPod.Name)-1)
klog.InfoS("prefill_request_end", fields...)
return nil
}

Expand Down Expand Up @@ -662,9 +667,22 @@ func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *
}
}

if llmEngine == TensorRTLLM {
// Signal to TensorRT-LLM that this is a context-only (prefill) request.
// The prefill response will return disaggregated_params containing
// first_gen_tokens and opaque_state, which are injected into the decode request.
completionRequest["disaggregated_params"] = map[string]any{
"request_type": "context_only",
}
}

// Set prefill-specific parameters
completionRequest["max_tokens"] = 1
completionRequest["max_completion_tokens"] = 1
if llmEngine == TensorRTLLM {
delete(completionRequest, "max_completion_tokens")
} else {
completionRequest["max_completion_tokens"] = 1
}
completionRequest["stream"] = false
delete(completionRequest, "stream_options")

Expand Down Expand Up @@ -788,6 +806,57 @@ func (r *pdRouter) updateRoutingContextWithKVTransferParams(routingCtx *types.Ro
return nil
}

func (r *pdRouter) updateRoutingContextWithTRTDisaggParams(routingCtx *types.RoutingContext, responseData map[string]any, prefillPod *v1.Pod) error {
// Parse the original request body
var originalRequest map[string]any
if err := sonic.Unmarshal(routingCtx.ReqBody, &originalRequest); err != nil {
return fmt.Errorf("failed to unmarshal original request body: %w", err)
}

// Extract disaggregated_params from prefill response.
// TRT-LLM may return it at the top level or inside choices[0].
var disaggParams any
var exists bool

disaggParams, exists = responseData["disaggregated_params"]
if !exists {
// Fallback: check choices[0] (TRT-LLM serializes handler output as a choice)
if choices, ok := responseData["choices"].([]any); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]any); ok {
disaggParams, exists = choice["disaggregated_params"]
}
}
}

if !exists {
klog.InfoS("no disaggregated_params in TRT prefill response", "request_id", routingCtx.RequestID)
return nil
}

disaggParamsMap, ok := disaggParams.(map[string]any)
if !ok {
return fmt.Errorf("disaggregated_params has unexpected type %T, expected map[string]any", disaggParams)
}

// Override request_type to generation_only for the decode request
disaggParamsMap["request_type"] = "generation_only"
originalRequest["disaggregated_params"] = disaggParamsMap

updatedReqBody, err := sonic.Marshal(originalRequest)
if err != nil {
return fmt.Errorf("failed to marshal updated request body: %w", err)
}

routingCtx.ReqBody = updatedReqBody

klog.InfoS("updated routing context with disaggregated_params (TensorRT-LLM)",
"request_id", routingCtx.RequestID,
"prefill_pod", prefillPod.Name,
"prefill_host", prefillPod.Status.PodIP)

return nil
}

func (r *pdRouter) SubscribedMetrics() []string {
return []string{}
}
Expand Down
Loading
Loading