[ML] Harden pytorch_inference with TorchScript model graph validation#2999
Draft
edsavage wants to merge 10 commits intoelastic:mainfrom
Draft
[ML] Harden pytorch_inference with TorchScript model graph validation#2999edsavage wants to merge 10 commits intoelastic:mainfrom
edsavage wants to merge 10 commits intoelastic:mainfrom
Conversation
This reverts commit 4f1ec3e.
✅ Snyk checks have passed. No issues have been found so far.
💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse. |
…ns quotes Sanitize BUILDKITE_MESSAGE before embedding in generated pipeline YAML to prevent double quotes and multi-line content from breaking the YAML structure. Affects both run_qa_tests.yml.sh and run_pytorch_tests.yml.sh. Made-with: Cursor
Add aten::bmm, aten::ceil, aten::floor_divide, aten::gt, aten::le, and aten::sign to the allowed operations list, fixing graph validation failures for the .rerank-v1 model used by the default rerank endpoint. These operations were identified by running the full Elasticsearch inference and ML integration test suites (974 tests) against a build with graph validation enabled. Made-with: Cursor
Add aten::clamp_min, aten::eq, aten::expand_as, aten::linalg_vector_norm, and aten::sum to the allowed operations list, fixing graph validation failures for distilbert-base-uncased-finetuned-sst-2-english and sentence-transformers/all-distilroberta-v1 models. Made-with: Cursor
…d additional model ops Add .rerank-v1 (52 ops extracted from ml-models.elastic.co .pt file), distilbert-sst2, all-distilroberta-v1, and their Eland-deployed variants to the reference model golden file. All ops extracted with PyTorch 2.7.1. Add aten::detach (Eland-traced models) and aten::masked_fill_ (.rerank-v1) to the allowlist. Made-with: Cursor
Add a parallel Buildkite step that runs Elasticsearch inference integration tests against the ml-cpp build artifacts. The new step runs on its own machine alongside the existing ES tests step. Tests exercised: - DefaultEndPointsIT (ELSER, E5, rerank default endpoints) - TextEmbeddingCrudIT (E5 model CRUD via inference API) - Semantic text YAML REST tests (indexing and querying with default ELSER 2 endpoint) All tests use local prepacked models served by the test framework — no external services required. Made-with: Cursor
Extract clone/branch selection, Java configuration, and Gradle invocation into run_es_tests_common.sh. Both run_es_tests.sh and run_es_inference_tests.sh are now thin wrappers that pass their Gradle commands as arguments to the common script. Made-with: Cursor
Resolve conflict in dev-tools/run_es_tests.sh: keep the refactored thin-wrapper structure from this branch and integrate the Gradle build cache support (from elastic#2907) into run_es_tests_common.sh. Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
Reintroduces and extends TorchScript model-graph validation for pytorch_inference to reduce attack surface by enforcing an operation allowlist/denylist, and adds tooling + tests to keep that allowlist current (including quantized model support).
Changes:
- Add C++ TorchScript graph validation (
CModelGraphValidator+CSupportedOperations) and wire it intopytorch_inferencestartup. - Add Python tooling (
dev-tools/extract_model_ops/) + CMake runner to extract/validate op allowlists and detect drift via a golden JSON + C++ test. - Add adversarial TorchScript fixtures and integration tests; extend Buildkite scripts to run additional Elasticsearch inference integration tests.
Reviewed changes
Copilot reviewed 22 out of 45 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
test/CMakeLists.txt |
Runs allowlist validation script as part of test_all_parallel (optional) and adds a standalone validation target. |
docs/CHANGELOG.asciidoc |
Adds release note entry for graph validation hardening. |
dev-tools/run_es_tests_common.sh |
Factors shared logic for running ES integration tests from a local Ivy repo. |
dev-tools/run_es_tests.sh |
Uses the common runner to execute core ML REST/YAML tests. |
dev-tools/run_es_inference_tests.sh |
Adds separate runner for ES inference integration tests. |
dev-tools/generate_malicious_models.py |
Adds generator for malicious TorchScript fixtures used by validator integration tests. |
dev-tools/extract_model_ops/validation_models.json |
Defines HF model set to validate allowlist (incl. quantized variants). |
dev-tools/extract_model_ops/validate_allowlist.py |
Adds Python-side allowlist validation against traced models + local .pt fixtures. |
dev-tools/extract_model_ops/torchscript_utils.py |
Shared tracing/inlining + config loading helpers (incl. dynamic quantization). |
dev-tools/extract_model_ops/requirements.txt |
Pins Python deps for extraction/validation tooling (torch + transformers stack). |
dev-tools/extract_model_ops/reference_models.json |
Defines reference HF models used to build the allowlist union (incl. quantized). |
dev-tools/extract_model_ops/extract_model_ops.py |
Adds extractor to generate op unions / C++ initializer / golden JSON. |
dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt |
Adds ES IT TorchScript fixture for validation. |
dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt |
Adds ES IT TorchScript fixture for validation. |
dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt |
Adds ES IT TorchScript fixture for validation. |
dev-tools/extract_model_ops/es_it_models/README.md |
Documents provenance/regeneration of ES IT TorchScript fixtures. |
dev-tools/extract_model_ops/README.md |
Documents extraction/validation workflows and golden drift test. |
dev-tools/extract_model_ops/.gitignore |
Ignores the local Python venv used by the tooling. |
cmake/run-validation.cmake |
Adds portable CMake driver to create venv, install deps, and run validation. |
cmake/functions.cmake |
Wires validation into precommit target (optional). |
bin/pytorch_inference/unittest/testfiles/reference_model_ops.json |
Adds/updates golden per-model op sets for drift detection tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt |
Adds malicious TorchScript fixture used by integration tests. |
bin/pytorch_inference/unittest/CThreadSettingsTest.cc |
Switches includes to angle-bracket form. |
bin/pytorch_inference/unittest/CResultWriterTest.cc |
Switches includes to angle-bracket form. |
bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc |
Adds unit + integration tests for validator, fixtures, and allowlist drift. |
bin/pytorch_inference/unittest/CMakeLists.txt |
Adds validator test source and include path for new include style. |
bin/pytorch_inference/unittest/CCommandParserTest.cc |
Switches includes to angle-bracket form. |
bin/pytorch_inference/Main.cc |
Enforces model-graph validation at load time; improves rejection messages. |
bin/pytorch_inference/CSupportedOperations.h |
Declares forbidden/allowed op sets for validation. |
bin/pytorch_inference/CSupportedOperations.cc |
Defines forbidden/allowed TorchScript ops (incl. quantized ops). |
bin/pytorch_inference/CModelGraphValidator.h |
Declares validator API + node-count guard. |
bin/pytorch_inference/CModelGraphValidator.cc |
Implements graph inlining + op collection + allow/deny evaluation. |
bin/pytorch_inference/CMakeLists.txt |
Builds new validator + supported-ops sources into pytorch_inference. |
.buildkite/scripts/steps/run_es_inference_tests.sh |
Adds Buildkite step script for ES inference integration tests. |
.buildkite/pipelines/run_qa_tests.yml.sh |
Sanitizes Buildkite message when triggering downstream QA pipeline. |
.buildkite/pipelines/run_pytorch_tests.yml.sh |
Sanitizes Buildkite message when triggering downstream PyTorch pipeline. |
.buildkite/pipelines/run_es_inference_tests_x86_64.yml.sh |
Adds x86_64 pipeline to run ES inference integration tests. |
.buildkite/pipeline.json.py |
Uploads the new ES inference tests runner pipeline when x86_64 enabled. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+38
to
+143
| from torchscript_utils import ( | ||
| collect_graph_ops, | ||
| collect_inlined_ops, | ||
| load_and_trace_hf_model, | ||
| load_model_config, | ||
| ) | ||
|
|
||
| SCRIPT_DIR = Path(__file__).resolve().parent | ||
| REPO_ROOT = SCRIPT_DIR.parents[1] | ||
| DEFAULT_CONFIG = SCRIPT_DIR / "validation_models.json" | ||
| SUPPORTED_OPS_CC = REPO_ROOT / "bin" / "pytorch_inference" / "CSupportedOperations.cc" | ||
|
|
||
|
|
||
| def parse_string_set_from_cc(path: Path, variable_name: str) -> set[str]: | ||
| """Extract a set of string literals from a C++ TStringViewSet definition.""" | ||
| text = path.read_text() | ||
| pattern = rf'{re.escape(variable_name)}\s*=\s*\{{(.*?)\}};' | ||
| match = re.search(pattern, text, re.DOTALL) | ||
| if not match: | ||
| raise RuntimeError(f"Could not find {variable_name} in {path}") | ||
| block = match.group(1) | ||
| return set(re.findall(r'"([^"]+)"', block)) | ||
|
|
||
|
|
||
| def load_cpp_sets() -> tuple[set[str], set[str]]: | ||
| """Parse ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS from the C++ source.""" | ||
| allowed = parse_string_set_from_cc(SUPPORTED_OPS_CC, "ALLOWED_OPERATIONS") | ||
| forbidden = parse_string_set_from_cc(SUPPORTED_OPS_CC, "FORBIDDEN_OPERATIONS") | ||
| return allowed, forbidden | ||
|
|
||
|
|
||
| def load_pt_and_collect_ops(pt_path: str) -> set[str] | None: | ||
| """Load a saved TorchScript .pt file, inline, and return its op set.""" | ||
| try: | ||
| module = torch.jit.load(pt_path) | ||
| return collect_inlined_ops(module) | ||
| except Exception as exc: | ||
| print(f" LOAD ERROR: {exc}", file=sys.stderr) | ||
| return None | ||
|
|
||
|
|
||
| def check_ops(ops: set[str], | ||
| allowed: set[str], | ||
| forbidden: set[str], | ||
| verbose: bool) -> bool: | ||
| """Check an op set against allowed/forbidden lists. Returns True if all pass.""" | ||
| forbidden_found = sorted(ops & forbidden) | ||
| unrecognised = sorted(ops - allowed - forbidden) | ||
|
|
||
| if verbose: | ||
| print(f" {len(ops)} distinct ops", file=sys.stderr) | ||
|
|
||
| if forbidden_found: | ||
| print(f" FORBIDDEN: {forbidden_found}", file=sys.stderr) | ||
| if unrecognised: | ||
| print(f" UNRECOGNISED: {unrecognised}", file=sys.stderr) | ||
|
|
||
| if not forbidden_found and not unrecognised: | ||
| print(f" PASS", file=sys.stderr) | ||
| return True | ||
|
|
||
| print(f" FAIL", file=sys.stderr) | ||
| return False | ||
|
|
||
|
|
||
| def validate_model(model_name: str, | ||
| allowed: set[str], | ||
| forbidden: set[str], | ||
| verbose: bool, | ||
| quantize: bool = False) -> bool: | ||
| """Validate one HuggingFace model. Returns True if all ops pass.""" | ||
| label = f"{model_name} (quantized)" if quantize else model_name | ||
| print(f" {label}...", file=sys.stderr) | ||
| traced = load_and_trace_hf_model(model_name, quantize=quantize) | ||
| if traced is None: | ||
| print(f" FAILED (could not load/trace)", file=sys.stderr) | ||
| return False | ||
| ops = collect_inlined_ops(traced) | ||
| return check_ops(ops, allowed, forbidden, verbose) | ||
|
|
||
|
|
||
| def validate_pt_file(name: str, | ||
| pt_path: str, | ||
| allowed: set[str], | ||
| forbidden: set[str], | ||
| verbose: bool) -> bool: | ||
| """Validate a local TorchScript .pt file. Returns True if all ops pass.""" | ||
| print(f" {name} ({pt_path})...", file=sys.stderr) | ||
| ops = load_pt_and_collect_ops(pt_path) | ||
| if ops is None: | ||
| print(f" FAILED (could not load)", file=sys.stderr) | ||
| return False | ||
| return check_ops(ops, allowed, forbidden, verbose) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser( | ||
| description=__doc__, | ||
| formatter_class=argparse.RawDescriptionHelpFormatter) | ||
| parser.add_argument( | ||
| "--config", type=Path, default=DEFAULT_CONFIG, | ||
| help="Path to reference_models.json (default: %(default)s)") | ||
| parser.add_argument( | ||
| "--pt-dir", type=Path, default=None, | ||
| help="Directory of pre-saved .pt TorchScript files to validate") | ||
| parser.add_argument( |
Comment on lines
+145
to
+150
| set(_venv_site_packages "${_venv_dir}/Lib/site-packages") | ||
| else() | ||
| # Discover the site-packages directory (Python version varies) | ||
| file(GLOB _venv_site_packages "${_venv_dir}/lib/python*/site-packages") | ||
| endif() | ||
| set(_torch_lib_dir "${_venv_site_packages}/torch/lib") |
Comment on lines
+11
to
+27
| @@ -22,7 +24,7 @@ steps: | |||
| - trigger: appex-qa-stateful-custom-ml-cpp-build-testing | |||
| async: false | |||
| build: | |||
| message: "${BUILDKITE_MESSAGE}" | |||
| message: "${SAFE_MESSAGE}" | |||
Comment on lines
565
to
575
| add_custom_target(precommit | ||
| COMMENT "Running essential tasks prior to code commit" | ||
| DEPENDS format test | ||
| COMMAND ${CMAKE_COMMAND} | ||
| -DSOURCE_DIR=${CMAKE_SOURCE_DIR} | ||
| -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json | ||
| -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models | ||
| -DVALIDATE_VERBOSE=TRUE | ||
| -DOPTIONAL=TRUE | ||
| -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake | ||
| WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} |
test/CMakeLists.txt
Outdated
Comment on lines
+75
to
+79
| COMMAND ${CMAKE_COMMAND} | ||
| ${_validation_args} | ||
| -DOPTIONAL=TRUE | ||
| -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake | ||
| WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} |
cmake/run-validation.cmake
Outdated
Comment on lines
+30
to
+44
| cmake_minimum_required(VERSION 3.16) | ||
|
|
||
| if(NOT DEFINED SOURCE_DIR) | ||
| message(FATAL_ERROR "SOURCE_DIR must be defined") | ||
| endif() | ||
|
|
||
| # Helper: emit a FATAL_ERROR or a WARNING+return depending on OPTIONAL. | ||
| macro(_validation_fail _msg) | ||
| if(DEFINED OPTIONAL AND OPTIONAL) | ||
| message(WARNING "Skipping validation: ${_msg}") | ||
| return() | ||
| else() | ||
| message(FATAL_ERROR "${_msg}") | ||
| endif() | ||
| endmacro() |
Comment on lines
+128
to
+130
| ML_CPP_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" | ||
| INIT_SCRIPT="$ML_CPP_ROOT/dev-tools/gradle-build-cache-init.gradle" | ||
| GRADLE_CACHE_DIR="$HOME/.gradle/caches/build-cache-1" |
Comment on lines
+11
to
+27
| @@ -22,13 +24,7 @@ steps: | |||
| - trigger: appex-qa-stateful-custom-ml-cpp-build-testing | |||
| async: false | |||
| build: | |||
| message: | | |||
| EOL | |||
|
|
|||
| # Output the message with proper indentation for YAML literal block scalar | |||
| printf '%s\n' "${BUILDKITE_MESSAGE}" | sed 's/^/ /' | |||
|
|
|||
| cat <<EOL | |||
| message: "${SAFE_MESSAGE}" | |||
- Remove unused collect_graph_ops import and fix help text in validate_allowlist.py - Query venv Python for site-packages path instead of globbing (which can yield multiple paths) in run-validation.cmake - Bump cmake_minimum_required to 3.19.2 to match the repo - Escape backslashes in SAFE_MESSAGE for YAML double-quoted strings in pipeline scripts - Remove allowlist validation from precommit and test_all_parallel targets to keep them fast; validation remains available via the standalone validate_pytorch_inference_models target - Document why eval is necessary in run_es_tests_common.sh and properly quote the Ivy repo URL Made-with: Cursor
Resolve conflict in dev-tools/run_es_tests.sh: incorporate ES_TEST_SUITE support from elastic#2990 (parallel javaRestTest/yamlRestTest steps) into our thin-wrapper architecture that delegates to run_es_tests_common.sh. Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Re-applies #2936 and #2991 which were reverted in #2995.
CModelGraphValidator,CSupportedOperations) that rejects models containing operations not observed in supported transformer architectures, reducing the attack surface by ensuring only known-safe operation sets are permitted.aten::mul_andquantized::linear_dynamicin the allowed operations for dynamically quantized models (e.g. ELSER v2 imported via Eland).dev-tools/extract_model_ops/) to trace reference HuggingFace models and collect their op sets, with support for quantized variants.reference_model_ops.jsongolden file and C++ drift test to detect allowlist staleness on PyTorch upgrades.Test plan
test_pytorch_inferencepassesci:run-qa-testslabel applied)WIP - awaiting the output of the QA tests.