Skip to content

Commit dcac448

Browse files
goanpecavfdev-5
andauthored
Use weak reference to break circular reference and memory leaks (#3447)
Fixes #3438 --- @vfdev-5 I can confirm this PR fixes the issue. <img width="900" height="416" alt="Screenshot 2025-09-05 at 10 21 49 PM" src="https://github.com/user-attachments/assets/8ce24a22-6606-4122-9bfa-77a161087d97" /> --------- Co-authored-by: vfdev <[email protected]>
1 parent 3b04445 commit dcac448

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

ignite/engine/engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def execute_something():
339339

340340
try:
341341
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
342-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
342+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
343343
except ValueError:
344344
_check_signature(handler, "handler", *(event_args + args), **kwargs)
345345
self._event_handlers[event_name].append((handler, args, kwargs))
@@ -443,7 +443,15 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
443443
self.last_event_name = event_name
444444
for func, args, kwargs in self._event_handlers[event_name]:
445445
kwargs.update(event_kwargs)
446-
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
446+
if args and isinstance(args[0], weakref.ref):
447+
resolved_engine = args[0]()
448+
if resolved_engine is None:
449+
raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
450+
first, others = ((resolved_engine,), args[1:])
451+
else:
452+
# metrics do not provide engine when registered
453+
first, others = (tuple(), args) # type: ignore[assignment]
454+
447455
func(*first, *(event_args + others), **kwargs)
448456

449457
def fire_event(self, event_name: Any) -> None:

ignite/handlers/visdom_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Visdom logger and its helper handlers."""
22

33
import os
4-
from typing import Any, Callable, cast, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, List, Optional, Union
55

66
import torch
77
import torch.nn as nn
@@ -179,7 +179,7 @@ def __init__(
179179
)
180180

181181
if server is None:
182-
server = cast(str, os.environ.get("VISDOM_SERVER_URL", "localhost"))
182+
server = os.environ.get("VISDOM_SERVER_URL", "localhost")
183183

184184
if port is None:
185185
port = int(os.environ.get("VISDOM_PORT", 8097))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import weakref
2+
3+
import pytest
4+
5+
from ignite.engine import Engine, Events
6+
7+
8+
class TestEngineMemoryLeak:
9+
"""See: https://github.com/pytorch/ignite/issues/3438"""
10+
11+
ENGINE_WEAK_REFS = set()
12+
13+
def do_train(self, cls, with_handler) -> None:
14+
engine = cls(lambda e, b: None)
15+
16+
if with_handler:
17+
18+
@engine.on(Events.EPOCH_STARTED)
19+
def handler(engine) -> None:
20+
pass
21+
22+
engine.run(range(5), max_epochs=5)
23+
self.ENGINE_WEAK_REFS.add(weakref.ref(engine))
24+
25+
@pytest.mark.parametrize("with_handler", [True, False])
26+
def test_memory_leak(self, with_handler):
27+
num_iters = 5
28+
counter = 0
29+
30+
class EngineForTests(Engine):
31+
32+
def __del__(self):
33+
nonlocal counter
34+
counter += 1
35+
36+
for i in range(num_iters):
37+
self.do_train(EngineForTests, with_handler)
38+
for weak_engine_ref in self.ENGINE_WEAK_REFS:
39+
engine = weak_engine_ref()
40+
assert engine is None
41+
42+
assert counter == i + 1
43+
44+
assert counter == i + 1

0 commit comments

Comments
 (0)