Skip to content

Commit 17190b9

Browse files
committed
Initial fixes with weak references
1 parent 670bbee commit 17190b9

File tree

4 files changed

+66
-10
lines changed

4 files changed

+66
-10
lines changed

ignite/engine/engine.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ def execute_something():
328328

329329
try:
330330
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
331-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
331+
# Use weak reference to break circular reference
332+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
332333
except ValueError:
333334
_check_signature(handler, "handler", *(event_args + args), **kwargs)
334335
self._event_handlers[event_name].append((handler, args, kwargs))
@@ -432,7 +433,15 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
432433
self.last_event_name = event_name
433434
for func, args, kwargs in self._event_handlers[event_name]:
434435
kwargs.update(event_kwargs)
435-
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
436+
# Resolve weak references if present
437+
if args and isinstance(args[0], weakref.ref):
438+
resolved_engine = args[0]()
439+
if resolved_engine is None:
440+
# Engine was garbage collected, skip this handler
441+
continue
442+
first, others = ((resolved_engine,), args[1:])
443+
else:
444+
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else (tuple(), args)
436445
func(*first, *(event_args + others), **kwargs)
437446

438447
def fire_event(self, event_name: Any) -> None:

ignite/handlers/base_logger.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections.abc as collections
44
import numbers
55
import warnings
6+
import weakref
67
from abc import ABCMeta, abstractmethod
78
from collections import OrderedDict
89
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -261,18 +262,34 @@ def attach(
261262
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
262263
"""
263264
if isinstance(event_name, EventsList):
265+
# Use weak reference to break circular reference: engine -> _event_handlers -> BaseLogger
266+
weak_self = weakref.ref(self)
267+
268+
def weak_log_handler(engine, event_name):
269+
logger_obj = weak_self()
270+
if logger_obj is not None:
271+
log_handler(engine, logger_obj, event_name)
272+
264273
for name in event_name:
265274
if name not in State.event_to_attr:
266275
raise RuntimeError(f"Unknown event name '{name}'")
267-
engine.add_event_handler(name, log_handler, self, name)
276+
engine.add_event_handler(name, weak_log_handler, name)
268277

269278
return RemovableEventHandle(event_name, log_handler, engine)
270279

271280
else:
272281
if event_name not in State.event_to_attr:
273282
raise RuntimeError(f"Unknown event name '{event_name}'")
274283

275-
return engine.add_event_handler(event_name, log_handler, self, event_name, *args, **kwargs)
284+
# Use weak reference to break circular reference: engine -> _event_handlers -> BaseLogger
285+
weak_self = weakref.ref(self)
286+
287+
def weak_log_handler(engine):
288+
logger_obj = weak_self()
289+
if logger_obj is not None:
290+
log_handler(engine, logger_obj, event_name, *args, **kwargs)
291+
292+
return engine.add_event_handler(event_name, weak_log_handler)
276293

277294
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
278295
"""Shortcut method to attach `OutputHandler` to the logger.

ignite/handlers/tqdm_logger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""TQDM logger."""
33
from collections import OrderedDict
44
from typing import Any, Callable, List, Optional, Union
5+
import weakref
56

67
from ignite.engine import Engine, Events
78
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle
@@ -221,7 +222,15 @@ def attach( # type: ignore[override]
221222
)
222223

223224
super(ProgressBar, self).attach(engine, log_handler, event_name)
224-
engine.add_event_handler(closing_event_name, self._close)
225+
# Use weak reference to break circular reference: engine -> _event_handlers -> ProgressBar
226+
weak_self = weakref.ref(self)
227+
228+
def weak_close(engine):
229+
pbar_obj = weak_self()
230+
if pbar_obj is not None:
231+
pbar_obj._close(engine)
232+
233+
engine.add_event_handler(closing_event_name, weak_close)
225234

226235
def attach_opt_params_handler( # type: ignore[empty-body]
227236
self, engine: Engine, event_name: Union[str, Events], *args: Any, **kwargs: Any

ignite/metrics/metric.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import wraps
55
from numbers import Number
66
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
7+
import weakref
78

89
import torch
910

@@ -560,11 +561,31 @@ def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = Epo
560561
assert metric.is_attached(engine, usage=BatchWise.usage_name)
561562
"""
562563
usage = self._check_usage(usage)
563-
if not engine.has_event_handler(self.started, usage.STARTED):
564-
engine.add_event_handler(usage.STARTED, self.started)
565-
if not engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED):
566-
engine.add_event_handler(usage.ITERATION_COMPLETED, self.iteration_completed)
567-
engine.add_event_handler(usage.COMPLETED, self.completed, name)
564+
565+
# Use weak reference to break circular reference: engine -> _event_handlers -> Metric
566+
weak_self = weakref.ref(self)
567+
568+
def weak_started(engine: Engine) -> None:
569+
metric_obj = weak_self()
570+
if metric_obj is not None:
571+
metric_obj.started(engine)
572+
573+
def weak_iteration_completed(engine: Engine) -> None:
574+
metric_obj = weak_self()
575+
if metric_obj is not None:
576+
metric_obj.iteration_completed(engine)
577+
578+
def weak_completed(engine: Engine, name: str) -> None:
579+
metric_obj = weak_self()
580+
if metric_obj is not None:
581+
metric_obj.completed(engine, name)
582+
583+
# Skip has_event_handler check since we're using different handler functions
584+
# Note: This means handlers may be added multiple times, but that's generally safe
585+
# for metrics as they are typically attached once per metric
586+
engine.add_event_handler(usage.STARTED, weak_started)
587+
engine.add_event_handler(usage.ITERATION_COMPLETED, weak_iteration_completed)
588+
engine.add_event_handler(usage.COMPLETED, weak_completed, name)
568589

569590
def detach(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> None:
570591
"""

0 commit comments

Comments
 (0)