Skip to content

Commit ac456ab

Browse files
authored
Merge branch 'master' into fix/fsdp-tied-weights-warning
2 parents b64b197 + bb7820f commit ac456ab

2 files changed

Lines changed: 36 additions & 3 deletions

File tree

src/lightning/fabric/utilities/data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,13 @@ def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] =
382382
for patched_name in ("__setattr__", "__delattr__", "__init__"):
383383
# Check that __old__{init,setattr,delattr} belongs to the class
384384
# https://stackoverflow.com/a/5253424
385-
if f"__old{patched_name}" in cls.__dict__:
386-
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
387-
delattr(cls, f"__old{patched_name}")
385+
old_name = f"__old{patched_name}"
386+
if old_name in cls.__dict__:
387+
try:
388+
setattr(cls, patched_name, getattr(cls, old_name))
389+
delattr(cls, old_name)
390+
except AttributeError:
391+
pass
388392

389393

390394
def _replace_value_in_saved_args(

tests/tests_fabric/utilities/test_data.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,35 @@ def test_replace_dunder_methods_multiple_loaders_without_init():
8080
assert before[cls] == cls.__init__
8181

8282

83+
def test_replace_dunder_methods_cleanup_tolerates_concurrent_restore():
84+
class ConcurrentCleanupMeta(type):
85+
def __getattribute__(cls, name):
86+
if (
87+
name == "__old__delattr__"
88+
and type.__getattribute__(cls, "_cleanup_started")
89+
and not type.__getattribute__(cls, "_restore_complete")
90+
):
91+
original_method = type.__getattribute__(cls, name)
92+
type.__setattr__(cls, "__delattr__", original_method)
93+
type.__delattr__(cls, name)
94+
type.__setattr__(cls, "_restore_complete", True)
95+
raise AttributeError
96+
return type.__getattribute__(cls, name)
97+
98+
class ConcurrentBatchSampler(BatchSampler, metaclass=ConcurrentCleanupMeta):
99+
_cleanup_started = False
100+
_restore_complete = False
101+
102+
pass
103+
104+
original_delattr = ConcurrentBatchSampler.__delattr__
105+
with _replace_dunder_methods(ConcurrentBatchSampler):
106+
ConcurrentBatchSampler._cleanup_started = True
107+
108+
assert ConcurrentBatchSampler.__delattr__ is original_delattr
109+
assert "__old__delattr__" not in ConcurrentBatchSampler.__dict__
110+
111+
83112
class MyBaseDataLoader(DataLoader):
84113
pass
85114

0 commit comments

Comments
 (0)