Skip to content

Commit 84f68e7

Browse files
committed
SweepsPytestReport [ci skip]
1 parent b92be07 commit 84f68e7

File tree

4 files changed

+173
-134
lines changed

4 files changed

+173
-134
lines changed

forge/test/operators/pytorch/conftest.py

Lines changed: 3 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,7 @@
88
import _pytest.reports
99
import _pytest.runner
1010

11-
from ..utils.frontend import XLA_MODE
12-
13-
if XLA_MODE:
14-
import pluggy
15-
else:
16-
import pluggy.callers
17-
18-
from loguru import logger
19-
20-
from ..utils import PyTestUtils
21-
from ..utils import SweepsTagsLogger
22-
from ..utils import FailingReasonsFinder
23-
from ..utils import FailingReasonsValidation
24-
from ..utils import TestPlanUtils
11+
from ..utils import SweepsPytestReport
2512

2613

2714
def pytest_generate_tests(metafunc):
@@ -33,123 +20,6 @@ def pytest_generate_tests(metafunc):
3320

3421
@pytest.hookimpl(hookwrapper=True)
3522
def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runner.CallInfo):
36-
if XLA_MODE:
37-
outcome: pluggy.Result = yield
38-
else:
39-
outcome: pluggy.callers._Result = yield
23+
outcome = yield
4024
report: _pytest.reports.TestReport = outcome.get_result()
41-
42-
xfail_reason = None
43-
44-
if report.when == "call" or (report.when == "setup" and report.skipped):
45-
xfail_reason = PyTestUtils.get_xfail_reason(item)
46-
47-
# This hook function is called after each step of the test execution (setup, call, teardown)
48-
if call.when == "call": # 'call' is a phase when the test is actually executed
49-
50-
if xfail_reason is not None: # an xfail reason is defined for the test
51-
SweepsTagsLogger.log_expected_failing_reason(xfail_reason=xfail_reason)
52-
53-
if call.excinfo is not None: # an exception occurred during the test execution
54-
55-
logger.trace(
56-
f"Test: skipped: {report.skipped} failed: {report.failed} passed: {report.passed} report: {report}"
57-
)
58-
59-
exception_value = call.excinfo.value
60-
long_repr = call.excinfo.getrepr(style="long")
61-
exception_traceback = str(long_repr)
62-
63-
log_error_properties(item, exception_value, exception_traceback)
64-
65-
ex_data = FailingReasonsFinder.build_ex_data(exception_value, exception_traceback)
66-
SweepsTagsLogger.log_detected_failing_reason(ex_data)
67-
68-
if xfail_reason is not None: # an xfail reason is defined for the test
69-
valid_reason = FailingReasonsValidation.validate_exception(
70-
exception_value, exception_traceback, xfail_reason
71-
)
72-
73-
# if reason is not valid, mark the test as failed and keep the original exception
74-
if valid_reason == False:
75-
# Replace test report with a new one with outcome set to 'failed' and exception details
76-
new_report = _pytest.reports.TestReport(
77-
item=item,
78-
when=call.when,
79-
outcome="failed",
80-
longrepr=call.excinfo.getrepr(style="long"),
81-
sections=report.sections,
82-
nodeid=report.nodeid,
83-
location=report.location,
84-
keywords=report.keywords,
85-
)
86-
outcome.force_result(new_report)
87-
else:
88-
logger.debug(f"Test '{item.name}' failed with exception: {type(exception_value)} '{exception_value}'")
89-
90-
if report.when == "call" or (report.when == "setup" and report.skipped):
91-
try:
92-
log_test_vector_properties(
93-
item=item,
94-
report=report,
95-
xfail_reason=xfail_reason,
96-
exception=call.excinfo.value if call.excinfo is not None else None,
97-
)
98-
except Exception as e:
99-
logger.error(f"Failed to log test vector properties: {e}")
100-
logger.exception(e)
101-
102-
103-
def log_error_properties(item: _pytest.python.Function, exception_value, exception_traceback):
104-
ex_class_name = f"{type(exception_value).__module__}.{type(exception_value).__name__}"
105-
ex_class_name = ex_class_name.replace("builtins.", "")
106-
item.user_properties.append(("exception_value", f"{ex_class_name}: {exception_value}"))
107-
item.user_properties.append(("exception_traceback", exception_traceback))
108-
109-
110-
def log_test_vector_properties(
111-
item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str, exception: Exception
112-
):
113-
original_name = item.originalname
114-
test_id = item.name
115-
test_id = test_id.replace(f"{original_name}[", "")
116-
test_id = test_id.replace("]", "")
117-
if test_id == "no_device-test_vector0":
118-
# This is not a valid test id. It happens when no tests are selected to run.
119-
return
120-
test_vector = TestPlanUtils.test_id_to_test_vector(test_id)
121-
122-
SweepsTagsLogger.log_test_properties(test_vector=test_vector)
123-
124-
item.user_properties.append(("id", test_id))
125-
item.user_properties.append(("operator", test_vector.operator))
126-
item.user_properties.append(
127-
("input_source", test_vector.input_source.name if test_vector.input_source is not None else None)
128-
)
129-
item.user_properties.append(("dev_data_format", TestPlanUtils.dev_data_format_to_str(test_vector.dev_data_format)))
130-
item.user_properties.append(
131-
("math_fidelity", test_vector.math_fidelity.name if test_vector.math_fidelity is not None else None)
132-
)
133-
item.user_properties.append(("input_shape", test_vector.input_shape))
134-
item.user_properties.append(("kwargs", test_vector.kwargs))
135-
if xfail_reason is not None:
136-
item.user_properties.append(("xfail_reason", xfail_reason))
137-
item.user_properties.append(("outcome", report.outcome))
138-
139-
if exception is not None:
140-
error_message = f"{exception}"
141-
142-
if "Observed maximum relative diff" in error_message:
143-
error_message_lines = error_message.split("\n")
144-
observed_error_lines = [line for line in error_message_lines if "Observed maximum relative diff" in line]
145-
if observed_error_lines:
146-
observed_error_line = observed_error_lines[0]
147-
# Example: "- Observed maximum relative diff: 0.0008770461427047849, maximum absolute diff: 0.0009063482284545898"
148-
rtol = float(observed_error_line.split(",")[0].split(":")[1].strip())
149-
atol = float(observed_error_line.split(",")[1].split(":")[1].strip())
150-
else:
151-
logger.error(f"Error parsing 'Observed maximum relative diff' from the exception: {error_message}")
152-
rtol = None
153-
atol = None
154-
item.user_properties.append(("all_close_rtol", rtol))
155-
item.user_properties.append(("all_close_atol", atol))
25+
SweepsPytestReport.adjust_report(item, call, outcome, report)

forge/test/operators/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .failing_reasons_validation import FailingReasonsValidation
3939
from .pytest import PyTestUtils
4040
from .pytest import PytestParamsUtils
41+
from .pytest_report import SweepsPytestReport
4142
from .datatypes import TestDevice
4243

4344

@@ -79,5 +80,6 @@
7980
"FailingReasonsValidation",
8081
"PyTestUtils",
8182
"PytestParamsUtils",
83+
"SweepsPytestReport",
8284
"TestDevice",
8385
]

forge/test/operators/utils/failing_reasons_validation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@
77

88
from loguru import logger
99

10+
from typing import Optional
11+
1012
from .failing_reasons import ExceptionData
1113
from .failing_reasons import FailingReasons
1214
from .failing_reasons import FailingReasonsFinder
1315

1416

1517
class FailingReasonsValidation:
1618
@classmethod
17-
def validate_exception(cls, exception_value: Exception, exception_traceback: str, xfail_reason: str):
19+
def validate_exception(
20+
cls, exception_value: Exception, exception_traceback: str, xfail_reason: str
21+
) -> Optional[bool]:
1822
"""Validate exception based on xfail reason
1923
2024
Args:
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
2+
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# Pytest report adjustments and logging
6+
7+
from typing import TypeAlias
8+
9+
import _pytest
10+
import _pytest.python
11+
import _pytest.reports
12+
import _pytest.runner
13+
14+
from loguru import logger
15+
16+
from ..utils import PyTestUtils
17+
from ..utils import SweepsTagsLogger
18+
from ..utils import FailingReasonsFinder
19+
from ..utils import FailingReasonsValidation
20+
from ..utils import TestPlanUtils
21+
22+
from ..utils.frontend import XLA_MODE
23+
24+
if XLA_MODE:
25+
import pluggy
26+
27+
ResultType: TypeAlias = pluggy.Result
28+
else:
29+
import pluggy.callers
30+
31+
ResultType: TypeAlias = pluggy.callers._Result
32+
33+
34+
class SweepsPytestReport:
35+
@classmethod
36+
def adjust_report(
37+
cls, item: _pytest.python.Function, call: _pytest.runner.CallInfo, outcome, report: _pytest.reports.TestReport
38+
):
39+
outcome: ResultType = outcome
40+
41+
xfail_reason = None
42+
43+
if report.when == "call" or (report.when == "setup" and report.skipped):
44+
xfail_reason = PyTestUtils.get_xfail_reason(item)
45+
46+
# This hook function is called after each step of the test execution (setup, call, teardown)
47+
if call.when == "call": # 'call' is a phase when the test is actually executed
48+
49+
if xfail_reason is not None: # an xfail reason is defined for the test
50+
SweepsTagsLogger.log_expected_failing_reason(xfail_reason=xfail_reason)
51+
52+
if call.excinfo is not None: # an exception occurred during the test execution
53+
54+
logger.trace(
55+
f"Test: skipped: {report.skipped} failed: {report.failed} passed: {report.passed} report: {report}"
56+
)
57+
58+
exception_value = call.excinfo.value
59+
long_repr = call.excinfo.getrepr(style="long")
60+
exception_traceback = str(long_repr)
61+
62+
cls.log_error_properties(item, exception_value, exception_traceback)
63+
64+
ex_data = FailingReasonsFinder.build_ex_data(exception_value, exception_traceback)
65+
SweepsTagsLogger.log_detected_failing_reason(ex_data)
66+
67+
if xfail_reason is not None: # an xfail reason is defined for the test
68+
valid_reason = FailingReasonsValidation.validate_exception(
69+
exception_value, exception_traceback, xfail_reason
70+
)
71+
72+
if valid_reason is None:
73+
# Consider unknown valid_reason as valid
74+
valid_reason = True
75+
# if reason is not valid, mark the test as failed and keep the original exception
76+
if not valid_reason:
77+
# Replace test report with a new one with outcome set to 'failed' and exception details
78+
new_report = _pytest.reports.TestReport(
79+
item=item,
80+
when=call.when,
81+
outcome="failed",
82+
longrepr=call.excinfo.getrepr(style="long"),
83+
sections=report.sections,
84+
nodeid=report.nodeid,
85+
location=report.location,
86+
keywords=report.keywords,
87+
)
88+
outcome.force_result(new_report)
89+
else:
90+
logger.debug(
91+
f"Test '{item.name}' failed with exception: {type(exception_value)} '{exception_value}'"
92+
)
93+
94+
if report.when == "call" or (report.when == "setup" and report.skipped):
95+
try:
96+
cls.log_test_vector_properties(
97+
item=item,
98+
report=report,
99+
xfail_reason=xfail_reason,
100+
exception=call.excinfo.value if call.excinfo is not None else None,
101+
)
102+
except Exception as e:
103+
logger.error(f"Failed to log test vector properties: {e}")
104+
logger.exception(e)
105+
106+
@classmethod
107+
def log_error_properties(cls, item: _pytest.python.Function, exception_value, exception_traceback):
108+
ex_class_name = f"{type(exception_value).__module__}.{type(exception_value).__name__}"
109+
ex_class_name = ex_class_name.replace("builtins.", "")
110+
item.user_properties.append(("exception_value", f"{ex_class_name}: {exception_value}"))
111+
item.user_properties.append(("exception_traceback", exception_traceback))
112+
113+
@classmethod
114+
def log_test_vector_properties(
115+
cls, item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str, exception: Exception
116+
):
117+
original_name = item.originalname
118+
test_id = item.name
119+
test_id = test_id.replace(f"{original_name}[", "")
120+
test_id = test_id.replace("]", "")
121+
if test_id == "no_device-test_vector0":
122+
# This is not a valid test id. It happens when no tests are selected to run.
123+
return
124+
test_vector = TestPlanUtils.test_id_to_test_vector(test_id)
125+
126+
SweepsTagsLogger.log_test_properties(test_vector=test_vector)
127+
128+
item.user_properties.append(("id", test_id))
129+
item.user_properties.append(("operator", test_vector.operator))
130+
item.user_properties.append(
131+
("input_source", test_vector.input_source.name if test_vector.input_source is not None else None)
132+
)
133+
item.user_properties.append(
134+
("dev_data_format", TestPlanUtils.dev_data_format_to_str(test_vector.dev_data_format))
135+
)
136+
item.user_properties.append(
137+
("math_fidelity", test_vector.math_fidelity.name if test_vector.math_fidelity is not None else None)
138+
)
139+
item.user_properties.append(("input_shape", test_vector.input_shape))
140+
item.user_properties.append(("kwargs", test_vector.kwargs))
141+
if xfail_reason is not None:
142+
item.user_properties.append(("xfail_reason", xfail_reason))
143+
item.user_properties.append(("outcome", report.outcome))
144+
145+
if exception is not None:
146+
error_message = f"{exception}"
147+
148+
if "Observed maximum relative diff" in error_message:
149+
error_message_lines = error_message.split("\n")
150+
observed_error_lines = [
151+
line for line in error_message_lines if "Observed maximum relative diff" in line
152+
]
153+
if observed_error_lines:
154+
observed_error_line = observed_error_lines[0]
155+
# Example: "- Observed maximum relative diff: 0.0008770461427047849, maximum absolute diff: 0.0009063482284545898"
156+
rtol = float(observed_error_line.split(",")[0].split(":")[1].strip())
157+
atol = float(observed_error_line.split(",")[1].split(":")[1].strip())
158+
else:
159+
logger.error(f"Error parsing 'Observed maximum relative diff' from the exception: {error_message}")
160+
rtol = None
161+
atol = None
162+
item.user_properties.append(("all_close_rtol", rtol))
163+
item.user_properties.append(("all_close_atol", atol))

0 commit comments

Comments
 (0)