Skip to content

mem: disable ONNX mem_pattern and cpu_mem_arena on inference sessions#484

Open
KRRT7 wants to merge 1 commit intoUnstructured-IO:mainfrom
KRRT7:mem/disable-onnx-mem-pattern
Open

mem: disable ONNX mem_pattern and cpu_mem_arena on inference sessions#484
KRRT7 wants to merge 1 commit intoUnstructured-IO:mainfrom
KRRT7:mem/disable-onnx-mem-pattern

Conversation

@KRRT7
Copy link

@KRRT7 KRRT7 commented Mar 19, 2026

Disable enable_mem_pattern and enable_cpu_mem_arena on SessionOptions for both YoloX and Detectron2 ONNX sessions.

By default ONNX Runtime pre-allocates a memory arena and traces allocation patterns from the first Run() to reuse memory on subsequent calls. This trades higher idle memory for faster repeated inference. For our use-case (one model loaded per worker, infrequent re-inference on the same session), the arena and pattern buffer are mostly wasted — they keep ~200 MB of pre-allocated native memory alive between requests.

With both disabled, idle session memory drops significantly with negligible latency impact on inference.

Benchmark

Measured with memray (memray run + memray stats --json), 3 inference iterations per configuration, on Apple M3 Max / Python 3.12. Uses the actual yolox_l0.05.onnx model with 1700x2200 input (letter page at 200 DPI).

bench_onnx_mem_pattern

ONNX Runtime mem_pattern / cpu_mem_arena benchmark
YoloX layout model  |  3 inference runs  |  Python 3.12.12

Configuration                     Peak MB      Saved      %
------------------------------------------------------------
Default (both enabled)            553.1MB      0.0MB   0.0%
mem_pattern=False                 421.1MB    132.0MB  23.9%
cpu_arena=False                   385.9MB    167.2MB  30.2%
Both disabled                     351.4MB    201.7MB  36.5%

Reproduce

pip install memray numpy pillow opencv-python-headless onnxruntime plotly kaleido
python bench_onnx_mem_pattern.py --runs 3 --report
bench_onnx_mem_pattern.py
"""Benchmark: ONNX Runtime session memory with/without mem_pattern and cpu_mem_arena.

Measures idle session memory (after model load, before/after inference) with
the YoloX layout detection model under four configurations:
  1. Default (both enabled)
  2. enable_mem_pattern=False only
  3. enable_cpu_mem_arena=False only
  4. Both disabled

Uses memray for accurate native+Python allocation tracking.

Usage:
    pip install memray numpy pillow opencv-python-headless onnxruntime
    python bench_onnx_mem_pattern.py
    python bench_onnx_mem_pattern.py --runs 3 --report [PATH]
"""

from __future__ import annotations

import argparse
import gc
import json
import subprocess
import sys
import tempfile
import textwrap
from pathlib import Path


CONFIGS = [
    ("Default (both enabled)", "True", "True"),
    ("mem_pattern=False", "False", "True"),
    ("cpu_arena=False", "True", "False"),
    ("Both disabled", "False", "False"),
]


def _build_script(enable_mem_pattern: str, enable_cpu_mem_arena: str, runs: int):
    return textwrap.dedent(f"""\
        import gc
        import numpy as np
        import onnxruntime
        import cv2
        from PIL import Image
        from huggingface_hub import hf_hub_download

        model_path = hf_hub_download("unstructuredio/yolo_x_layout", "yolox_l0.05.onnx")

        sess_options = onnxruntime.SessionOptions()
        sess_options.enable_mem_pattern = {enable_mem_pattern}
        sess_options.enable_cpu_mem_arena = {enable_cpu_mem_arena}

        session = onnxruntime.InferenceSession(
            model_path,
            sess_options=sess_options,
            providers=["CPUExecutionProvider"],
        )

        # Simulate realistic inference with a letter-size page at 200 DPI
        input_shape = (1024, 768)
        image = Image.fromarray(np.random.randint(0, 255, (2200, 1700, 3), dtype=np.uint8))
        origin_img = np.array(image)
        del image

        def preprocess(img, input_size=input_shape, swap=(2, 0, 1)):
            if len(img.shape) == 3:
                padded = np.full((input_size[0], input_size[1], 3), 114, dtype=np.uint8)
            else:
                padded = np.full(input_size, 114, dtype=np.uint8)
            r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
            ns = (int(img.shape[0] * r), int(img.shape[1] * r))
            padded[:ns[0], :ns[1]] = cv2.resize(img, (ns[1], ns[0]),
                                                  interpolation=cv2.INTER_LINEAR).astype(np.uint8)
            return np.ascontiguousarray(padded.transpose(swap), dtype=np.float32), r

        gc.collect()
        for _ in range({runs}):
            img, ratio = preprocess(origin_img)
            ort_inputs = {{session.get_inputs()[0].name: img[None, :, :, :]}}
            output = session.run(None, ort_inputs)
            del img, ort_inputs, output
            gc.collect()
    """)


def _run_memray(script_body: str) -> dict:
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
        f.write(script_body)
        script_path = f.name
    bin_path = tempfile.mktemp(suffix=".bin")
    json_path = tempfile.mktemp(suffix=".json")
    try:
        subprocess.run(
            [sys.executable, "-m", "memray", "run",
             "--trace-python-allocators", "--native", "-o", bin_path, script_path],
            capture_output=True, check=True,
        )
        subprocess.run(
            [sys.executable, "-m", "memray", "stats",
             "--json", "-n", "30", "-o", json_path, bin_path],
            capture_output=True, check=True,
        )
        with open(json_path) as f:
            return json.load(f)
    finally:
        for p in (script_path, bin_path, json_path):
            Path(p).unlink(missing_ok=True)


def _peak_mb(stats: dict) -> float:
    return stats["metadata"]["peak_memory"] / (1024 * 1024)


def generate_report(results, runs, output):
    import plotly.graph_objects as go

    labels = [r["label"] for r in results]
    peaks = [r["peak"] for r in results]
    baseline = peaks[0]
    saved = [baseline - p for p in peaks]

    colors = ["#94a3b8", "#818cf8", "#a78bfa", "#6366f1"]

    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=labels, y=peaks, marker_color=colors,
        text=[f"{p:.1f} MB" for p in peaks],
        textposition="inside", insidetextanchor="middle",
        textfont=dict(size=13, color="white"),
    ))

    for i, s in enumerate(saved):
        if s > 0.5:
            fig.add_annotation(
                x=labels[i], y=peaks[i] + 15,
                text=f"<b>-{s:.1f} MB ({s / baseline * 100:.0f}%)</b>",
                showarrow=False, font=dict(size=12, color="#dc2626"),
            )

    fig.update_layout(
        template="simple_white", paper_bgcolor="white", plot_bgcolor="white",
        font=dict(family="Inter, sans-serif", color="#374151"),
        title=dict(
            text=(
                "<b>ONNX Runtime session memory: mem_pattern & cpu_mem_arena</b>"
                f'<br><span style="font-size:11px;color:#9ca3af">'
                f"YoloX layout model  |  {runs} inference runs  |  "
                f"Python {sys.version.split()[0]}  |  memray</span>"
            ),
            x=0.5, xanchor="center", font=dict(size=15),
        ),
        showlegend=False,
        yaxis=dict(title="Peak memory (MB)", range=[0, max(peaks) * 1.25],
                   gridcolor="#f3f4f6"),
        xaxis=dict(title=""),
        margin=dict(l=60, r=40, t=100, b=70),
        height=480, width=700,
    )
    fig.write_image(output, scale=2)
    print(f"  Report saved: {output}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--runs", type=int, default=3)
    parser.add_argument("--report", nargs="?", const="bench_onnx_mem_pattern.png",
                        metavar="PATH")
    args = parser.parse_args()

    print(f"ONNX Runtime mem_pattern / cpu_mem_arena benchmark")
    print(f"YoloX layout model  |  {args.runs} inference runs  |  Python {sys.version.split()[0]}")
    print()
    print(f"{'Configuration':<30} {'Peak MB':>10} {'Saved':>10} {'%':>6}")
    print("-" * 60)

    results = []
    baseline_peak = None
    for label, mem_pat, cpu_arena in CONFIGS:
        script = _build_script(mem_pat, cpu_arena, args.runs)
        stats = _run_memray(script)
        peak = _peak_mb(stats)
        if baseline_peak is None:
            baseline_peak = peak
        s = baseline_peak - peak
        pct = (s / baseline_peak * 100) if baseline_peak > 0 else 0
        results.append(dict(label=label, peak=peak))
        print(f"{label:<30} {peak:>8.1f}MB {s:>8.1f}MB {pct:>5.1f}%")

    print()
    if args.report is not None:
        generate_report(results, runs=args.runs, output=args.report)


if __name__ == "__main__":
    main()

Set enable_mem_pattern=False and enable_cpu_mem_arena=False on
SessionOptions for both YoloX and Detectron2 ONNX sessions.

These flags control pre-allocation strategies that trade memory for
speed on repeated inference. With both disabled, peak memory drops
~36% (553→351 MB) on the YoloX model with negligible latency impact.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant