@@ -80,6 +80,35 @@ def test_replace_dunder_methods_multiple_loaders_without_init():
8080 assert before [cls ] == cls .__init__
8181
8282
83+ def test_replace_dunder_methods_cleanup_tolerates_concurrent_restore ():
84+ class ConcurrentCleanupMeta (type ):
85+ def __getattribute__ (cls , name ):
86+ if (
87+ name == "__old__delattr__"
88+ and type .__getattribute__ (cls , "_cleanup_started" )
89+ and not type .__getattribute__ (cls , "_restore_complete" )
90+ ):
91+ original_method = type .__getattribute__ (cls , name )
92+ type .__setattr__ (cls , "__delattr__" , original_method )
93+ type .__delattr__ (cls , name )
94+ type .__setattr__ (cls , "_restore_complete" , True )
95+ raise AttributeError
96+ return type .__getattribute__ (cls , name )
97+
98+ class ConcurrentBatchSampler (BatchSampler , metaclass = ConcurrentCleanupMeta ):
99+ _cleanup_started = False
100+ _restore_complete = False
101+
102+ pass
103+
104+ original_delattr = ConcurrentBatchSampler .__delattr__
105+ with _replace_dunder_methods (ConcurrentBatchSampler ):
106+ ConcurrentBatchSampler ._cleanup_started = True
107+
108+ assert ConcurrentBatchSampler .__delattr__ is original_delattr
109+ assert "__old__delattr__" not in ConcurrentBatchSampler .__dict__
110+
111+
83112class MyBaseDataLoader (DataLoader ):
84113 pass
85114
0 commit comments