@@ -484,6 +484,202 @@ def test_prediction_missing_prefix_cache_score():
484484 print ("✓ Prediction correctly failed when prefix_cache_score was missing" )
485485
486486
487+ def test_prediction_with_pod_type_prefill ():
488+ """Test predictions with pod_type='prefill' parameter."""
489+ print ("Testing prediction with pod_type='prefill'..." )
490+
491+ features = {
492+ "kv_cache_percentage" : 0.5 ,
493+ "input_token_length" : 200 ,
494+ "num_request_waiting" : 4 ,
495+ "num_request_running" : 1 ,
496+ "num_tokens_generated" : 0 , # Prefill doesn't generate tokens
497+ "prefix_cache_score" : 0.7 ,
498+ "pod_type" : "prefill" ,
499+ }
500+
501+ r = requests .post (f"{ PREDICTION_URL } /predict" , json = features )
502+ assert r .status_code == 200
503+
504+ data = r .json ()
505+ assert "ttft_ms" in data
506+ assert "tpot_ms" in data
507+ assert data ["ttft_ms" ] > 0
508+ assert data ["tpot_ms" ] >= 0 # Non-negative
509+
510+ print (f"✓ Prefill prediction: TTFT={ data ['ttft_ms' ]:.2f} ms, TPOT={ data ['tpot_ms' ]:.2f} ms" )
511+
512+
513+ def test_prediction_with_pod_type_decode ():
514+ """Test predictions with pod_type='decode' parameter."""
515+ print ("Testing prediction with pod_type='decode'..." )
516+
517+ features = {
518+ "kv_cache_percentage" : 0.5 ,
519+ "input_token_length" : 200 ,
520+ "num_request_waiting" : 4 ,
521+ "num_request_running" : 1 ,
522+ "num_tokens_generated" : 10 ,
523+ "prefix_cache_score" : 0.7 ,
524+ "pod_type" : "decode" ,
525+ }
526+
527+ r = requests .post (f"{ PREDICTION_URL } /predict" , json = features )
528+ assert r .status_code == 200
529+
530+ data = r .json ()
531+ assert "ttft_ms" in data
532+ assert "tpot_ms" in data
533+ assert data ["ttft_ms" ] > 0
534+ assert data ["tpot_ms" ] >= 0 # Non-negative
535+
536+ print (f"✓ Decode prediction: TTFT={ data ['ttft_ms' ]:.2f} ms, TPOT={ data ['tpot_ms' ]:.2f} ms" )
537+
538+
539+ def test_bulk_prediction_with_pod_type ():
540+ """Test bulk predictions with mixed pod types."""
541+ print ("Testing bulk prediction with pod_type..." )
542+
543+ requests_data = [
544+ # Prefill pod request
545+ {
546+ "kv_cache_percentage" : 0.5 ,
547+ "input_token_length" : 200 ,
548+ "num_request_waiting" : 4 ,
549+ "num_request_running" : 1 ,
550+ "num_tokens_generated" : 0 ,
551+ "prefix_cache_score" : 0.7 ,
552+ "pod_type" : "prefill" ,
553+ },
554+ # Decode pod request
555+ {
556+ "kv_cache_percentage" : 0.3 ,
557+ "input_token_length" : 150 ,
558+ "num_request_waiting" : 2 ,
559+ "num_request_running" : 1 ,
560+ "num_tokens_generated" : 10 ,
561+ "prefix_cache_score" : 0.5 ,
562+ "pod_type" : "decode" ,
563+ },
564+ # Legacy request (no pod_type)
565+ {
566+ "kv_cache_percentage" : 0.6 ,
567+ "input_token_length" : 300 ,
568+ "num_request_waiting" : 3 ,
569+ "num_request_running" : 2 ,
570+ "num_tokens_generated" : 5 ,
571+ "prefix_cache_score" : 0.8 ,
572+ }
573+ ]
574+
575+ bulk_request = {"requests" : requests_data }
576+
577+ r = requests .post (f"{ PREDICTION_URL } /predict/bulk/strict" , json = bulk_request )
578+ assert r .status_code == 200
579+
580+ data = r .json ()
581+ assert data ["total_requests" ] == 3
582+ assert data ["successful_predictions" ] == 3
583+ assert data ["failed_predictions" ] == 0
584+
585+ predictions = data ["predictions" ]
586+
587+ # Check prefill prediction (index 0)
588+ prefill_pred = predictions [0 ]
589+ assert prefill_pred ["ttft_ms" ] > 0
590+ assert prefill_pred ["tpot_ms" ] >= 0 # Relaxed constraint for prefill
591+ print (f" Prefill: TTFT={ prefill_pred ['ttft_ms' ]:.2f} ms, TPOT={ prefill_pred ['tpot_ms' ]:.2f} ms" )
592+
593+ # Check decode prediction (index 1)
594+ decode_pred = predictions [1 ]
595+ assert decode_pred ["ttft_ms" ] > 0
596+ assert decode_pred ["tpot_ms" ] > 0 # Should be positive for decode
597+ print (f" Decode: TTFT={ decode_pred ['ttft_ms' ]:.2f} ms, TPOT={ decode_pred ['tpot_ms' ]:.2f} ms" )
598+
599+ # Check legacy prediction (index 2)
600+ legacy_pred = predictions [2 ]
601+ assert legacy_pred ["ttft_ms" ] > 0
602+ assert legacy_pred ["tpot_ms" ] > 0
603+ print (f" Legacy: TTFT={ legacy_pred ['ttft_ms' ]:.2f} ms, TPOT={ legacy_pred ['tpot_ms' ]:.2f} ms" )
604+
605+ print ("✓ Bulk prediction with mixed pod types passed" )
606+
607+
608+ def test_training_data_with_pod_type ():
609+ """Test that training server accepts pod_type in training data."""
610+ print ("Testing training data with pod_type..." )
611+
612+ # Generate training samples with pod_type
613+ prefill_entries = []
614+ decode_entries = []
615+
616+ # Prefill training samples (TPOT should be 0)
617+ for i in range (10 ):
618+ prefill_entries .append ({
619+ "kv_cache_percentage" : 0.5 ,
620+ "input_token_length" : 200 + i * 10 ,
621+ "num_request_waiting" : i % 5 ,
622+ "num_request_running" : 1 ,
623+ "actual_ttft_ms" : 100.0 + i * 5 ,
624+ "actual_tpot_ms" : 0.0 , # Prefill doesn't produce tokens
625+ "num_tokens_generated" : 0 ,
626+ "prefix_cache_score" : 0.7 ,
627+ "pod_type" : "prefill" ,
628+ })
629+
630+ # Decode training samples (both TTFT and TPOT)
631+ for i in range (10 ):
632+ decode_entries .append ({
633+ "kv_cache_percentage" : 0.5 ,
634+ "input_token_length" : 200 + i * 10 ,
635+ "num_request_waiting" : i % 5 ,
636+ "num_request_running" : 1 ,
637+ "actual_ttft_ms" : 100.0 + i * 5 ,
638+ "actual_tpot_ms" : 10.0 + i * 2 ,
639+ "num_tokens_generated" : 5 + i ,
640+ "prefix_cache_score" : 0.7 ,
641+ "pod_type" : "decode" ,
642+ })
643+
644+ all_entries = prefill_entries + decode_entries
645+ payload = {"entries" : all_entries }
646+
647+ r = requests .post (f"{ TRAINING_URL } /add_training_data_bulk" , json = payload )
648+ assert r .status_code == 202
649+ assert r .json ().get ("message" ) == f"Accepted { len (all_entries )} training samples."
650+
651+ print (f"✓ Successfully sent { len (all_entries )} training samples with pod_type" )
652+
653+
654+ def test_invalid_pod_type ():
655+ """Test that invalid pod_type values are handled correctly."""
656+ print ("Testing invalid pod_type handling..." )
657+
658+ features = {
659+ "kv_cache_percentage" : 0.5 ,
660+ "input_token_length" : 200 ,
661+ "num_request_waiting" : 4 ,
662+ "num_request_running" : 1 ,
663+ "num_tokens_generated" : 10 ,
664+ "prefix_cache_score" : 0.7 ,
665+ "pod_type" : "invalid_type" , # Invalid pod type
666+ }
667+
668+ r = requests .post (f"{ PREDICTION_URL } /predict" , json = features )
669+
670+ # Should either accept it (treating as legacy) or reject with validation error
671+ if r .status_code == 422 :
672+ print ("✓ Invalid pod_type rejected with validation error (strict validation)" )
673+ elif r .status_code == 200 :
674+ data = r .json ()
675+ # If accepted, should still return valid predictions
676+ assert data ["ttft_ms" ] > 0
677+ assert data ["tpot_ms" ] >= 0
678+ print ("✓ Invalid pod_type accepted with fallback behavior (permissive validation)" )
679+ else :
680+ assert False , f"Unexpected status code { r .status_code } for invalid pod_type"
681+
682+
487683def test_training_server_metrics ():
488684 """Test training server metrics endpoint."""
489685 r = requests .get (f"{ TRAINING_URL } /metrics" )
0 commit comments