|
1 | 1 | """Utilities for LLM Deployment (LLMD) resources.""" |
2 | 2 |
|
3 | | -import json |
4 | | -import re |
5 | | -import shlex |
6 | 3 | from collections.abc import Generator |
7 | 4 | from contextlib import contextmanager |
8 | | -from string import Template |
9 | 5 | from typing import Any |
10 | 6 |
|
11 | 7 | from kubernetes.dynamic import DynamicClient |
12 | 8 | from ocp_resources.gateway import Gateway |
13 | 9 | from ocp_resources.llm_inference_service import LLMInferenceService |
14 | | -from pyhelper_utils.shell import run_command |
15 | 10 | from simple_logger.logger import get_logger |
16 | | -from timeout_sampler import TimeoutWatch, retry |
| 11 | +from timeout_sampler import TimeoutWatch |
17 | 12 |
|
18 | | -from utilities.certificates_utils import get_ca_bundle |
19 | | -from utilities.constants import HTTPRequest, Timeout |
20 | | -from utilities.exceptions import InferenceResponseError |
| 13 | +from utilities.constants import Timeout |
21 | 14 | from utilities.infra import get_services_by_isvc_label |
22 | 15 | from utilities.llmd_constants import ( |
23 | 16 | ContainerImages, |
24 | 17 | KServeGateway, |
25 | 18 | LLMDGateway, |
26 | | - LLMEndpoint, |
27 | 19 | ) |
28 | 20 |
|
29 | 21 | LOGGER = get_logger(name=__name__) |
@@ -396,340 +388,3 @@ def get_llm_inference_url(llm_service: LLMInferenceService) -> str: |
396 | 388 | fallback_url = f"http://{llm_service.name}.{llm_service.namespace}.svc.cluster.local" |
397 | 389 | LOGGER.debug(f"Using fallback URL for {llm_service.name}: {fallback_url}") |
398 | 390 | return fallback_url |
399 | | - |
400 | | - |
401 | | -def verify_inference_response_llmd( |
402 | | - llm_service: LLMInferenceService, |
403 | | - inference_config: dict[str, Any], |
404 | | - inference_type: str, |
405 | | - protocol: str, |
406 | | - model_name: str | None = None, |
407 | | - inference_input: Any | None = None, |
408 | | - use_default_query: bool = False, |
409 | | - expected_response_text: str | None = None, |
410 | | - insecure: bool = False, |
411 | | - token: str | None = None, |
412 | | - authorized_user: bool | None = None, |
413 | | -) -> None: |
414 | | - """ |
415 | | - Verify the LLM inference response following the pattern of verify_inference_response. |
416 | | -
|
417 | | - Args: |
418 | | - llm_service: LLMInferenceService resource to test |
419 | | - inference_config: Inference configuration dictionary |
420 | | - inference_type: Type of inference ('infer', 'streaming', etc.) |
421 | | - protocol: Protocol to use ('http', 'grpc') |
422 | | - model_name: Name of the model (defaults to service name) |
423 | | - inference_input: Input for inference (optional) |
424 | | - use_default_query: Whether to use default query from config |
425 | | - expected_response_text: Expected response text for validation |
426 | | - insecure: Whether to use insecure connections |
427 | | - token: Authentication token (optional) |
428 | | - authorized_user: Whether user should be authorized (optional) |
429 | | -
|
430 | | - Raises: |
431 | | - InferenceResponseError: If inference response is invalid |
432 | | - ValueError: If inference response validation fails |
433 | | - """ |
434 | | - |
435 | | - model_name = model_name or llm_service.name |
436 | | - inference = LLMUserInference( |
437 | | - llm_service=llm_service, |
438 | | - inference_config=inference_config, |
439 | | - inference_type=inference_type, |
440 | | - protocol=protocol, |
441 | | - ) |
442 | | - |
443 | | - res = inference.run_inference_flow( |
444 | | - model_name=model_name, |
445 | | - inference_input=inference_input, |
446 | | - use_default_query=use_default_query, |
447 | | - token=token, |
448 | | - insecure=insecure, |
449 | | - ) |
450 | | - |
451 | | - if authorized_user is False: |
452 | | - _validate_unauthorized_response(res=res, token=token, inference=inference) |
453 | | - else: |
454 | | - _validate_authorized_response( |
455 | | - res=res, |
456 | | - inference=inference, |
457 | | - inference_config=inference_config, |
458 | | - inference_type=inference_type, |
459 | | - expected_response_text=expected_response_text, |
460 | | - use_default_query=use_default_query, |
461 | | - model_name=model_name, |
462 | | - ) |
463 | | - |
464 | | - |
465 | | -class LLMUserInference: |
466 | | - """ |
467 | | - LLM-specific inference handler following the pattern of UserInference. |
468 | | - """ |
469 | | - |
470 | | - STREAMING = "streaming" |
471 | | - INFER = "infer" |
472 | | - |
473 | | - def __init__( |
474 | | - self, |
475 | | - llm_service: LLMInferenceService, |
476 | | - inference_config: dict[str, Any], |
477 | | - inference_type: str, |
478 | | - protocol: str, |
479 | | - ) -> None: |
480 | | - self.llm_service = llm_service |
481 | | - self.inference_config = inference_config |
482 | | - self.inference_type = inference_type |
483 | | - self.protocol = protocol |
484 | | - self.runtime_config = self.get_runtime_config() |
485 | | - |
486 | | - def get_runtime_config(self) -> dict[str, Any]: |
487 | | - """Get runtime config from inference config based on inference type and protocol.""" |
488 | | - if inference_type_config := self.inference_config.get(self.inference_type): |
489 | | - protocol = "http" if self.protocol.lower() in ["http", "https"] else self.protocol |
490 | | - if data := inference_type_config.get(protocol): |
491 | | - return data |
492 | | - else: |
493 | | - raise ValueError(f"Protocol {protocol} not supported for inference type {self.inference_type}") |
494 | | - else: |
495 | | - raise ValueError(f"Inference type {self.inference_type} not supported in config") |
496 | | - |
497 | | - @property |
498 | | - def inference_response_text_key_name(self) -> str | None: |
499 | | - """Get inference response text key name from runtime config.""" |
500 | | - return self.runtime_config.get("response_fields_map", {}).get("response_output") |
501 | | - |
502 | | - @property |
503 | | - def inference_response_key_name(self) -> str: |
504 | | - """Get inference response key name from runtime config.""" |
505 | | - return self.runtime_config.get("response_fields_map", {}).get("response", "output") |
506 | | - |
507 | | - def get_inference_body( |
508 | | - self, |
509 | | - model_name: str, |
510 | | - inference_input: Any | None = None, |
511 | | - use_default_query: bool = False, |
512 | | - ) -> str: |
513 | | - """Get inference body for LLM request.""" |
514 | | - if not use_default_query and inference_input is None: |
515 | | - raise ValueError("Either pass `inference_input` or set `use_default_query` to True") |
516 | | - |
517 | | - if use_default_query: |
518 | | - default_query_config = self.inference_config.get("default_query_model") |
519 | | - if not default_query_config: |
520 | | - raise ValueError(f"Missing default query config for {model_name}") |
521 | | - |
522 | | - if self.inference_config.get("support_multi_default_queries"): |
523 | | - query_config = default_query_config.get(self.inference_type) |
524 | | - if not query_config: |
525 | | - raise ValueError(f"Missing default query for inference type {self.inference_type}") |
526 | | - query_input = query_config.get("query_input", "") |
527 | | - else: |
528 | | - query_input = default_query_config.get("query_input", "") |
529 | | - |
530 | | - # Use the proper JSON body template from runtime config |
531 | | - body_template = self.runtime_config.get("body", "") |
532 | | - if body_template: |
533 | | - # Use template substitution for both model name and query input |
534 | | - template = Template(template=body_template) |
535 | | - body = template.safe_substitute(model_name=model_name, query_input=query_input) |
536 | | - else: |
537 | | - # Fallback to plain text (legacy behavior) |
538 | | - template = Template(template=query_input) |
539 | | - body = template.safe_substitute(model_name=model_name) |
540 | | - else: |
541 | | - # For custom input, create OpenAI-compatible format |
542 | | - if isinstance(inference_input, str): |
543 | | - body = json.dumps({ |
544 | | - "model": model_name, |
545 | | - "messages": [{"role": "user", "content": inference_input}], |
546 | | - "max_tokens": 100, |
547 | | - "temperature": 0.0, |
548 | | - }) |
549 | | - else: |
550 | | - body = json.dumps(inference_input) |
551 | | - |
552 | | - return body |
553 | | - |
554 | | - def generate_command( |
555 | | - self, |
556 | | - model_name: str, |
557 | | - inference_input: str | None = None, |
558 | | - use_default_query: bool = False, |
559 | | - insecure: bool = False, |
560 | | - token: str | None = None, |
561 | | - ) -> str: |
562 | | - """Generate curl command string for LLM inference.""" |
563 | | - base_url = get_llm_inference_url(llm_service=self.llm_service) |
564 | | - endpoint_url = f"{base_url}{LLMEndpoint.CHAT_COMPLETIONS}" |
565 | | - |
566 | | - body = self.get_inference_body( |
567 | | - model_name=model_name, |
568 | | - inference_input=inference_input, |
569 | | - use_default_query=use_default_query, |
570 | | - ) |
571 | | - |
572 | | - header = HTTPRequest.CONTENT_JSON.replace("-H ", "") |
573 | | - cmd_exec = "curl -i -s" |
574 | | - cmd = f"{cmd_exec} -X POST -d '{body}' -H {header} -H 'Accept: application/json'" |
575 | | - |
576 | | - if token: |
577 | | - cmd += f" {HTTPRequest.AUTH_HEADER.format(token=token)}" |
578 | | - |
579 | | - if insecure: |
580 | | - cmd += " --insecure" |
581 | | - else: |
582 | | - try: |
583 | | - from ocp_resources.resource import get_client |
584 | | - |
585 | | - client = get_client() |
586 | | - ca_bundle = get_ca_bundle(client=client) |
587 | | - if ca_bundle: |
588 | | - cmd += f" --cacert {ca_bundle}" |
589 | | - else: |
590 | | - cmd += " --insecure" |
591 | | - except Exception: # noqa: BLE001 |
592 | | - cmd += " --insecure" |
593 | | - |
594 | | - cmd += f" --max-time {LLMEndpoint.DEFAULT_TIMEOUT} {endpoint_url}" |
595 | | - return cmd |
596 | | - |
597 | | - @retry(wait_timeout=Timeout.TIMEOUT_30SEC, sleep=5) |
598 | | - def run_inference( |
599 | | - self, |
600 | | - model_name: str, |
601 | | - inference_input: str | None = None, |
602 | | - use_default_query: bool = False, |
603 | | - insecure: bool = False, |
604 | | - token: str | None = None, |
605 | | - ) -> str: |
606 | | - """Run inference command and return raw output.""" |
607 | | - cmd = self.generate_command( |
608 | | - model_name=model_name, |
609 | | - inference_input=inference_input, |
610 | | - use_default_query=use_default_query, |
611 | | - insecure=insecure, |
612 | | - token=token, |
613 | | - ) |
614 | | - |
615 | | - res, out, err = run_command(command=shlex.split(cmd), verify_stderr=False, check=False) |
616 | | - if res: |
617 | | - return out |
618 | | - raise ValueError(f"Inference failed with error: {err}\nOutput: {out}\nCommand: {cmd}") |
619 | | - |
620 | | - def run_inference_flow( |
621 | | - self, |
622 | | - model_name: str, |
623 | | - inference_input: str | None = None, |
624 | | - use_default_query: bool = False, |
625 | | - insecure: bool = False, |
626 | | - token: str | None = None, |
627 | | - ) -> dict[str, Any]: |
628 | | - """Run LLM inference using the same high-level flow as inference_utils.""" |
629 | | - out = self.run_inference( |
630 | | - model_name=model_name, |
631 | | - inference_input=inference_input, |
632 | | - use_default_query=use_default_query, |
633 | | - insecure=insecure, |
634 | | - token=token, |
635 | | - ) |
636 | | - return {"output": out} |
637 | | - |
638 | | - |
639 | | -def _validate_unauthorized_response(res: dict[str, Any], token: str | None, inference: LLMUserInference) -> None: |
640 | | - """Validate response for unauthorized users.""" |
641 | | - auth_header = "x-ext-auth-reason" |
642 | | - |
643 | | - if auth_reason := re.search(rf"{auth_header}: (.*)", res["output"], re.MULTILINE): |
644 | | - reason = auth_reason.group(1).lower() |
645 | | - |
646 | | - if token: |
647 | | - assert re.search(r"not (?:authenticated|authorized)", reason) |
648 | | - else: |
649 | | - assert "credential not found" in reason |
650 | | - else: |
651 | | - forbidden_patterns = ["Forbidden", "401", "403", "Unauthorized"] |
652 | | - output = res["output"] |
653 | | - |
654 | | - if any(pattern in output for pattern in forbidden_patterns): |
655 | | - return |
656 | | - |
657 | | - raise ValueError(f"Auth header {auth_header} not found in response. Response: {output}") |
658 | | - |
659 | | - |
660 | | -def _validate_authorized_response( |
661 | | - res: dict[str, Any], |
662 | | - inference: LLMUserInference, |
663 | | - inference_config: dict[str, Any], |
664 | | - inference_type: str, |
665 | | - expected_response_text: str | None, |
666 | | - use_default_query: bool, |
667 | | - model_name: str, |
668 | | -) -> None: |
669 | | - """Validate response for authorized users.""" |
670 | | - |
671 | | - use_regex = False |
672 | | - |
673 | | - if use_default_query: |
674 | | - expected_response_text_config = inference_config.get("default_query_model", {}) |
675 | | - use_regex = expected_response_text_config.get("use_regex", False) |
676 | | - |
677 | | - if not expected_response_text_config: |
678 | | - raise ValueError(f"Missing default_query_model config for inference {inference_config}") |
679 | | - |
680 | | - if inference_config.get("support_multi_default_queries"): |
681 | | - query_config = expected_response_text_config.get(inference_type) |
682 | | - if not query_config: |
683 | | - raise ValueError(f"Missing default_query_model config for inference type {inference_type}") |
684 | | - expected_response_text = query_config.get("query_output", "") |
685 | | - use_regex = query_config.get("use_regex", False) |
686 | | - else: |
687 | | - expected_response_text = expected_response_text_config.get("query_output") |
688 | | - |
689 | | - if not expected_response_text: |
690 | | - raise ValueError(f"Missing response text key for inference {inference_config}") |
691 | | - |
692 | | - if isinstance(expected_response_text, str): |
693 | | - expected_response_text = Template(template=expected_response_text).safe_substitute(model_name=model_name) |
694 | | - elif isinstance(expected_response_text, dict): |
695 | | - response_output = expected_response_text.get("response_output") |
696 | | - if response_output is not None: |
697 | | - expected_response_text = Template(template=response_output).safe_substitute(model_name=model_name) |
698 | | - if inference.inference_response_text_key_name: |
699 | | - if inference_type == inference.STREAMING: |
700 | | - if output := re.findall( |
701 | | - rf"{inference.inference_response_text_key_name}\": \"(.*)\"", |
702 | | - res[inference.inference_response_key_name], |
703 | | - re.MULTILINE, |
704 | | - ): |
705 | | - assert "".join(output) == expected_response_text, ( |
706 | | - f"Expected: {expected_response_text} does not match response: {output}" |
707 | | - ) |
708 | | - elif inference_type == inference.INFER or use_regex: |
709 | | - formatted_res = json.dumps(res[inference.inference_response_text_key_name]).replace(" ", "") |
710 | | - if use_regex and expected_response_text is not None: |
711 | | - assert re.search(expected_response_text, formatted_res), ( |
712 | | - f"Expected: {expected_response_text} not found in: {formatted_res}" |
713 | | - ) |
714 | | - else: |
715 | | - formatted_res = json.dumps(res[inference.inference_response_key_name]).replace(" ", "") |
716 | | - assert formatted_res == expected_response_text, ( |
717 | | - f"Expected: {expected_response_text} does not match output: {formatted_res}" |
718 | | - ) |
719 | | - else: |
720 | | - response = res[inference.inference_response_key_name] |
721 | | - if isinstance(response, list): |
722 | | - response = response[0] |
723 | | - |
724 | | - if isinstance(response, dict): |
725 | | - response_text = response[inference.inference_response_text_key_name] |
726 | | - assert response_text == expected_response_text, ( |
727 | | - f"Expected: {expected_response_text} does not match response: {response_text}" |
728 | | - ) |
729 | | - else: |
730 | | - raise InferenceResponseError( |
731 | | - "Inference response output does not match expected output format." |
732 | | - f"Expected: {expected_response_text}.\nResponse: {res}" |
733 | | - ) |
734 | | - else: |
735 | | - raise InferenceResponseError(f"Inference response output not found in response. Response: {res}") |
0 commit comments