-
Notifications
You must be signed in to change notification settings - Fork 22
Description
Describe the bug
Description
-
RequirementsManagerinstalls per-model dependencies at runtime within the same Python test process. -
If a package (e.g., transformers) has already been imported before entering the
RequirementsManagercontext, and a different version is installed for a specific test, Python may continue using the previously imported module objects. -
This creates version skew between:
a) The newly installed package versions on disk
b) The in-memory module objects already loaded in sys.modules
Because pip install runs in a subprocess, the active Python process does not automatically reload updated packages. As a result, subsequent imports may resolve to stale module objects.
Primary Failure Observed
Reproduction
# Run failing test on main
pytest -svv tests/runner/test_models.py::test_all_models_torch[pi_0/pytorch-lerobot_pi0_libero_base-single_device-inference] &> pi_0.log
pip install transformers==4.57.1
# Reinstall the environment’s default package version, since the current RequirementsManager does not track packages installed via @ references; as a result, those packages are never properly rolled back - for more (https://github.com/tenstorrent/tt-xla/issues/3317)
Complete Log - pi_0_failure.log
Example Failure
ImportError: cannot import name 'LossKwargs' from 'transformers.utils'
Root Cause
-
transformers==4.57.1was imported earlier during test collection/setup. -
RequirementsManager installs
transformers==4.53.3for the pi_0 model. -
sys.modules still references the in-memory
4.57.1module. -
This results in API mismatch and missing symbols due to mixing old in-memory modules with the newly installed package version.
Why This Is a Bug
- The requirements flow is intended to provide temporary, per-test dependency overrides.
- However, performing dependency swaps inside a live Python process does not reliably update already-imported modules.
- This leads to non-deterministic behavior and test failures unrelated to actual model logic.
Attempted Mitigation (Branch: akannan/clear_sys_modules)
- Introduced logic to:
a) Detect packages whose versions changed.
b) Clear corresponding entries from sys.modules.
c) Force re-import of those packages
This resolved the original stale-import issue for models such as pi_0.
Validation Steps
git checkout akannan/clear_sys_modules
# pi_0 test – no more import error
pytest -svv tests/runner/test_models.py::test_all_models_torch[pi_0/pytorch-lerobot_pi0_libero_base-single_device-inference] &> pi_0.log
Complete log - pi_0_after_fixlog.log
Secondary Issue Introduced by the Mitigation
-
While clearing sys.modules fixes stale import resolution, it introduces a new problem:
-
If objects were instantiated from the original import of a module, and the module is later reloaded, Python treats the reloaded classes as distinct type identities.
-
This causes isinstance checks to fail even though class names are identical.
Validation Steps for Example Failure (GPT-2)
pytest -svv tests/runner/test_models.py::test_all_models_jax[gpt2/causal_lm/jax-Base-single_device-inference] &> gpt2.log
Failure:
tests/infra/testers/single_chip/model/jax_model_tester.py:67: in _configure_model_for_inference
assert isinstance(self._model, (nnx.Module, linen.Module, FlaxPreTrainedModel))
E AssertionError
- Complete Log - gpt2_issue_after_fix.log
This regression also appeared in CI after introducing the sys.modules clearing logic:
- CI Failure - https://github.com/tenstorrent/tt-xla/actions/runs/21979514830/job/63499436546#step:17:401
Why This Happens
- The model instance was created from classes loaded before sys.modules was cleared.
- The assertion references classes from the re-imported module.
- Although names match, Python treats them as different class objects due to reloading.
- isinstance therefore returns False.
Summary
Original issue: In-process dependency swaps cause stale imports and version skew.
Mitigation side effect: Clearing sys.modules introduces class identity mismatches, leading to assertion failures.