Skip to content

Commit 902767e

Browse files
committed
Use weak references to resolve circular references and add test
1 parent 670bbee commit 902767e

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

ignite/engine/engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def execute_something():
328328

329329
try:
330330
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
331-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
331+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
332332
except ValueError:
333333
_check_signature(handler, "handler", *(event_args + args), **kwargs)
334334
self._event_handlers[event_name].append((handler, args, kwargs))
@@ -432,7 +432,15 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
432432
self.last_event_name = event_name
433433
for func, args, kwargs in self._event_handlers[event_name]:
434434
kwargs.update(event_kwargs)
435-
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
435+
if args and isinstance(args[0], weakref.ref):
436+
resolved_engine = args[0]()
437+
if resolved_engine is None:
438+
raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
439+
first, others = ((resolved_engine,), args[1:])
440+
else:
441+
# metrics do not provide engine when registered
442+
first, others = (tuple(), args) # type: ignore[assignment]
443+
436444
func(*first, *(event_args + others), **kwargs)
437445

438446
def fire_event(self, event_name: Any) -> None:
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import sys
2+
import weakref
3+
import torch
4+
import torch.nn as nn
5+
from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events
6+
from ignite.handlers import ProgressBar, TensorboardLogger
7+
from ignite.handlers.tensorboard_logger import OptimizerParamsHandler
8+
from torch.optim import Adam
9+
from ignite.metrics import Loss
10+
from torch.utils.data import DataLoader, TensorDataset
11+
12+
13+
class TestEngineMemoryLeak:
14+
ENGINE_WEAK_REFS = {}
15+
16+
def do(self, model, dataloader, device, runs_folder):
17+
optim = Adam(model.parameters(), 1e-4)
18+
loss = nn.BCEWithLogitsLoss()
19+
trainer = create_supervised_trainer(model, optim, loss, device)
20+
metrics = {"Loss": Loss(loss)}
21+
evaluator = create_supervised_evaluator(model, metrics, device)
22+
23+
pbar = ProgressBar()
24+
pbar.attach(trainer)
25+
26+
tb_logger = TensorboardLogger(log_dir=runs_folder)
27+
tb_logger.attach(trainer, OptimizerParamsHandler(optim), Events.EPOCH_STARTED)
28+
29+
trainer.run(dataloader, 1)
30+
31+
@trainer.on(Events.COMPLETED)
32+
def completed(engine):
33+
evaluator.run(dataloader)
34+
35+
tb_logger.close()
36+
pbar.close()
37+
38+
self.ENGINE_WEAK_REFS[weakref.ref(trainer)] = sys.getrefcount(trainer) - 1
39+
self.ENGINE_WEAK_REFS[weakref.ref(evaluator)] = sys.getrefcount(evaluator) - 1
40+
41+
def test_circular_references(self, tmp_path):
42+
runs_folder = tmp_path / "runs"
43+
runs_folder.mkdir()
44+
45+
if torch.cuda.is_available():
46+
device = torch.device("cuda")
47+
else:
48+
device = torch.device("cpu")
49+
50+
x = torch.rand(32, 1, 64, 64, 32)
51+
y = torch.round(torch.rand(32, 1))
52+
ds = TensorDataset(x, y)
53+
dataloader = DataLoader(ds, 6)
54+
for i in range(5):
55+
N = 3000
56+
model = nn.Sequential(nn.Flatten(), nn.Linear(64 * 64 * 32, N), nn.ReLU(), nn.Linear(N, 1))
57+
model = model.to(device)
58+
self.do(model, dataloader, device, runs_folder)
59+
for engine_weak_ref, val in self.ENGINE_WEAK_REFS.items():
60+
engine = engine_weak_ref()
61+
if engine is not None:
62+
ref_count = sys.getrefcount(engine) - 1
63+
error_message = f"Engine Memory Leak: {engine} - Ref Count: {ref_count}"
64+
print(error_message)
65+
assert ref_count == 0
66+
67+
print("!!!", i, torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())

0 commit comments

Comments
 (0)