Skip to content

Commit ac58ba5

Browse files
committed
Add fix for changing snapshots
1 parent a489187 commit ac58ba5

File tree

1 file changed

+28
-5
lines changed
  • tests/model_serving/model_runtime/model_validation

1 file changed

+28
-5
lines changed

tests/model_serving/model_runtime/model_validation/utils.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,38 @@
33

44
from tests.model_serving.model_runtime.vllm.constant import VLLM_SUPPORTED_QUANTIZATION
55

6+
def normalize_output(output):
7+
"""
8+
Recursively normalize model output by removing or masking fields that cause non-deterministic snapshot changes.
9+
Handles nested dicts and lists.
10+
"""
11+
if isinstance(output, dict):
12+
output = output.copy()
13+
volatile_keys = ["timestamp", "created_at", "updated_at", "id", "unique_id", "request_id", "uuid", "run_id"]
14+
for key in volatile_keys:
15+
output.pop(key, None)
16+
for k, v in output.items():
17+
output[k] = normalize_output(v)
18+
elif isinstance(output, list):
19+
output = [normalize_output(item) for item in output]
20+
try:
21+
output = sorted(output, key=lambda x: str(x))
22+
except Exception:
23+
pass
24+
elif isinstance(output, str):
25+
import re
26+
output = re.sub(r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b', '[MASKED_UUID]', output)
27+
return output
628

729
def validate_supported_quantization_schema(q_type: str) -> None:
830
if q_type not in VLLM_SUPPORTED_QUANTIZATION:
931
raise ValueError(f"Unsupported quantization type: {q_type}")
1032

11-
12-
def validate_inference_output(*args: tuple[str, ...] | list[Any], response_snapshot: Any) -> None:
13-
for data in args:
14-
assert data == response_snapshot, f"output mismatch for {data}"
33+
def validate_inference_output(*args: tuple[Any, ...], response_snapshot: Any) -> None:
34+
normalized_args = [normalize_output(data) for data in args]
35+
normalized_snapshot = normalize_output(response_snapshot)
36+
for data in normalized_args:
37+
assert data == normalized_snapshot, f"output mismatch for {data}"
1538

1639

1740
def safe_k8s_name(model_name: str, max_length: int = 20) -> str:
@@ -49,4 +72,4 @@ def safe_k8s_name(model_name: str, max_length: int = 20) -> str:
4972
if not safe_name:
5073
return "model"
5174

52-
return safe_name
75+
return safe_name

0 commit comments

Comments
 (0)