Skip to content

Commit 697cdd0

Browse files
committed
Use weak references to resolve circular references and add test
1 parent 670bbee commit 697cdd0

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

ignite/engine/engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def execute_something():
328328

329329
try:
330330
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
331-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
331+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
332332
except ValueError:
333333
_check_signature(handler, "handler", *(event_args + args), **kwargs)
334334
self._event_handlers[event_name].append((handler, args, kwargs))
@@ -432,7 +432,15 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
432432
self.last_event_name = event_name
433433
for func, args, kwargs in self._event_handlers[event_name]:
434434
kwargs.update(event_kwargs)
435-
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
435+
if args and isinstance(args[0], weakref.ref):
436+
resolved_engine = args[0]()
437+
if resolved_engine is None:
438+
raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
439+
first, others = ((resolved_engine,), args[1:])
440+
else:
441+
# metrics do not provide engine when registered
442+
first, others = (tuple(), args) # type: ignore[assignment]
443+
436444
func(*first, *(event_args + others), **kwargs)
437445

438446
def fire_event(self, event_name: Any) -> None:
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import sys
2+
import weakref
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch.optim import Adam
7+
from torch.utils.data import DataLoader, TensorDataset
8+
9+
from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events
10+
from ignite.handlers import ProgressBar, TensorboardLogger
11+
from ignite.handlers.tensorboard_logger import OptimizerParamsHandler
12+
from ignite.metrics import Loss
13+
14+
15+
class TestEngineMemoryLeak:
16+
"""See: https://github.com/pytorch/ignite/issues/3438"""
17+
18+
ENGINE_WEAK_REFS = {}
19+
20+
def do(self, model, dataloader, device, runs_folder):
21+
optim = Adam(model.parameters(), 1e-4)
22+
loss = nn.BCEWithLogitsLoss()
23+
trainer = create_supervised_trainer(model, optim, loss, device)
24+
metrics = {"Loss": Loss(loss)}
25+
evaluator = create_supervised_evaluator(model, metrics, device)
26+
27+
pbar = ProgressBar()
28+
pbar.attach(trainer)
29+
30+
tb_logger = TensorboardLogger(log_dir=runs_folder)
31+
tb_logger.attach(trainer, OptimizerParamsHandler(optim), Events.EPOCH_STARTED)
32+
33+
trainer.run(dataloader, 1)
34+
35+
@trainer.on(Events.COMPLETED)
36+
def completed(engine):
37+
evaluator.run(dataloader)
38+
39+
tb_logger.close()
40+
pbar.close()
41+
42+
self.ENGINE_WEAK_REFS[weakref.ref(trainer)] = sys.getrefcount(trainer) - 1
43+
self.ENGINE_WEAK_REFS[weakref.ref(evaluator)] = sys.getrefcount(evaluator) - 1
44+
45+
def test_circular_references(self, tmp_path):
46+
runs_folder = tmp_path / "runs"
47+
runs_folder.mkdir()
48+
49+
if torch.cuda.is_available():
50+
device = torch.device("cuda")
51+
else:
52+
device = torch.device("cpu")
53+
54+
x = torch.rand(32, 1, 64, 64, 32)
55+
y = torch.round(torch.rand(32, 1))
56+
ds = TensorDataset(x, y)
57+
dataloader = DataLoader(ds, 6)
58+
for i in range(5):
59+
N = 3000
60+
model = nn.Sequential(nn.Flatten(), nn.Linear(64 * 64 * 32, N), nn.ReLU(), nn.Linear(N, 1))
61+
model = model.to(device)
62+
self.do(model, dataloader, device, runs_folder)
63+
for engine_weak_ref, val in self.ENGINE_WEAK_REFS.items():
64+
engine = engine_weak_ref()
65+
if engine is not None:
66+
ref_count = sys.getrefcount(engine) - 1
67+
error_message = f"Engine Memory Leak: {engine} - Ref Count: {ref_count}"
68+
print(error_message)
69+
assert ref_count == 0
70+
71+
print("!!!", i, torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())

0 commit comments

Comments
 (0)