Skip to content

Commit fd2d71b

Browse files
authored
Added pod_type categorical feature (''=monolithic, 'prefill', 'decode') (#1993)
to both TTFT and TPOT prediction models - Add pod_type field to PredictionRequest and TrainingEntry models - Encode pod_type as categorical in _prepare_features_with_interaction - Handle pod_type_cat in both TTFT and TPOT feature columns - One-hot encode pod_type_cat for Bayesian Ridge models - Add pod_type to XGBoost/LightGBM feature orders with monotone constraints - Add comprehensive tests for pod_type functionality - Update Go types to include PodType field
1 parent fec7fd0 commit fd2d71b

File tree

4 files changed

+341
-60
lines changed

4 files changed

+341
-60
lines changed

latencypredictor/prediction_server.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,25 @@ def is_ready(self) -> bool:
219219
def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str) -> pd.DataFrame:
220220
"""
221221
Prepare features with interaction terms to match training server.
222-
223222
Args:
224223
df: DataFrame with raw features
225224
model_type: 'ttft' or 'tpot'
226-
227225
Returns:
228226
DataFrame with engineered features including interactions
229227
"""
228+
# Encode pod_type as categorical (common for both TTFT and TPOT)
229+
# Convert to categorical with known categories for consistent encoding
230+
if 'pod_type' in df.columns:
231+
df['pod_type'] = df['pod_type'].fillna('') # Handle NaN
232+
df['pod_type_cat'] = pd.Categorical(
233+
df['pod_type'],
234+
categories=['', 'prefill', 'decode'], # '' = monolithic, prefill, decode
235+
ordered=False
236+
)
237+
else:
238+
# If pod_type column doesn't exist, create it as empty (monolithic)
239+
df['pod_type_cat'] = pd.Categorical([''] * len(df), categories=['', 'prefill', 'decode'], ordered=False)
240+
230241
if model_type == "ttft":
231242
# Create interaction: prefix score * input length
232243
df['effective_input_tokens'] = (1-df['prefix_cache_score']) * df['input_token_length']
@@ -238,31 +249,33 @@ def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str)
238249

239250
# make it categorical for tree models (safe for LGB, XGB with enable_categorical)
240251
df['prefill_score_bucket'] = pd.Categorical(df['prefill_score_bucket'], categories=[0,1,2,3], ordered=True)
241-
242-
243-
# Return TTFT features with interaction
252+
253+
254+
# Return TTFT features with interaction and pod_type
244255
feature_cols = [
245256
'kv_cache_percentage',
246257
'input_token_length',
247258
'num_request_waiting',
248259
'num_request_running',
249260
'prefix_cache_score',
250261
'effective_input_tokens',
251-
'prefill_score_bucket'
262+
'prefill_score_bucket',
263+
'pod_type_cat'
252264
]
253-
265+
254266
return df[feature_cols]
255-
267+
256268
else: # tpot
257269
# TPOT doesn't use prefix_cache_score, so no interaction needed
258270
feature_cols = [
259271
'kv_cache_percentage',
260272
'input_token_length',
261273
'num_request_waiting',
262274
'num_request_running',
263-
'num_tokens_generated'
275+
'num_tokens_generated',
276+
'pod_type_cat'
264277
]
265-
278+
266279
return df[feature_cols]
267280

268281
def load_models(self) -> bool:
@@ -314,29 +327,42 @@ def predict(self, features: dict) -> Tuple[float, float]:
314327
'num_request_running': features['num_request_running'],
315328
'prefix_cache_score': features['prefix_cache_score']
316329
}
317-
330+
318331
tpot_raw_data = {
319332
'kv_cache_percentage': features['kv_cache_percentage'],
320333
'input_token_length': features['input_token_length'],
321334
'num_request_waiting': features['num_request_waiting'],
322335
'num_request_running': features['num_request_running'],
323336
'num_tokens_generated': features['num_tokens_generated']
324337
}
325-
338+
326339
# Prepare features with interactions
327340
df_ttft_raw = pd.DataFrame([ttft_raw_data])
341+
# Add pod_type if present
342+
if 'pod_type' in features:
343+
df_ttft_raw['pod_type'] = features['pod_type']
328344
df_ttft = self._prepare_features_with_interaction(df_ttft_raw, "ttft")
329-
330-
345+
346+
331347
df_tpot_raw = pd.DataFrame([tpot_raw_data])
348+
# Add pod_type if present
349+
if 'pod_type' in features:
350+
df_tpot_raw['pod_type'] = features['pod_type']
332351
df_tpot = self._prepare_features_with_interaction(df_tpot_raw, "tpot")
333352
#df_tpot = pd.DataFrame([tpot_raw_data])
334353

335354
if self.model_type == ModelType.BAYESIAN_RIDGE:
336-
355+
# Bayesian Ridge can't handle categorical features directly
356+
# Drop categorical bucket, but one-hot encode pod_type
337357
ttft_for_scale = df_ttft.drop(columns=['prefill_score_bucket'], errors='ignore')
358+
if 'pod_type_cat' in ttft_for_scale.columns:
359+
ttft_for_scale = pd.get_dummies(ttft_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
338360
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
339-
tpot_scaled = self.tpot_scaler.transform(df_tpot)
361+
362+
tpot_for_scale = df_tpot.copy()
363+
if 'pod_type_cat' in tpot_for_scale.columns:
364+
tpot_for_scale = pd.get_dummies(tpot_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
365+
tpot_scaled = self.tpot_scaler.transform(tpot_for_scale)
340366

341367
ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
342368
tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True)
@@ -388,37 +414,53 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
388414
# Create raw feature data (without interaction)
389415
ttft_raw_data = []
390416
tpot_raw_data = []
391-
417+
392418
for features in features_list:
393-
ttft_raw_data.append({
419+
ttft_entry = {
394420
'kv_cache_percentage': features['kv_cache_percentage'],
395421
'input_token_length': features['input_token_length'],
396422
'num_request_waiting': features['num_request_waiting'],
397423
'num_request_running': features['num_request_running'],
398424
'prefix_cache_score': features['prefix_cache_score']
399-
})
400-
401-
tpot_raw_data.append({
425+
}
426+
# Add pod_type if present
427+
if 'pod_type' in features:
428+
ttft_entry['pod_type'] = features['pod_type']
429+
ttft_raw_data.append(ttft_entry)
430+
431+
tpot_entry = {
402432
'kv_cache_percentage': features['kv_cache_percentage'],
403433
'input_token_length': features['input_token_length'],
404434
'num_request_waiting': features['num_request_waiting'],
405435
'num_request_running': features['num_request_running'],
406436
'num_tokens_generated': features['num_tokens_generated']
407-
})
408-
437+
}
438+
# Add pod_type if present
439+
if 'pod_type' in features:
440+
tpot_entry['pod_type'] = features['pod_type']
441+
tpot_raw_data.append(tpot_entry)
442+
409443
# Prepare features with interactions
410444
df_ttft_raw = pd.DataFrame(ttft_raw_data)
411445
df_ttft_batch = self._prepare_features_with_interaction(df_ttft_raw, "ttft")
412446
#df_ttft_batch = pd.DataFrame(ttft_raw_data)
413-
447+
414448
df_tpot_raw = pd.DataFrame(tpot_raw_data)
415449
df_tpot_batch = self._prepare_features_with_interaction(df_tpot_raw, "tpot")
416450
#df_tpot_batch = pd.DataFrame(tpot_raw_data)
417451

418452
if self.model_type == ModelType.BAYESIAN_RIDGE:
453+
# Bayesian Ridge can't handle categorical features directly
454+
# Drop categorical bucket, but one-hot encode pod_type
419455
ttft_for_scale = df_ttft_batch.drop(columns=['prefill_score_bucket'], errors='ignore')
456+
if 'pod_type_cat' in ttft_for_scale.columns:
457+
ttft_for_scale = pd.get_dummies(ttft_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
420458
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
421-
tpot_scaled = self.tpot_scaler.transform(df_tpot_batch)
459+
460+
tpot_for_scale = df_tpot_batch.copy()
461+
if 'pod_type_cat' in tpot_for_scale.columns:
462+
tpot_for_scale = pd.get_dummies(tpot_for_scale, columns=['pod_type_cat'], prefix='pod_type', drop_first=False)
463+
tpot_scaled = self.tpot_scaler.transform(tpot_for_scale)
422464

423465
ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
424466
tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True)
@@ -471,6 +513,7 @@ class PredictionRequest(BaseModel):
471513
num_request_running: int = Field(..., ge=0)
472514
num_tokens_generated: int = Field(..., ge=0)
473515
prefix_cache_score: float = Field(..., ge=0.0, le=1.0, description="Prefix cache hit ratio score (0.0 to 1.0)")
516+
pod_type: Optional[str] = Field(default="", description="Pod type: 'prefill', 'decode', or '' for monolithic")
474517

475518

476519
class PredictionResponse(BaseModel):

latencypredictor/test_dual_server_client.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,202 @@ def test_prediction_missing_prefix_cache_score():
484484
print("✓ Prediction correctly failed when prefix_cache_score was missing")
485485

486486

487+
def test_prediction_with_pod_type_prefill():
488+
"""Test predictions with pod_type='prefill' parameter."""
489+
print("Testing prediction with pod_type='prefill'...")
490+
491+
features = {
492+
"kv_cache_percentage": 0.5,
493+
"input_token_length": 200,
494+
"num_request_waiting": 4,
495+
"num_request_running": 1,
496+
"num_tokens_generated": 0, # Prefill doesn't generate tokens
497+
"prefix_cache_score": 0.7,
498+
"pod_type": "prefill",
499+
}
500+
501+
r = requests.post(f"{PREDICTION_URL}/predict", json=features)
502+
assert r.status_code == 200
503+
504+
data = r.json()
505+
assert "ttft_ms" in data
506+
assert "tpot_ms" in data
507+
assert data["ttft_ms"] > 0
508+
assert data["tpot_ms"] >= 0 # Non-negative
509+
510+
print(f"✓ Prefill prediction: TTFT={data['ttft_ms']:.2f}ms, TPOT={data['tpot_ms']:.2f}ms")
511+
512+
513+
def test_prediction_with_pod_type_decode():
514+
"""Test predictions with pod_type='decode' parameter."""
515+
print("Testing prediction with pod_type='decode'...")
516+
517+
features = {
518+
"kv_cache_percentage": 0.5,
519+
"input_token_length": 200,
520+
"num_request_waiting": 4,
521+
"num_request_running": 1,
522+
"num_tokens_generated": 10,
523+
"prefix_cache_score": 0.7,
524+
"pod_type": "decode",
525+
}
526+
527+
r = requests.post(f"{PREDICTION_URL}/predict", json=features)
528+
assert r.status_code == 200
529+
530+
data = r.json()
531+
assert "ttft_ms" in data
532+
assert "tpot_ms" in data
533+
assert data["ttft_ms"] > 0
534+
assert data["tpot_ms"] >= 0 # Non-negative
535+
536+
print(f"✓ Decode prediction: TTFT={data['ttft_ms']:.2f}ms, TPOT={data['tpot_ms']:.2f}ms")
537+
538+
539+
def test_bulk_prediction_with_pod_type():
540+
"""Test bulk predictions with mixed pod types."""
541+
print("Testing bulk prediction with pod_type...")
542+
543+
requests_data = [
544+
# Prefill pod request
545+
{
546+
"kv_cache_percentage": 0.5,
547+
"input_token_length": 200,
548+
"num_request_waiting": 4,
549+
"num_request_running": 1,
550+
"num_tokens_generated": 0,
551+
"prefix_cache_score": 0.7,
552+
"pod_type": "prefill",
553+
},
554+
# Decode pod request
555+
{
556+
"kv_cache_percentage": 0.3,
557+
"input_token_length": 150,
558+
"num_request_waiting": 2,
559+
"num_request_running": 1,
560+
"num_tokens_generated": 10,
561+
"prefix_cache_score": 0.5,
562+
"pod_type": "decode",
563+
},
564+
# Legacy request (no pod_type)
565+
{
566+
"kv_cache_percentage": 0.6,
567+
"input_token_length": 300,
568+
"num_request_waiting": 3,
569+
"num_request_running": 2,
570+
"num_tokens_generated": 5,
571+
"prefix_cache_score": 0.8,
572+
}
573+
]
574+
575+
bulk_request = {"requests": requests_data}
576+
577+
r = requests.post(f"{PREDICTION_URL}/predict/bulk/strict", json=bulk_request)
578+
assert r.status_code == 200
579+
580+
data = r.json()
581+
assert data["total_requests"] == 3
582+
assert data["successful_predictions"] == 3
583+
assert data["failed_predictions"] == 0
584+
585+
predictions = data["predictions"]
586+
587+
# Check prefill prediction (index 0)
588+
prefill_pred = predictions[0]
589+
assert prefill_pred["ttft_ms"] > 0
590+
assert prefill_pred["tpot_ms"] >= 0 # Relaxed constraint for prefill
591+
print(f" Prefill: TTFT={prefill_pred['ttft_ms']:.2f}ms, TPOT={prefill_pred['tpot_ms']:.2f}ms")
592+
593+
# Check decode prediction (index 1)
594+
decode_pred = predictions[1]
595+
assert decode_pred["ttft_ms"] > 0
596+
assert decode_pred["tpot_ms"] > 0 # Should be positive for decode
597+
print(f" Decode: TTFT={decode_pred['ttft_ms']:.2f}ms, TPOT={decode_pred['tpot_ms']:.2f}ms")
598+
599+
# Check legacy prediction (index 2)
600+
legacy_pred = predictions[2]
601+
assert legacy_pred["ttft_ms"] > 0
602+
assert legacy_pred["tpot_ms"] > 0
603+
print(f" Legacy: TTFT={legacy_pred['ttft_ms']:.2f}ms, TPOT={legacy_pred['tpot_ms']:.2f}ms")
604+
605+
print("✓ Bulk prediction with mixed pod types passed")
606+
607+
608+
def test_training_data_with_pod_type():
609+
"""Test that training server accepts pod_type in training data."""
610+
print("Testing training data with pod_type...")
611+
612+
# Generate training samples with pod_type
613+
prefill_entries = []
614+
decode_entries = []
615+
616+
# Prefill training samples (TPOT should be 0)
617+
for i in range(10):
618+
prefill_entries.append({
619+
"kv_cache_percentage": 0.5,
620+
"input_token_length": 200 + i * 10,
621+
"num_request_waiting": i % 5,
622+
"num_request_running": 1,
623+
"actual_ttft_ms": 100.0 + i * 5,
624+
"actual_tpot_ms": 0.0, # Prefill doesn't produce tokens
625+
"num_tokens_generated": 0,
626+
"prefix_cache_score": 0.7,
627+
"pod_type": "prefill",
628+
})
629+
630+
# Decode training samples (both TTFT and TPOT)
631+
for i in range(10):
632+
decode_entries.append({
633+
"kv_cache_percentage": 0.5,
634+
"input_token_length": 200 + i * 10,
635+
"num_request_waiting": i % 5,
636+
"num_request_running": 1,
637+
"actual_ttft_ms": 100.0 + i * 5,
638+
"actual_tpot_ms": 10.0 + i * 2,
639+
"num_tokens_generated": 5 + i,
640+
"prefix_cache_score": 0.7,
641+
"pod_type": "decode",
642+
})
643+
644+
all_entries = prefill_entries + decode_entries
645+
payload = {"entries": all_entries}
646+
647+
r = requests.post(f"{TRAINING_URL}/add_training_data_bulk", json=payload)
648+
assert r.status_code == 202
649+
assert r.json().get("message") == f"Accepted {len(all_entries)} training samples."
650+
651+
print(f"✓ Successfully sent {len(all_entries)} training samples with pod_type")
652+
653+
654+
def test_invalid_pod_type():
655+
"""Test that invalid pod_type values are handled correctly."""
656+
print("Testing invalid pod_type handling...")
657+
658+
features = {
659+
"kv_cache_percentage": 0.5,
660+
"input_token_length": 200,
661+
"num_request_waiting": 4,
662+
"num_request_running": 1,
663+
"num_tokens_generated": 10,
664+
"prefix_cache_score": 0.7,
665+
"pod_type": "invalid_type", # Invalid pod type
666+
}
667+
668+
r = requests.post(f"{PREDICTION_URL}/predict", json=features)
669+
670+
# Should either accept it (treating as legacy) or reject with validation error
671+
if r.status_code == 422:
672+
print("✓ Invalid pod_type rejected with validation error (strict validation)")
673+
elif r.status_code == 200:
674+
data = r.json()
675+
# If accepted, should still return valid predictions
676+
assert data["ttft_ms"] > 0
677+
assert data["tpot_ms"] >= 0
678+
print("✓ Invalid pod_type accepted with fallback behavior (permissive validation)")
679+
else:
680+
assert False, f"Unexpected status code {r.status_code} for invalid pod_type"
681+
682+
487683
def test_training_server_metrics():
488684
"""Test training server metrics endpoint."""
489685
r = requests.get(f"{TRAINING_URL}/metrics")

0 commit comments

Comments
 (0)