Skip to content

Commit efcc534

Browse files
committed
Added pod_type categorical feature (''=monolithic, 'prefill', 'decode')
to both TTFT and TPOT prediction models
1 parent 14598d2 commit efcc534

File tree

4 files changed

+314
-49
lines changed

4 files changed

+314
-49
lines changed

latencypredictor/prediction_server.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,25 @@ def is_ready(self) -> bool:
219219
def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str) -> pd.DataFrame:
220220
"""
221221
Prepare features with interaction terms to match training server.
222-
223222
Args:
224223
df: DataFrame with raw features
225224
model_type: 'ttft' or 'tpot'
226-
227225
Returns:
228226
DataFrame with engineered features including interactions
229227
"""
228+
# Encode pod_type as categorical (common for both TTFT and TPOT)
229+
# Convert to categorical with known categories for consistent encoding
230+
if 'pod_type' in df.columns:
231+
df['pod_type'] = df['pod_type'].fillna('') # Handle NaN
232+
df['pod_type_cat'] = pd.Categorical(
233+
df['pod_type'],
234+
categories=['', 'prefill', 'decode'], # '' = monolithic, prefill, decode
235+
ordered=False
236+
)
237+
else:
238+
# If pod_type column doesn't exist, create it as empty (monolithic)
239+
df['pod_type_cat'] = pd.Categorical([''] * len(df), categories=['', 'prefill', 'decode'], ordered=False)
240+
230241
if model_type == "ttft":
231242
# Create interaction: prefix score * input length
232243
df['effective_input_tokens'] = (1-df['prefix_cache_score']) * df['input_token_length']
@@ -238,31 +249,32 @@ def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str)
238249

239250
# make it categorical for tree models (safe for LGB, XGB with enable_categorical)
240251
df['prefill_score_bucket'] = pd.Categorical(df['prefill_score_bucket'], categories=[0,1,2,3], ordered=True)
241-
242-
243-
# Return TTFT features with interaction
252+
253+
254+
# Return TTFT features with interaction and pod_type
244255
feature_cols = [
245256
'kv_cache_percentage',
246257
'input_token_length',
247258
'num_request_waiting',
248259
'num_request_running',
249260
'prefix_cache_score',
250261
'effective_input_tokens',
251-
'prefill_score_bucket'
262+
'prefill_score_bucket',
263+
'pod_type_cat'
252264
]
253-
265+
254266
return df[feature_cols]
255-
267+
256268
else: # tpot
257269
# TPOT doesn't use prefix_cache_score, so no interaction needed
258270
feature_cols = [
259271
'kv_cache_percentage',
260272
'input_token_length',
261273
'num_request_waiting',
262274
'num_request_running',
263-
'num_tokens_generated'
275+
'num_tokens_generated',
276+
'pod_type_cat'
264277
]
265-
266278
return df[feature_cols]
267279

268280
def load_models(self) -> bool:
@@ -333,10 +345,17 @@ def predict(self, features: dict) -> Tuple[float, float]:
333345
#df_tpot = pd.DataFrame([tpot_raw_data])
334346

335347
if self.model_type == ModelType.BAYESIAN_RIDGE:
336-
348+
# Bayesian Ridge can't handle categorical features directly
349+
# Drop categorical bucket, but one-hot encode pod_type
337350
ttft_for_scale = df_ttft.drop(columns=['prefill_score_bucket'], errors='ignore')
351+
if 'pod_type_cat' in ttft_for_scale.columns:
352+
ttft_for_scale = pd.get_dummies(ttft_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
338353
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
339-
tpot_scaled = self.tpot_scaler.transform(df_tpot)
354+
355+
tpot_for_scale = df_tpot.copy()
356+
if 'pod_type_cat' in tpot_for_scale.columns:
357+
tpot_for_scale = pd.get_dummies(tpot_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
358+
tpot_scaled = self.tpot_scaler.transform(tpot_for_scale)
340359

341360
ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
342361
tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True)
@@ -416,9 +435,17 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
416435
#df_tpot_batch = pd.DataFrame(tpot_raw_data)
417436

418437
if self.model_type == ModelType.BAYESIAN_RIDGE:
438+
# Bayesian Ridge can't handle categorical features directly
439+
# Drop categorical bucket, but one-hot encode pod_type
419440
ttft_for_scale = df_ttft_batch.drop(columns=['prefill_score_bucket'], errors='ignore')
441+
if 'pod_type_cat' in ttft_for_scale.columns:
442+
ttft_for_scale = pd.get_dummies(ttft_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
420443
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
421-
tpot_scaled = self.tpot_scaler.transform(df_tpot_batch)
444+
445+
tpot_for_scale = df_tpot_batch.copy()
446+
if 'pod_type_cat' in tpot_for_scale.columns:
447+
tpot_for_scale = pd.get_dummies(tpot_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
448+
tpot_scaled = self.tpot_scaler.transform(tpot_for_scale)
422449

423450
ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
424451
tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True)
@@ -471,6 +498,7 @@ class PredictionRequest(BaseModel):
471498
num_request_running: int = Field(..., ge=0)
472499
num_tokens_generated: int = Field(..., ge=0)
473500
prefix_cache_score: float = Field(..., ge=0.0, le=1.0, description="Prefix cache hit ratio score (0.0 to 1.0)")
501+
pod_type: Optional[str] = Field(default="", description="Pod type: 'prefill', 'decode', or '' for monolithic")
474502

475503

476504
class PredictionResponse(BaseModel):

latencypredictor/test_dual_server_client.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,202 @@ def test_prediction_missing_prefix_cache_score():
484484
print("✓ Prediction correctly failed when prefix_cache_score was missing")
485485

486486

487+
def test_prediction_with_pod_type_prefill():
488+
"""Test predictions with pod_type='prefill' parameter."""
489+
print("Testing prediction with pod_type='prefill'...")
490+
491+
features = {
492+
"kv_cache_percentage": 0.5,
493+
"input_token_length": 200,
494+
"num_request_waiting": 4,
495+
"num_request_running": 1,
496+
"num_tokens_generated": 0, # Prefill doesn't generate tokens
497+
"prefix_cache_score": 0.7,
498+
"pod_type": "prefill",
499+
}
500+
501+
r = requests.post(f"{PREDICTION_URL}/predict", json=features)
502+
assert r.status_code == 200
503+
504+
data = r.json()
505+
assert "ttft_ms" in data
506+
assert "tpot_ms" in data
507+
assert data["ttft_ms"] > 0
508+
assert data["tpot_ms"] >= 0 # Non-negative
509+
510+
print(f"✓ Prefill prediction: TTFT={data['ttft_ms']:.2f}ms, TPOT={data['tpot_ms']:.2f}ms")
511+
512+
513+
def test_prediction_with_pod_type_decode():
514+
"""Test predictions with pod_type='decode' parameter."""
515+
print("Testing prediction with pod_type='decode'...")
516+
517+
features = {
518+
"kv_cache_percentage": 0.5,
519+
"input_token_length": 200,
520+
"num_request_waiting": 4,
521+
"num_request_running": 1,
522+
"num_tokens_generated": 10,
523+
"prefix_cache_score": 0.7,
524+
"pod_type": "decode",
525+
}
526+
527+
r = requests.post(f"{PREDICTION_URL}/predict", json=features)
528+
assert r.status_code == 200
529+
530+
data = r.json()
531+
assert "ttft_ms" in data
532+
assert "tpot_ms" in data
533+
assert data["ttft_ms"] > 0
534+
assert data["tpot_ms"] >= 0 # Non-negative
535+
536+
print(f"✓ Decode prediction: TTFT={data['ttft_ms']:.2f}ms, TPOT={data['tpot_ms']:.2f}ms")
537+
538+
539+
def test_bulk_prediction_with_pod_type():
540+
"""Test bulk predictions with mixed pod types."""
541+
print("Testing bulk prediction with pod_type...")
542+
543+
requests_data = [
544+
# Prefill pod request
545+
{
546+
"kv_cache_percentage": 0.5,
547+
"input_token_length": 200,
548+
"num_request_waiting": 4,
549+
"num_request_running": 1,
550+
"num_tokens_generated": 0,
551+
"prefix_cache_score": 0.7,
552+
"pod_type": "prefill",
553+
},
554+
# Decode pod request
555+
{
556+
"kv_cache_percentage": 0.3,
557+
"input_token_length": 150,
558+
"num_request_waiting": 2,
559+
"num_request_running": 1,
560+
"num_tokens_generated": 10,
561+
"prefix_cache_score": 0.5,
562+
"pod_type": "decode",
563+
},
564+
# Legacy request (no pod_type)
565+
{
566+
"kv_cache_percentage": 0.6,
567+
"input_token_length": 300,
568+
"num_request_waiting": 3,
569+
"num_request_running": 2,
570+
"num_tokens_generated": 5,
571+
"prefix_cache_score": 0.8,
572+
}
573+
]
574+
575+
bulk_request = {"requests": requests_data}
576+
577+
r = requests.post(f"{PREDICTION_URL}/predict/bulk/strict", json=bulk_request)
578+
assert r.status_code == 200
579+
580+
data = r.json()
581+
assert data["total_requests"] == 3
582+
assert data["successful_predictions"] == 3
583+
assert data["failed_predictions"] == 0
584+
585+
predictions = data["predictions"]
586+
587+
# Check prefill prediction (index 0)
588+
prefill_pred = predictions[0]
589+
assert prefill_pred["ttft_ms"] > 0
590+
assert prefill_pred["tpot_ms"] >= 0 # Relaxed constraint for prefill
591+
print(f" Prefill: TTFT={prefill_pred['ttft_ms']:.2f}ms, TPOT={prefill_pred['tpot_ms']:.2f}ms")
592+
593+
# Check decode prediction (index 1)
594+
decode_pred = predictions[1]
595+
assert decode_pred["ttft_ms"] > 0
596+
assert decode_pred["tpot_ms"] > 0 # Should be positive for decode
597+
print(f" Decode: TTFT={decode_pred['ttft_ms']:.2f}ms, TPOT={decode_pred['tpot_ms']:.2f}ms")
598+
599+
# Check legacy prediction (index 2)
600+
legacy_pred = predictions[2]
601+
assert legacy_pred["ttft_ms"] > 0
602+
assert legacy_pred["tpot_ms"] > 0
603+
print(f" Legacy: TTFT={legacy_pred['ttft_ms']:.2f}ms, TPOT={legacy_pred['tpot_ms']:.2f}ms")
604+
605+
print("✓ Bulk prediction with mixed pod types passed")
606+
607+
608+
def test_training_data_with_pod_type():
609+
"""Test that training server accepts pod_type in training data."""
610+
print("Testing training data with pod_type...")
611+
612+
# Generate training samples with pod_type
613+
prefill_entries = []
614+
decode_entries = []
615+
616+
# Prefill training samples (TPOT should be 0)
617+
for i in range(10):
618+
prefill_entries.append({
619+
"kv_cache_percentage": 0.5,
620+
"input_token_length": 200 + i * 10,
621+
"num_request_waiting": i % 5,
622+
"num_request_running": 1,
623+
"actual_ttft_ms": 100.0 + i * 5,
624+
"actual_tpot_ms": 0.0, # Prefill doesn't produce tokens
625+
"num_tokens_generated": 0,
626+
"prefix_cache_score": 0.7,
627+
"pod_type": "prefill",
628+
})
629+
630+
# Decode training samples (both TTFT and TPOT)
631+
for i in range(10):
632+
decode_entries.append({
633+
"kv_cache_percentage": 0.5,
634+
"input_token_length": 200 + i * 10,
635+
"num_request_waiting": i % 5,
636+
"num_request_running": 1,
637+
"actual_ttft_ms": 100.0 + i * 5,
638+
"actual_tpot_ms": 10.0 + i * 2,
639+
"num_tokens_generated": 5 + i,
640+
"prefix_cache_score": 0.7,
641+
"pod_type": "decode",
642+
})
643+
644+
all_entries = prefill_entries + decode_entries
645+
payload = {"entries": all_entries}
646+
647+
r = requests.post(f"{TRAINING_URL}/add_training_data_bulk", json=payload)
648+
assert r.status_code == 202
649+
assert r.json().get("message") == f"Accepted {len(all_entries)} training samples."
650+
651+
print(f"✓ Successfully sent {len(all_entries)} training samples with pod_type")
652+
653+
654+
def test_invalid_pod_type():
655+
"""Test that invalid pod_type values are handled correctly."""
656+
print("Testing invalid pod_type handling...")
657+
658+
features = {
659+
"kv_cache_percentage": 0.5,
660+
"input_token_length": 200,
661+
"num_request_waiting": 4,
662+
"num_request_running": 1,
663+
"num_tokens_generated": 10,
664+
"prefix_cache_score": 0.7,
665+
"pod_type": "invalid_type", # Invalid pod type
666+
}
667+
668+
r = requests.post(f"{PREDICTION_URL}/predict", json=features)
669+
670+
# Should either accept it (treating as legacy) or reject with validation error
671+
if r.status_code == 422:
672+
print("✓ Invalid pod_type rejected with validation error (strict validation)")
673+
elif r.status_code == 200:
674+
data = r.json()
675+
# If accepted, should still return valid predictions
676+
assert data["ttft_ms"] > 0
677+
assert data["tpot_ms"] >= 0
678+
print("✓ Invalid pod_type accepted with fallback behavior (permissive validation)")
679+
else:
680+
assert False, f"Unexpected status code {r.status_code} for invalid pod_type"
681+
682+
487683
def test_training_server_metrics():
488684
"""Test training server metrics endpoint."""
489685
r = requests.get(f"{TRAINING_URL}/metrics")

0 commit comments

Comments
 (0)