Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import platform
import re

import pytest


Expand All @@ -31,3 +34,36 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "regression" in item.keywords:
item.add_marker(skip_regression)


# TODO: remove this once support for PyTorch 2.2 (the latest one still supported by GitHub MacOS x86_64 runners) is
# dropped, or if MacOS is removed from the test matrix, see https://github.com/huggingface/peft/issues/2431.
# Note: the function name is fixed by the pytest plugin system, don't change it
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item, call):
"""
Plug into the pytest test report generation to skip a specific MacOS failure caused by transformers.

The error was introduced by https://github.com/huggingface/transformers/pull/37785, which results in torch.load
failing when using torch < 2.6.

Since the MacOS x86 runners need to use an older torch version, those steps are necessary to get the CI green.
"""
outcome = yield
rep = outcome.get_result()
# ref:
# https://github.com/huggingface/transformers/blob/858ce6879a4aa7fa76a7c4e2ac20388e087ace26/src/transformers/utils/import_utils.py#L1418
error_msg = re.compile(r"Due to a serious vulnerability issue in `torch.load`")

# notes:
# - pytest uses hard-coded strings, we cannot import and use constants
# https://docs.pytest.org/en/stable/reference/reference.html#pytest.TestReport
# - errors can happen during call (running the test) but also setup (e.g. in fixtures)
if rep.failed and (rep.when in ("setup", "call")) and (platform.system() == "Darwin"):
exc_msg = str(call.excinfo.value)
if error_msg.search(exc_msg):
# turn this failure into an xfail:
rep.outcome = "skipped"
# for this attribute, see:
# https://github.com/pytest-dev/pytest/blob/bd6877e5874b50ee57d0f63b342a67298ee9a1c3/src/_pytest/reports.py#L266C5-L266C13
rep.wasxfail = "Error known to occur on MacOS with older torch versions, won't be fixed"
Loading