Skip to content

[ML] Harden pytorch_inference with TorchScript model graph validation#2999

Draft
edsavage wants to merge 10 commits intoelastic:mainfrom
edsavage:feature/harden-pytorch-inference-v2
Draft

[ML] Harden pytorch_inference with TorchScript model graph validation#2999
edsavage wants to merge 10 commits intoelastic:mainfrom
edsavage:feature/harden-pytorch-inference-v2

Conversation

@edsavage
Copy link
Contributor

@edsavage edsavage commented Mar 15, 2026

Summary

Re-applies #2936 and #2991 which were reverted in #2995.

  • Adds a static TorchScript graph validation layer (CModelGraphValidator, CSupportedOperations) that rejects models containing operations not observed in supported transformer architectures, reducing the attack surface by ensuring only known-safe operation sets are permitted.
  • Includes aten::mul_ and quantized::linear_dynamic in the allowed operations for dynamically quantized models (e.g. ELSER v2 imported via Eland).
  • Adds Python extraction tooling (dev-tools/extract_model_ops/) to trace reference HuggingFace models and collect their op sets, with support for quantized variants.
  • Adds reference_model_ops.json golden file and C++ drift test to detect allowlist staleness on PyTorch upgrades.
  • Adds adversarial "evil model" integration tests to verify rejection of forbidden operations.
  • Adds CHANGELOG entry.

Test plan

  • Local test_pytorch_inference passes
  • CI builds pass on all platforms
  • QA tests pass (ci:run-qa-tests label applied)

WIP - awaiting the output of the QA tests.

@prodsecmachine
Copy link

prodsecmachine commented Mar 15, 2026

Snyk checks have passed. No issues have been found so far.

Status Scan Engine Critical High Medium Low Total (0)
Open Source Security 0 0 0 0 0 issues
Licenses 0 0 0 0 0 issues

💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse.

@edsavage edsavage marked this pull request as draft March 15, 2026 21:19
…ns quotes

Sanitize BUILDKITE_MESSAGE before embedding in generated pipeline YAML
to prevent double quotes and multi-line content from breaking the YAML
structure. Affects both run_qa_tests.yml.sh and run_pytorch_tests.yml.sh.

Made-with: Cursor
Add aten::bmm, aten::ceil, aten::floor_divide, aten::gt, aten::le, and
aten::sign to the allowed operations list, fixing graph validation
failures for the .rerank-v1 model used by the default rerank endpoint.

These operations were identified by running the full Elasticsearch
inference and ML integration test suites (974 tests) against a build
with graph validation enabled.

Made-with: Cursor
Add aten::clamp_min, aten::eq, aten::expand_as, aten::linalg_vector_norm,
and aten::sum to the allowed operations list, fixing graph validation
failures for distilbert-base-uncased-finetuned-sst-2-english and
sentence-transformers/all-distilroberta-v1 models.

Made-with: Cursor
…d additional model ops

Add .rerank-v1 (52 ops extracted from ml-models.elastic.co .pt file),
distilbert-sst2, all-distilroberta-v1, and their Eland-deployed variants
to the reference model golden file. All ops extracted with PyTorch 2.7.1.

Add aten::detach (Eland-traced models) and aten::masked_fill_ (.rerank-v1)
to the allowlist.

Made-with: Cursor
Add a parallel Buildkite step that runs Elasticsearch inference
integration tests against the ml-cpp build artifacts. The new step
runs on its own machine alongside the existing ES tests step.

Tests exercised:
- DefaultEndPointsIT (ELSER, E5, rerank default endpoints)
- TextEmbeddingCrudIT (E5 model CRUD via inference API)
- Semantic text YAML REST tests (indexing and querying with default
  ELSER 2 endpoint)

All tests use local prepacked models served by the test framework —
no external services required.

Made-with: Cursor
Extract clone/branch selection, Java configuration, and Gradle
invocation into run_es_tests_common.sh.  Both run_es_tests.sh and
run_es_inference_tests.sh are now thin wrappers that pass their
Gradle commands as arguments to the common script.

Made-with: Cursor
Resolve conflict in dev-tools/run_es_tests.sh: keep the refactored
thin-wrapper structure from this branch and integrate the Gradle build
cache support (from elastic#2907) into run_es_tests_common.sh.

Made-with: Cursor
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Reintroduces and extends TorchScript model-graph validation for pytorch_inference to reduce attack surface by enforcing an operation allowlist/denylist, and adds tooling + tests to keep that allowlist current (including quantized model support).

Changes:

  • Add C++ TorchScript graph validation (CModelGraphValidator + CSupportedOperations) and wire it into pytorch_inference startup.
  • Add Python tooling (dev-tools/extract_model_ops/) + CMake runner to extract/validate op allowlists and detect drift via a golden JSON + C++ test.
  • Add adversarial TorchScript fixtures and integration tests; extend Buildkite scripts to run additional Elasticsearch inference integration tests.

Reviewed changes

Copilot reviewed 22 out of 45 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
test/CMakeLists.txt Runs allowlist validation script as part of test_all_parallel (optional) and adds a standalone validation target.
docs/CHANGELOG.asciidoc Adds release note entry for graph validation hardening.
dev-tools/run_es_tests_common.sh Factors shared logic for running ES integration tests from a local Ivy repo.
dev-tools/run_es_tests.sh Uses the common runner to execute core ML REST/YAML tests.
dev-tools/run_es_inference_tests.sh Adds separate runner for ES inference integration tests.
dev-tools/generate_malicious_models.py Adds generator for malicious TorchScript fixtures used by validator integration tests.
dev-tools/extract_model_ops/validation_models.json Defines HF model set to validate allowlist (incl. quantized variants).
dev-tools/extract_model_ops/validate_allowlist.py Adds Python-side allowlist validation against traced models + local .pt fixtures.
dev-tools/extract_model_ops/torchscript_utils.py Shared tracing/inlining + config loading helpers (incl. dynamic quantization).
dev-tools/extract_model_ops/requirements.txt Pins Python deps for extraction/validation tooling (torch + transformers stack).
dev-tools/extract_model_ops/reference_models.json Defines reference HF models used to build the allowlist union (incl. quantized).
dev-tools/extract_model_ops/extract_model_ops.py Adds extractor to generate op unions / C++ initializer / golden JSON.
dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt Adds ES IT TorchScript fixture for validation.
dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt Adds ES IT TorchScript fixture for validation.
dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt Adds ES IT TorchScript fixture for validation.
dev-tools/extract_model_ops/es_it_models/README.md Documents provenance/regeneration of ES IT TorchScript fixtures.
dev-tools/extract_model_ops/README.md Documents extraction/validation workflows and golden drift test.
dev-tools/extract_model_ops/.gitignore Ignores the local Python venv used by the tooling.
cmake/run-validation.cmake Adds portable CMake driver to create venv, install deps, and run validation.
cmake/functions.cmake Wires validation into precommit target (optional).
bin/pytorch_inference/unittest/testfiles/reference_model_ops.json Adds/updates golden per-model op sets for drift detection tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt Adds malicious TorchScript fixture used by integration tests.
bin/pytorch_inference/unittest/CThreadSettingsTest.cc Switches includes to angle-bracket form.
bin/pytorch_inference/unittest/CResultWriterTest.cc Switches includes to angle-bracket form.
bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc Adds unit + integration tests for validator, fixtures, and allowlist drift.
bin/pytorch_inference/unittest/CMakeLists.txt Adds validator test source and include path for new include style.
bin/pytorch_inference/unittest/CCommandParserTest.cc Switches includes to angle-bracket form.
bin/pytorch_inference/Main.cc Enforces model-graph validation at load time; improves rejection messages.
bin/pytorch_inference/CSupportedOperations.h Declares forbidden/allowed op sets for validation.
bin/pytorch_inference/CSupportedOperations.cc Defines forbidden/allowed TorchScript ops (incl. quantized ops).
bin/pytorch_inference/CModelGraphValidator.h Declares validator API + node-count guard.
bin/pytorch_inference/CModelGraphValidator.cc Implements graph inlining + op collection + allow/deny evaluation.
bin/pytorch_inference/CMakeLists.txt Builds new validator + supported-ops sources into pytorch_inference.
.buildkite/scripts/steps/run_es_inference_tests.sh Adds Buildkite step script for ES inference integration tests.
.buildkite/pipelines/run_qa_tests.yml.sh Sanitizes Buildkite message when triggering downstream QA pipeline.
.buildkite/pipelines/run_pytorch_tests.yml.sh Sanitizes Buildkite message when triggering downstream PyTorch pipeline.
.buildkite/pipelines/run_es_inference_tests_x86_64.yml.sh Adds x86_64 pipeline to run ES inference integration tests.
.buildkite/pipeline.json.py Uploads the new ES inference tests runner pipeline when x86_64 enabled.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +38 to +143
from torchscript_utils import (
collect_graph_ops,
collect_inlined_ops,
load_and_trace_hf_model,
load_model_config,
)

SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parents[1]
DEFAULT_CONFIG = SCRIPT_DIR / "validation_models.json"
SUPPORTED_OPS_CC = REPO_ROOT / "bin" / "pytorch_inference" / "CSupportedOperations.cc"


def parse_string_set_from_cc(path: Path, variable_name: str) -> set[str]:
"""Extract a set of string literals from a C++ TStringViewSet definition."""
text = path.read_text()
pattern = rf'{re.escape(variable_name)}\s*=\s*\{{(.*?)\}};'
match = re.search(pattern, text, re.DOTALL)
if not match:
raise RuntimeError(f"Could not find {variable_name} in {path}")
block = match.group(1)
return set(re.findall(r'"([^"]+)"', block))


def load_cpp_sets() -> tuple[set[str], set[str]]:
"""Parse ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS from the C++ source."""
allowed = parse_string_set_from_cc(SUPPORTED_OPS_CC, "ALLOWED_OPERATIONS")
forbidden = parse_string_set_from_cc(SUPPORTED_OPS_CC, "FORBIDDEN_OPERATIONS")
return allowed, forbidden


def load_pt_and_collect_ops(pt_path: str) -> set[str] | None:
"""Load a saved TorchScript .pt file, inline, and return its op set."""
try:
module = torch.jit.load(pt_path)
return collect_inlined_ops(module)
except Exception as exc:
print(f" LOAD ERROR: {exc}", file=sys.stderr)
return None


def check_ops(ops: set[str],
allowed: set[str],
forbidden: set[str],
verbose: bool) -> bool:
"""Check an op set against allowed/forbidden lists. Returns True if all pass."""
forbidden_found = sorted(ops & forbidden)
unrecognised = sorted(ops - allowed - forbidden)

if verbose:
print(f" {len(ops)} distinct ops", file=sys.stderr)

if forbidden_found:
print(f" FORBIDDEN: {forbidden_found}", file=sys.stderr)
if unrecognised:
print(f" UNRECOGNISED: {unrecognised}", file=sys.stderr)

if not forbidden_found and not unrecognised:
print(f" PASS", file=sys.stderr)
return True

print(f" FAIL", file=sys.stderr)
return False


def validate_model(model_name: str,
allowed: set[str],
forbidden: set[str],
verbose: bool,
quantize: bool = False) -> bool:
"""Validate one HuggingFace model. Returns True if all ops pass."""
label = f"{model_name} (quantized)" if quantize else model_name
print(f" {label}...", file=sys.stderr)
traced = load_and_trace_hf_model(model_name, quantize=quantize)
if traced is None:
print(f" FAILED (could not load/trace)", file=sys.stderr)
return False
ops = collect_inlined_ops(traced)
return check_ops(ops, allowed, forbidden, verbose)


def validate_pt_file(name: str,
pt_path: str,
allowed: set[str],
forbidden: set[str],
verbose: bool) -> bool:
"""Validate a local TorchScript .pt file. Returns True if all ops pass."""
print(f" {name} ({pt_path})...", file=sys.stderr)
ops = load_pt_and_collect_ops(pt_path)
if ops is None:
print(f" FAILED (could not load)", file=sys.stderr)
return False
return check_ops(ops, allowed, forbidden, verbose)


def main():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"--config", type=Path, default=DEFAULT_CONFIG,
help="Path to reference_models.json (default: %(default)s)")
parser.add_argument(
"--pt-dir", type=Path, default=None,
help="Directory of pre-saved .pt TorchScript files to validate")
parser.add_argument(
Comment on lines +145 to +150
set(_venv_site_packages "${_venv_dir}/Lib/site-packages")
else()
# Discover the site-packages directory (Python version varies)
file(GLOB _venv_site_packages "${_venv_dir}/lib/python*/site-packages")
endif()
set(_torch_lib_dir "${_venv_site_packages}/torch/lib")
Comment on lines +11 to +27
@@ -22,7 +24,7 @@ steps:
- trigger: appex-qa-stateful-custom-ml-cpp-build-testing
async: false
build:
message: "${BUILDKITE_MESSAGE}"
message: "${SAFE_MESSAGE}"
Comment on lines 565 to 575
add_custom_target(precommit
COMMENT "Running essential tasks prior to code commit"
DEPENDS format test
COMMAND ${CMAKE_COMMAND}
-DSOURCE_DIR=${CMAKE_SOURCE_DIR}
-DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json
-DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models
-DVALIDATE_VERBOSE=TRUE
-DOPTIONAL=TRUE
-P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
Comment on lines +75 to +79
COMMAND ${CMAKE_COMMAND}
${_validation_args}
-DOPTIONAL=TRUE
-P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
Comment on lines +30 to +44
cmake_minimum_required(VERSION 3.16)

if(NOT DEFINED SOURCE_DIR)
message(FATAL_ERROR "SOURCE_DIR must be defined")
endif()

# Helper: emit a FATAL_ERROR or a WARNING+return depending on OPTIONAL.
macro(_validation_fail _msg)
if(DEFINED OPTIONAL AND OPTIONAL)
message(WARNING "Skipping validation: ${_msg}")
return()
else()
message(FATAL_ERROR "${_msg}")
endif()
endmacro()
Comment on lines +128 to +130
ML_CPP_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
INIT_SCRIPT="$ML_CPP_ROOT/dev-tools/gradle-build-cache-init.gradle"
GRADLE_CACHE_DIR="$HOME/.gradle/caches/build-cache-1"
Comment on lines +11 to +27
@@ -22,13 +24,7 @@ steps:
- trigger: appex-qa-stateful-custom-ml-cpp-build-testing
async: false
build:
message: |
EOL

# Output the message with proper indentation for YAML literal block scalar
printf '%s\n' "${BUILDKITE_MESSAGE}" | sed 's/^/ /'

cat <<EOL
message: "${SAFE_MESSAGE}"
- Remove unused collect_graph_ops import and fix help text in
  validate_allowlist.py
- Query venv Python for site-packages path instead of globbing
  (which can yield multiple paths) in run-validation.cmake
- Bump cmake_minimum_required to 3.19.2 to match the repo
- Escape backslashes in SAFE_MESSAGE for YAML double-quoted strings
  in pipeline scripts
- Remove allowlist validation from precommit and test_all_parallel
  targets to keep them fast; validation remains available via the
  standalone validate_pytorch_inference_models target
- Document why eval is necessary in run_es_tests_common.sh and
  properly quote the Ivy repo URL

Made-with: Cursor
Resolve conflict in dev-tools/run_es_tests.sh: incorporate
ES_TEST_SUITE support from elastic#2990 (parallel javaRestTest/yamlRestTest
steps) into our thin-wrapper architecture that delegates to
run_es_tests_common.sh.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants