Skip to content

Commit 4fb30b8

Browse files
Fix tvm verification and add doc string for RemoveEmptyConcat pattern callback (#2516)
1) Added doc string for the RemoveEmptyConcat pattern callback and updated the pattern callback name from RemoveConcat to RemoveEmptyConcat based upon [this comments](#2421 (comment)). 2) The tvm verification function(i.e verify_tvm_compile) is improperly called in run_pattern_callbacks function present in `forge/forge/tvm_calls/relay/op/forge_passes.py` path and the verify_tvm_compile function is used in both run_forge_compile_passes and compile_for_forge function. To avoid cirucular import issues moved the TVM verification functions(i.e verify_tvm_compile) to forge/forge/tvm_calls/relay/op/utils.py path. The framework output should be extracted only when verify_tvm_compile config is enabled.
1 parent 786f15c commit 4fb30b8

File tree

5 files changed

+133
-109
lines changed

5 files changed

+133
-109
lines changed

forge/forge/tvm_calls/forge_compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
import onnx
3333
import onnx.numpy_helper
3434
from tvm.relay.expr import Tuple
35-
from forge.tvm_calls.relay.op.forge import verify_tvm_compile, flatten_IO, compile_for_forge, partition_for_forge
35+
from forge.tvm_calls.relay.op.forge import flatten_IO, compile_for_forge, partition_for_forge
36+
from forge.tvm_calls.relay.op.utils import verify_tvm_compile
3637
from jax.experimental import jax2tf
3738
from jax.tools.jax_to_ir import tf_wrap_with_input_names
3839
from transformers import FlaxPreTrainedModel

forge/forge/tvm_calls/forge_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def extract_framework_model_outputs(
3535
):
3636
framework_outputs = []
3737

38-
if verify_tvm_compile:
38+
if not verify_tvm_compile:
3939
return framework_outputs
4040

4141
if framework == "pytorch" or framework == "paddle":

forge/forge/tvm_calls/relay/op/forge.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,107 +1004,6 @@ def visit_call(self, call):
10041004
super().visit_call(call)
10051005

10061006

1007-
def get_relay_output(mod, params, inputs, target):
1008-
# Build and Run Relay modules with inputs as (key : tensor) pair
1009-
# Then, inputs dont need to be in the same order as 'mod' defines.
1010-
ret_type = mod["main"].checked_type.ret_type
1011-
with tvm.transform.PassContext(opt_level=0):
1012-
lib = relay.build_module.build(mod, target=target, params=params)
1013-
m = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
1014-
m.run(**inputs)
1015-
1016-
def _unflatten(flat_iter, cur_type):
1017-
import tvm.relay.ty as _ty
1018-
1019-
if isinstance(cur_type, _ty.TensorType):
1020-
return next(flat_iter)
1021-
if isinstance(cur_type, _ty.TupleType):
1022-
fields = []
1023-
for field_type in cur_type.fields:
1024-
field = _unflatten(flat_iter, field_type)
1025-
fields.append(field)
1026-
return fields
1027-
raise ValueError("Return type", ret_type, "contains unsupported type", cur_type)
1028-
1029-
flattened = []
1030-
import tvm.runtime.ndarray as _nd
1031-
1032-
for i in range(m.get_num_outputs()):
1033-
flattened.append(m.get_output(i).copyto(_nd.cpu(0)))
1034-
relay_outputs = _unflatten(iter(flattened), ret_type)
1035-
1036-
if not isinstance(relay_outputs, (list, tuple)):
1037-
relay_outputs = [relay_outputs]
1038-
relay_outputs = [x.numpy() for x in flattened]
1039-
return relay_outputs
1040-
1041-
1042-
def verify_outputs(framework_outputs, relay_outputs, compile_location, rtol=1e-02, atol=1e-04, pcc=None):
1043-
allowed_to_fail = False
1044-
if len(framework_outputs) != len(relay_outputs):
1045-
logger.error(
1046-
f"Different number of outputs. Framework: {len(framework_outputs)}, TVM: {len(relay_outputs)} after {compile_location}"
1047-
)
1048-
1049-
for i, (fr_out, tvm_out) in enumerate(zip(framework_outputs, relay_outputs)):
1050-
if fr_out.shape != tvm_out.shape:
1051-
logger.error(
1052-
f"Different shapes for outputs. Framework: {fr_out.shape}, TVM: {tvm_out.shape} after {compile_location}"
1053-
)
1054-
1055-
if pcc is None:
1056-
ok = np.allclose(fr_out, tvm_out, rtol=rtol, atol=atol, equal_nan=True)
1057-
else:
1058-
pcc_value = np.min(
1059-
np.ma.corrcoef(np.ma.masked_invalid(fr_out.flatten()), np.ma.masked_invalid(tvm_out.flatten()))
1060-
)
1061-
if isinstance(pcc_value, np.ma.core.MaskedConstant):
1062-
pcc_value = 1.0
1063-
ok = pcc_value >= pcc
1064-
1065-
if not ok:
1066-
logger.error(f"Tensor mismatch on output {i} between framework and TVM after {compile_location}.")
1067-
logger.trace(f"Framework: (shape = {fr_out.shape}")
1068-
logger.trace(fr_out)
1069-
logger.trace(f"TVM: (shape = {tvm_out.shape}")
1070-
logger.trace(tvm_out)
1071-
logger.info(
1072-
"Max ATOL Delta: "
1073-
+ "{:.3e}".format(np.max(np.abs((fr_out - tvm_out))).item())
1074-
+ ", atol="
1075-
+ "{}".format(atol)
1076-
)
1077-
logger.info(
1078-
"Max RTOL Delta: "
1079-
+ "{:.3e}".format(np.max(np.abs((fr_out - tvm_out)) / tvm_out).item())
1080-
+ ", rtol="
1081-
+ "{}".format(rtol)
1082-
)
1083-
if pcc is not None:
1084-
logger.info(f"PCC got={pcc_value}, required={pcc}")
1085-
if not allowed_to_fail:
1086-
raise RuntimeError
1087-
1088-
logger.info(f"Verified TVM Relay outputs against framework outputs after {compile_location}")
1089-
1090-
1091-
def verify_tvm_compile(mod, params, inputs, target, framework_outputs, compile_location, verify_cfg=None):
1092-
relay_outputs = get_relay_output(mod, params, inputs, target)
1093-
1094-
# Verify compile passes (original relay passes + forge passes)
1095-
if verify_cfg:
1096-
verify_outputs(
1097-
framework_outputs,
1098-
relay_outputs,
1099-
compile_location,
1100-
rtol=verify_cfg.rtol,
1101-
atol=verify_cfg.atol,
1102-
pcc=verify_cfg.pcc,
1103-
)
1104-
else:
1105-
verify_outputs(framework_outputs, relay_outputs, compile_location)
1106-
1107-
11081007
class CompareWarner(DFPatternCallback):
11091008
def __init__(self):
11101009
super().__init__(require_type=True, rewrite_once=True)

forge/forge/tvm_calls/relay/op/forge_passes.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4990,7 +4990,30 @@ def callback(self, pre, post, node_map):
49904990
return out
49914991

49924992

4993-
class RemoveConcat(DFPatternCallback):
4993+
class RemoveEmptyConcat(DFPatternCallback):
4994+
"""
4995+
Relay pass to eliminate unnecessary `concatenate` ops involving empty tensors.
4996+
4997+
In some models (e.g., Phi-3), rotary embedding logic performs slicing on the last
4998+
dimension of the query tensor to split it into two parts. If the slicing boundaries
4999+
are incorrectly defined, this may create an empty tensor (e.g., shape with size 0
5000+
along the concatenation axis).
5001+
5002+
This pass identifies `concatenate` operations between two tensors where one has
5003+
dimension size 0 along the concatenation axis, and removes the redundant concat
5004+
by returning the non-empty operand directly.
5005+
5006+
This prevents downstream errors like:
5007+
`AssertionError: start < operandA.shape[dim]`
5008+
5009+
which occur due to operations on invalid tensor slices.
5010+
5011+
Example:
5012+
q_rot = query[..., 0:96]
5013+
q_pass = query[..., 96:96] # shape: (1, 32, 256, 0)
5014+
concat(q_rot, q_pass, axis=-1) # → Rewritten to just q_rot
5015+
"""
5016+
49945017
def __init__(self):
49955018
super().__init__(rewrite_once=False, require_type=True)
49965019
self.act1 = wildcard()
@@ -5053,9 +5076,7 @@ def run_pattern_callbacks(
50535076
raise ex
50545077
if run_verify:
50555078
logger.trace(f"Verifying {callback_name}")
5056-
tvm.relay.op.contrib.forge.forge.verify_tvm_compile(
5057-
relay_module, params, inputs, target, framework_outputs, callback_name, verify_cfg
5058-
)
5079+
verify_tvm_compile(relay_module, params, inputs, target, framework_outputs, callback_name, verify_cfg)
50595080

50605081
return relay_module
50615082

@@ -5161,7 +5182,7 @@ def run_forge_compile_passes(
51615182
SimplifyVITOnnxAttention(),
51625183
GQABroadcastReshape(),
51635184
RemoveDenseInputSqueeze(),
5164-
RemoveConcat(),
5185+
RemoveEmptyConcat(),
51655186
],
51665187
params=params,
51675188
inputs=inputs,

forge/forge/tvm_calls/relay/op/utils.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,116 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import numpy as np
65
import numpy as np
76
from tvm.relay.dataflow_pattern import *
7+
import tvm
8+
from tvm import relay
9+
from tvm.contrib import graph_executor
810

911
from loguru import logger
1012

1113

14+
def get_relay_output(mod, params, inputs, target):
15+
# Build and Run Relay modules with inputs as (key : tensor) pair
16+
# Then, inputs dont need to be in the same order as 'mod' defines.
17+
ret_type = mod["main"].checked_type.ret_type
18+
with tvm.transform.PassContext(opt_level=0):
19+
lib = relay.build_module.build(mod, target=target, params=params)
20+
m = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
21+
m.run(**inputs)
22+
23+
def _unflatten(flat_iter, cur_type):
24+
import tvm.relay.ty as _ty
25+
26+
if isinstance(cur_type, _ty.TensorType):
27+
return next(flat_iter)
28+
if isinstance(cur_type, _ty.TupleType):
29+
fields = []
30+
for field_type in cur_type.fields:
31+
field = _unflatten(flat_iter, field_type)
32+
fields.append(field)
33+
return fields
34+
raise ValueError("Return type", ret_type, "contains unsupported type", cur_type)
35+
36+
flattened = []
37+
import tvm.runtime.ndarray as _nd
38+
39+
for i in range(m.get_num_outputs()):
40+
flattened.append(m.get_output(i).copyto(_nd.cpu(0)))
41+
relay_outputs = _unflatten(iter(flattened), ret_type)
42+
43+
if not isinstance(relay_outputs, (list, tuple)):
44+
relay_outputs = [relay_outputs]
45+
relay_outputs = [x.numpy() for x in flattened]
46+
return relay_outputs
47+
48+
49+
def verify_outputs(framework_outputs, relay_outputs, compile_location, rtol=1e-02, atol=1e-04, pcc=None):
50+
allowed_to_fail = False
51+
if len(framework_outputs) != len(relay_outputs):
52+
logger.error(
53+
f"Different number of outputs. Framework: {len(framework_outputs)}, TVM: {len(relay_outputs)} after {compile_location}"
54+
)
55+
56+
for i, (fr_out, tvm_out) in enumerate(zip(framework_outputs, relay_outputs)):
57+
if fr_out.shape != tvm_out.shape:
58+
logger.error(
59+
f"Different shapes for outputs. Framework: {fr_out.shape}, TVM: {tvm_out.shape} after {compile_location}"
60+
)
61+
62+
if pcc is None:
63+
ok = np.allclose(fr_out, tvm_out, rtol=rtol, atol=atol, equal_nan=True)
64+
else:
65+
pcc_value = np.min(
66+
np.ma.corrcoef(np.ma.masked_invalid(fr_out.flatten()), np.ma.masked_invalid(tvm_out.flatten()))
67+
)
68+
if isinstance(pcc_value, np.ma.core.MaskedConstant):
69+
pcc_value = 1.0
70+
ok = pcc_value >= pcc
71+
72+
if not ok:
73+
logger.error(f"Tensor mismatch on output {i} between framework and TVM after {compile_location}.")
74+
logger.trace(f"Framework: (shape = {fr_out.shape}")
75+
logger.trace(fr_out)
76+
logger.trace(f"TVM: (shape = {tvm_out.shape}")
77+
logger.trace(tvm_out)
78+
logger.info(
79+
"Max ATOL Delta: "
80+
+ "{:.3e}".format(np.max(np.abs((fr_out - tvm_out))).item())
81+
+ ", atol="
82+
+ "{}".format(atol)
83+
)
84+
logger.info(
85+
"Max RTOL Delta: "
86+
+ "{:.3e}".format(np.max(np.abs((fr_out - tvm_out)) / tvm_out).item())
87+
+ ", rtol="
88+
+ "{}".format(rtol)
89+
)
90+
if pcc is not None:
91+
logger.info(f"PCC got={pcc_value}, required={pcc}")
92+
if not allowed_to_fail:
93+
raise RuntimeError
94+
95+
logger.info(f"Verified TVM Relay outputs against framework outputs after {compile_location}")
96+
97+
98+
def verify_tvm_compile(mod, params, inputs, target, framework_outputs, compile_location, verify_cfg=None):
99+
relay_outputs = get_relay_output(mod, params, inputs, target)
100+
101+
# Verify compile passes (original relay passes + forge passes)
102+
if verify_cfg:
103+
verify_outputs(
104+
framework_outputs,
105+
relay_outputs,
106+
compile_location,
107+
rtol=verify_cfg.rtol,
108+
atol=verify_cfg.atol,
109+
pcc=verify_cfg.pcc,
110+
)
111+
else:
112+
verify_outputs(framework_outputs, relay_outputs, compile_location)
113+
114+
12115
def is_unsqueeze(call):
13116
input_shape = call.args[0].checked_type.shape
14117
output_shape = call.checked_type.shape

0 commit comments

Comments
 (0)