33
44from tests .model_serving .model_runtime .vllm .constant import VLLM_SUPPORTED_QUANTIZATION
55
6+ def normalize_output (output ):
7+ """
8+ Recursively normalize model output by removing or masking fields that cause non-deterministic snapshot changes.
9+ Handles nested dicts and lists.
10+ """
11+ if isinstance (output , dict ):
12+ output = output .copy ()
13+ volatile_keys = ["timestamp" , "created_at" , "updated_at" , "id" , "unique_id" , "request_id" , "uuid" , "run_id" ]
14+ for key in volatile_keys :
15+ output .pop (key , None )
16+ for k , v in output .items ():
17+ output [k ] = normalize_output (v )
18+ elif isinstance (output , list ):
19+ output = [normalize_output (item ) for item in output ]
20+ try :
21+ output = sorted (output , key = lambda x : str (x ))
22+ except Exception :
23+ pass
24+ elif isinstance (output , str ):
25+ import re
26+ output = re .sub (r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b' , '[MASKED_UUID]' , output )
27+ return output
628
729def validate_supported_quantization_schema (q_type : str ) -> None :
830 if q_type not in VLLM_SUPPORTED_QUANTIZATION :
931 raise ValueError (f"Unsupported quantization type: { q_type } " )
1032
11-
12- def validate_inference_output (* args : tuple [str , ...] | list [Any ], response_snapshot : Any ) -> None :
13- for data in args :
14- assert data == response_snapshot , f"output mismatch for { data } "
33+ def validate_inference_output (* args : tuple [Any , ...], response_snapshot : Any ) -> None :
34+ normalized_args = [normalize_output (data ) for data in args ]
35+ normalized_snapshot = normalize_output (response_snapshot )
36+ for data in normalized_args :
37+ assert data == normalized_snapshot , f"output mismatch for { data } "
1538
1639
1740def safe_k8s_name (model_name : str , max_length : int = 20 ) -> str :
@@ -49,4 +72,4 @@ def safe_k8s_name(model_name: str, max_length: int = 20) -> str:
4972 if not safe_name :
5073 return "model"
5174
52- return safe_name
75+ return safe_name
0 commit comments