Skip to content

Understanding max_mem option of OrtArenaCfg class #23121

Open
@vsbaldeev

Description

@vsbaldeev

Describe the issue

I tried using the max_mem option of OrtArenaCfg class (from here) with python to limit the memory consumption of the model, but it works unpredictably for me. If I set max_mem to 8192, I expect the model to be able to consume up to 8192 bytes as input and not a byte more. In practice, the model can only consume about 5616 bytes. RuntimeException says that requested memory is only 2048 bytes. I can't see the relationship between these numbers.

I need to configure the exact number of maximum bytes my model can consume. Does this option or any other way of doing it exist in the onnxruntime library?

requirements.txt:
onnxruntime==1.20.1
numpy==2.2.0
skl2onnx==1.17.0

Python:
3.13

To reproduce

import uuid

import numpy
import onnxruntime
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import StringTensorType
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline


def create_trained_inference_session(memory_limit: int):
    x = [str(uuid.uuid4()) for _ in range(100)]
    y = ["a" for _ in range(50)] + ["b" for _ in range(50)]

    sklearn_model = Pipeline(
        steps=[
            ("vectorizer", CountVectorizer()),
            ("classifier", RandomForestClassifier())
        ]
    )

    sklearn_model.fit(x, y)

    onnx_model_proto_string = convert_sklearn(
        sklearn_model,
        initial_types=[('features', StringTensorType((None,)))],
        verbose=False
    ).SerializeToString()

    ort_memory_info = onnxruntime.OrtMemoryInfo(
        "Cpu", onnxruntime.OrtAllocatorType.ORT_ARENA_ALLOCATOR, 0, onnxruntime.OrtMemType.CPU
    )
    ort_arena_cfg = onnxruntime.OrtArenaCfg({
        "max_mem": memory_limit,
    })
    onnxruntime.create_and_register_allocator(ort_memory_info, ort_arena_cfg)

    session_options = onnxruntime.SessionOptions()
    session_options.log_severity_level = 0
    session_options.log_verbosity_level = 0
    session_options.add_session_config_entry("session.use_env_allocators", "1")

    return onnxruntime.InferenceSession(
        onnx_model_proto_string,
        providers=["CPUExecutionProvider"],
        sess_options=session_options
    )

def main():
    onnx_model = create_trained_inference_session(memory_limit=8192)

    print("This call fails on Macos 15.2")
    try:
        input_data = numpy.array([str(uuid.uuid4()) * 40])
        print(f"Input data size in bytes = {input_data.nbytes}")

        result = onnx_model.run(
            [onnx_model.get_outputs()[1].name],
            {onnx_model.get_inputs()[0].name: input_data}
        )
        print(result)

    except Exception as exception:
        print(f"Caught exception: {exception}")

    print("This call will complete successfully")
    try:
        input_data = numpy.array([str(uuid.uuid4()) * 39])
        print(f"Input data size in bytes = {input_data.nbytes}")

        result = onnx_model.run(
            [onnx_model.get_outputs()[1].name],
            {onnx_model.get_inputs()[0].name: input_data}
        )
        print(result)

    except Exception as exception:
        print(f"Caught exception: {exception}")

if __name__ == '__main__':
    main()

The code above is producing the following output:

`2024-12-16 16:48:18.318947 [I:onnxruntime:, inference_session.cc:583 TraceSessionOptions] Session Options { execution_mode:0 execution_order:DEFAULT enable_profiling:0 optimized_model_filepath:"" enable_mem_pattern:1 enable_mem_reuse:1 enable_cpu_mem_arena:1 profile_file_prefix:onnxruntime_profile_ session_logid: session_log_severity_level:0 session_log_verbosity_level:0 max_num_graph_transformation_steps:10 graph_optimization_level:3 intra_op_param:OrtThreadPoolParams { thread_pool_size: 0 auto_set_affinity: 0 allow_spinning: 1 dynamic_block_base_: 0 stack_size: 0 affinity_str: set_denormal_as_zero: 0 } inter_op_param:OrtThreadPoolParams { thread_pool_size: 0 auto_set_affinity: 0 allow_spinning: 1 dynamic_block_base_: 0 stack_size: 0 affinity_str: set_denormal_as_zero: 0 } use_per_session_threads:1 thread_pool_allow_spinning:1 use_deterministic_compute:0 config_options: { session.use_env_allocators: 1 } }
2024-12-16 16:48:18.318967 [I:onnxruntime:, inference_session.cc:483 operator()] Flush-to-zero and denormal-as-zero are off
2024-12-16 16:48:18.319016 [I:onnxruntime:, inference_session.cc:491 ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true
2024-12-16 16:48:18.319022 [I:onnxruntime:, inference_session.cc:509 ConstructorCommon] Dynamic block base set to 0
2024-12-16 16:48:18.320176 [I:onnxruntime:, inference_session.cc:1669 Initialize] Initializing session.
2024-12-16 16:48:18.320185 [I:onnxruntime:, inference_session.cc:1751 Initialize] This session will use the allocator registered with the environment.
2024-12-16 16:48:18.321055 [I:onnxruntime:, graph_partitioner.cc:898 InlineFunctionsAOT] This model does not have any local functions defined. AOT Inlining is not performed
2024-12-16 16:48:18.321067 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer EnsureUniqueDQForNodeUnit modified: 0 with status: OK
2024-12-16 16:48:18.321074 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer Level1_RuleBasedTransformer modified: 1 with status: OK
2024-12-16 16:48:18.321099 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer DoubleQDQPairsRemover modified: 0 with status: OK
2024-12-16 16:48:18.321114 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ConstantSharing modified: 0 with status: OK
2024-12-16 16:48:18.321303 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2024-12-16 16:48:18.321308 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ConstantFolding modified: 0 with status: OK
2024-12-16 16:48:18.321311 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulAddFusion modified: 0 with status: OK
2024-12-16 16:48:18.321315 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ReshapeFusion modified: 0 with status: OK
2024-12-16 16:48:18.321317 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer FreeDimensionOverrideTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321320 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer GeluFusionL1 modified: 0 with status: OK
2024-12-16 16:48:18.321323 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusionL1 modified: 0 with status: OK
2024-12-16 16:48:18.321328 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer QDQPropagationTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321330 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer EnsureUniqueDQForNodeUnit modified: 0 with status: OK
2024-12-16 16:48:18.321333 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer RocmBlasAltImpl modified: 0 with status: OK
2024-12-16 16:48:18.321344 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer TransposeOptimizer modified: 0 with status: OK
2024-12-16 16:48:18.321347 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer Level1_RuleBasedTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321349 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer DoubleQDQPairsRemover modified: 0 with status: OK
2024-12-16 16:48:18.321532 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2024-12-16 16:48:18.321535 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ConstantFolding modified: 0 with status: OK
2024-12-16 16:48:18.321538 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulAddFusion modified: 0 with status: OK
2024-12-16 16:48:18.321541 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ReshapeFusion modified: 0 with status: OK
2024-12-16 16:48:18.321543 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer FreeDimensionOverrideTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321546 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer GeluFusionL1 modified: 0 with status: OK
2024-12-16 16:48:18.321549 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusionL1 modified: 0 with status: OK
2024-12-16 16:48:18.321552 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer QDQPropagationTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321745 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer EnsureUniqueDQForNodeUnit modified: 0 with status: OK
2024-12-16 16:48:18.321761 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer RocmBlasAltImpl modified: 0 with status: OK
2024-12-16 16:48:18.321834 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer Level2_RuleBasedTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321850 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer TransposeOptimizer_CPUExecutionProvider modified: 0 with status: OK
2024-12-16 16:48:18.321868 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer QDQSelectorActionTransformer modified: 0 with status: OK
2024-12-16 16:48:18.321876 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer GemmActivationFusion modified: 0 with status: OK
2024-12-16 16:48:18.321883 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulIntegerToFloatFusion modified: 0 with status: OK
2024-12-16 16:48:18.321889 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer DynamicQuantizeMatMulFusion modified: 0 with status: OK
2024-12-16 16:48:18.321895 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ConvActivationFusion modified: 0 with status: OK
2024-12-16 16:48:18.321900 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer GeluFusionL2 modified: 0 with status: OK
2024-12-16 16:48:18.321904 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusionL2 modified: 0 with status: OK
2024-12-16 16:48:18.321910 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer SimplifiedLayerNormFusion modified: 0 with status: OK
2024-12-16 16:48:18.321989 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer AttentionFusion modified: 0 with status: OK
2024-12-16 16:48:18.321997 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer EmbedLayerNormFusion modified: 0 with status: OK
2024-12-16 16:48:18.322005 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer GatherSliceToSplitFusion modified: 0 with status: OK
2024-12-16 16:48:18.322011 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer GatherToSliceFusion modified: 0 with status: OK
2024-12-16 16:48:18.322019 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatmulTransposeFusion modified: 0 with status: OK
2024-12-16 16:48:18.322025 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer BiasGeluFusion modified: 0 with status: OK
2024-12-16 16:48:18.322031 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer SkipLayerNormFusion modified: 0 with status: OK
2024-12-16 16:48:18.322125 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer FastGeluFusion modified: 0 with status: OK
2024-12-16 16:48:18.322134 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer QuickGeluFusion modified: 0 with status: OK
2024-12-16 16:48:18.322139 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer BiasSoftmaxFusion modified: 0 with status: OK
2024-12-16 16:48:18.322143 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer BiasDropoutFusion modified: 0 with status: OK
2024-12-16 16:48:18.322147 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulScaleFusion modified: 0 with status: OK
2024-12-16 16:48:18.322151 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulActivationFusion modified: 0 with status: OK
2024-12-16 16:48:18.322158 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulNBitsFusion modified: 0 with status: OK
2024-12-16 16:48:18.322267 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer QDQFinalCleanupTransformer modified: 0 with status: OK
2024-12-16 16:48:18.322276 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer NhwcTransformer modified: 0 with status: OK
2024-12-16 16:48:18.322281 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer ConvAddActivationFusion modified: 0 with status: OK
2024-12-16 16:48:18.322295 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer RemoveDuplicateCastTransformer modified: 0 with status: OK
2024-12-16 16:48:18.322298 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer CastFloat16Transformer modified: 0 with status: OK
2024-12-16 16:48:18.322306 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MemcpyTransformer modified: 0 with status: OK
2024-12-16 16:48:18.322316 [V:onnxruntime:, session_state.cc:1148 VerifyEachNodeIsAssignedToAnEp] Node placements
2024-12-16 16:48:18.322319 [V:onnxruntime:, session_state.cc:1151 VerifyEachNodeIsAssignedToAnEp] All nodes placed on [CPUExecutionProvider]. Number of nodes: 7
2024-12-16 16:48:18.322423 [V:onnxruntime:, session_state.cc:128 CreateGraphInfo] SaveMLValueNameIndexMapping
2024-12-16 16:48:18.322437 [V:onnxruntime:, session_state.cc:174 CreateGraphInfo] Done saving OrtValue mappings.
2024-12-16 16:48:18.322443 [I:onnxruntime:, allocation_planner.cc:2567 CreateGraphPartitioner] Use DeviceBasedPartition as default
2024-12-16 16:48:18.322475 [I:onnxruntime:, session_state_utils.cc:276 SaveInitializedTensors] Saving initialized tensors.
2024-12-16 16:48:18.322491 [I:onnxruntime:, session_state_utils.cc:427 SaveInitializedTensors] Done saving initialized tensors
2024-12-16 16:48:18.323604 [I:onnxruntime:, session_state.cc:262 PruneRemovableAttributes] removed 14 removable attributes for node 'TreeEnsembleClassifier' ('TreeEnsembleClassifier'), among attributes: base_values, nodes_falsenodeids, nodes_featureids, nodes_hitrates, nodes_missing_value_tracks_true, nodes_modes, nodes_nodeids, nodes_treeids, nodes_truenodeids, nodes_values, class_ids, class_treeids, class_nodeids, class_weights, classlabels_strings, classlabels_int64sbase_values_as_tensor, nodes_hitrates_as_tensor, nodes_values_as_tensor, class_weights_as_tensor.
2024-12-16 16:48:18.323609 [I:onnxruntime:, inference_session.cc:2106 Initialize] Session successfully initialized.
This call fails on Macos 15.2
Input data size in bytes = 5760
2024-12-16 16:48:18.323840 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running TfIdfVectorizer node. Name:'TfIdfVectorizer' Status Message: /Users/runner/work/1/s/onnxruntime/core/framework/bfc_arena.cc:376 void *onnxruntime::BFCArena::AllocateRawInternal(size_t, bool, onnxruntime::Stream *, bool, onnxruntime::WaitNotificationFn) Available memory of 0 is smaller than requested bytes of 2048

Caught exception: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running TfIdfVectorizer node. Name:'TfIdfVectorizer' Status Message: /Users/runner/work/1/s/onnxruntime/core/framework/bfc_arena.cc:376 void *onnxruntime::BFCArena::AllocateRawInternal(size_t, bool, onnxruntime::Stream *, bool, onnxruntime::WaitNotificationFn) Available memory of 0 is smaller than requested bytes of 2048

This call will complete successfully
Input data size in bytes = 5616
[[{'a': 0.5100000500679016, 'b': 0.4899999499320984}]]
`

Urgency

No response

Platform

Mac

OS Version

15.2

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

Python

Architecture

Other / Unknown

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    core runtimeissues related to core runtimestaleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions