Skip to content
159 changes: 118 additions & 41 deletions src/Runtime/python/onnxmlirtorch/src/onnxmlirtorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import pickle
import pickletools
from collections import deque

import numpy as np
import torch
Expand All @@ -37,6 +38,7 @@

from .onnxmlirdocker import InferenceSession
from .sessioncache import SessionCache, CacheValue
from . import config

"""
This file provides an onnx-mlir compiler backend for torch.compile().
Expand Down Expand Up @@ -78,13 +80,8 @@

logger = logging.getLogger(__name__)


class ONNXMLIRConfig:
cache_size = 3


# An instance to cache onnx_mlir session so that there is no need to recompile the same model.
global_session_cache = SessionCache(ONNXMLIRConfig.cache_size)
global_session_cache = SessionCache(config.session_cache_limit)


# Backend function for torch.compile.
Expand Down Expand Up @@ -145,22 +142,68 @@ def __init__(self, gm: torch.fx.GraphModule, compile_options) -> None:
def generate_hash_key(
gm: torch.fx.GraphModule, compile_options, use_lightweight_hashing=True
) -> str:
start = time.time()
start = time.perf_counter()
if use_lightweight_hashing:
# Hash the graph module.
# Touch the code to materialize.
_ = gm.code

# Generate a unique string to represent the graph module.
node_info = []
graph_info = []
placeholder_counter = 0
dim_counter = 0
dim_dict = {}
for node in gm.graph.nodes:
# Use stable names for placeholders.
node_info = []
# Use stable names for placeholders and symbolic dimensions.
if node.op == "placeholder":
node_info.append(f"om_placeholder_{placeholder_counter}")
if "example_value" in node.meta and isinstance(
node.meta["example_value"], torch.Tensor
):
shape = []
for d in node.meta["example_value"].shape:
s = str(d)
if isinstance(d, torch.SymInt):
if s in dim_dict:
shape.append(dim_dict[s])
else:
dim_str = f"dim_{dim_counter}"
dim_dict[s] = dim_str
dim_counter += 1
shape.append(dim_str)
else:
shape.append(s)
shape_str = ",".join(shape)
node_info.append(
f"om_placeholder_{placeholder_counter}_[{shape_str}]"
)
else:
node_info.append(f"om_placeholder_{placeholder_counter}")
placeholder_counter += 1
else:
node_info.append(f"{node.op}_{node.target}")
graph_str = " ".join(node_info)
node_info.append(f"{node.op}_{torch.typename(node.target)}")
# Append information from input nodes.
for inode in node._input_nodes.keys():
if inode.op == "get_attr":
try:
t = gm._parameters[inode.target]
except KeyError:
t = None
if t is not None and isinstance(t, torch.nn.Parameter):
sample_values = [
str(s)
for s in t.view(-1)[
: config.sample_parameter_values_limit
].tolist()
]
sample_str = ".".join(sample_values)
else:
sample_str = "."
node_info.append(f"{inode.name}.{sample_str}")
else:
node_info.append(f"{inode.name}")
graph_info.append(";".join(node_info))
graph_str = " ".join(graph_info)
graph_hash = sha256_hash(graph_str.encode())

# Hash the options.
Expand All @@ -175,29 +218,59 @@ def generate_hash_key(
key = pickler.get_hash(details)

key = "_om_" + key
logger.info(f"Creating a cache key took {(time.time() - start)*1000} ms: {key}")
logger.info(
f"Creating a cache key took {(time.perf_counter() - start)*1000} ms: {key}"
)
return key


class ONNXMLIRTorch:
def __init__(self, gm: torch.fx.GraphModule, *args, **kwargs):
# Input graph module.
self.gm = gm
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Original graph module: {self.gm}")

# Generate an unique key from the graph module.
self.cache_key = generate_hash_key(self.gm, kwargs["options"])
self.gm.eval()

# Check whether there is any cached compiled model.
self.cached_session = global_session_cache.get(self.cache_key)
if self.cached_session is None:
# Rewrite the graph for exporting to onnx.
self.example_inputs_indices, _ = self.rewrite_gm_for_export(*args)
if len(self.example_inputs_indices) < len(args):
# Cache the rewritten graph module.
self.cache_key = generate_hash_key(self.gm, kwargs["options"])
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Original graph module {self.gm}")

# If the model was rewritten, the cache key was stored in "om_hash" in gm.meta.
need_rewrite = False
if "om_hash" not in self.gm.meta:
# Rewrite the graph at the first time touching the graph.
need_rewrite = True
self.gm.meta["om_same_hash_counter"] = deque([])
else:
same_hash_counter = self.gm.meta["om_same_hash_counter"]
same_hash_size = max(0, config.same_hash_size)
if len(same_hash_counter) == same_hash_size and all(same_hash_counter):
self.cache_key = self.gm.meta["om_hash"]
else:
self.cache_key = generate_hash_key(self.gm, kwargs["options"])
if self.cache_key == self.gm.meta["om_hash"]:
same_hash_counter.append(True)
else:
# Rewrite the graph if it was changed.
same_hash_counter.append(False)
need_rewrite = True
while len(same_hash_counter) > same_hash_size:
if same_hash_counter:
same_hash_counter.popleft()
self.gm.meta["om_same_hash_counter"] = same_hash_counter

if need_rewrite:
# Rewrite the graph for exporting to onnx.
(
self.example_inputs_indices,
removed_example_inputs_indices,
placeholders_to_replace,
) = self.rewrite_gm_for_export(*args)
self.cache_key = generate_hash_key(self.gm, kwargs["options"])
self.gm.meta["om_hash"] = self.cache_key

# Cache the rewritten graph module.
assert self.cache_key, "cache key does not exist"
self.cached_session = global_session_cache.get(self.cache_key)
if self.cached_session:
self.example_inputs_indices = self.cached_session.example_inputs_indices

# Touch the code to materialize before exporting.
Expand Down Expand Up @@ -263,9 +336,9 @@ def forward(self, *example_inputs):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"onnx_mlir input sig: {sess.input_signature()}")
logger.debug(f"onnx_mlir output sig: {sess.output_signature()}")
start = time.time()
start = time.perf_counter()
om_outputs = sess.run(om_inputs)
logger.info(f"sess.run took {(time.time() - start)*1000} ms")
logger.info(f"sess.run took {(time.perf_counter() - start)*1000} ms")
return [torch.from_numpy(output) for output in om_outputs]

def get_tensor_example_inputs(self, example_inputs):
Expand Down Expand Up @@ -309,7 +382,8 @@ def get_dynamic_shapes_for_export(self) -> ([str], dict[str, dict[int, str]]):
dynamic_dims = {}
for dim_idx, dim_size in enumerate(input_arg.shape):
if isinstance(dim_size, torch.SymInt):
dynamic_dims[dim_idx] = "dim" + str(dim_size)
if not str(dim_size).isdigit():
dynamic_dims[dim_idx] = "dim" + str(dim_size)
if dynamic_dims:
dynamic_shapes[input_name] = dynamic_dims
else:
Expand All @@ -328,30 +402,31 @@ def rewrite_gm_for_export(self, *example_inputs):
self.freeze_scalar_constant_args(constant_values)
# Since onnx does not support scalar inputs, symbolic integer arguments
# are converted to tensor arguments.
self.convert_symint_args_to_tensors()
placeholders_to_replace = self.convert_symint_args_to_tensors()
# After rewriting the argument list of the graph module, we maintain
# a list of un-removed arguments that are used in forward for passing
# correct example inputs to the rewritten graph module.
return example_inputs_indices, removed_example_inputs
return example_inputs_indices, removed_example_inputs, placeholders_to_replace

def convert_symint_args_to_tensors(self):
# Important note: do not cast SymInt to int by int(SymInt)
# since that concretizes symbolic dimensions in related Tensors.

graph = self.gm.graph
placeholders_to_replace = []

# First pass: collect SymInt placeholders.
for node in list(graph.nodes):
for node in graph.nodes:
if node.op == "placeholder" and node.type in [int, torch.SymInt]:
new_name = f"{node.name}_tensor"
if node.type is torch.SymInt:
value = int(node.meta["example_value"])
else:
value = node.meta["example_value"]
with graph.inserting_before(node):
new_node = graph.placeholder(new_name)
new_node.meta = {
"tensor_meta": {"shape": [1], "dtype": torch.int64},
"example_value": torch.tensor([value], dtype=torch.int64),
}
new_node.meta = node.meta
new_node.meta["tensor_meta"] = {"shape": [1], "dtype": torch.int64}
if node.type is int:
new_node.meta["example_value"] = torch.tensor(
[value], dtype=torch.int64
)
new_node.type = torch.Tensor
placeholders_to_replace.append((node, new_node))

Expand All @@ -367,6 +442,8 @@ def convert_symint_args_to_tensors(self):
self.gm.graph.lint()
self.gm.recompile()

return placeholders_to_replace

def extract_scalar_constant_args(self, example_inputs: tuple):
graph = self.gm.graph
placeholder_nodes = [n for n in graph.nodes if n.op == "placeholder"]
Expand Down Expand Up @@ -498,7 +575,7 @@ def create_onnxmlir_session(self) -> InferenceSession:

def cleanup_onnxmlir_files(self, tag_id):
base = os.path.join(self.workdir.name, self.default_model_name + str(tag_id))
old_files = [base + ".onnx", base + ".so"]
old_files = [base + ".onnx", base + ".constants.bin", base + ".so"]
for f in old_files:
if os.path.exists(f):
os.remove(old_onnx_file)
Expand Down
29 changes: 29 additions & 0 deletions src/Runtime/python/onnxmlirtorch/src/onnxmlirtorch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0

##################### config.py *******#########################################
#
# Copyright 2025 The IBM Research Authors.
#
################################################################################
#
# This module contains various configuration flags and settings that control the backend.
# These flags and settings can be set in users' script by using package.config., e.g.:
# ```python
# import onnxmlirtorch
# onnxmlirtorch.config.same_hash_size = 1
# ```
#
################################################################################

# If the compiler detects that after this number of hashings, the graph module stays
# the same, the compiler does not hash the module in the next run in order to reduce
# the inference overhead.
same_hash_size = 3

# Control how many values in a constant tensor (parameters) are used for hashing
# the graph module. This affects the hashing time since it takes more time to
# read more values and hash more data.
sample_parameter_values_limit = 3

# Control the maximum number of compiler sessions to be cached at runtime.
session_cache_limit = 3
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
torch._dynamo.config.prepare_freezing = 1

# Exporting with dynamic shapes.
# torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.assume_static_by_default = False

from .backend import (
onnxmlir_backend,
Expand Down
Loading
Loading