Skip to content

Commit d0d2f49

Browse files
committed
test
1 parent 670bbee commit d0d2f49

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

ignite/engine/engine.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
__all__ = ["Engine"]
1818

1919

20+
REF_COUNTER = 0
21+
2022
class Engine(Serializable):
2123
"""Runs a given ``process_function`` over each batch of a dataset, emitting events as it goes.
2224
@@ -163,6 +165,21 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
163165
# generator provided by self._internal_run_as_gen
164166
self._internal_run_generator: Optional[Generator[Any, None, State]] = None
165167

168+
def __del__(self) -> None:
169+
"""Finalize the engine."""
170+
self._event_handlers.clear()
171+
print("__del__", self)
172+
print("Cleared event handlers for {}".format(self))
173+
# global REF_COUNTER
174+
# REF_COUNTER += 1
175+
# print("counter", REF_COUNTER)
176+
# try to clear event handlers to break circular references
177+
# try:
178+
# self._event_handlers.clear()
179+
# print("Cleared event handlers for {}".format(self))
180+
# except Exception:
181+
# pass
182+
166183
def register_events(
167184
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
168185
) -> None:
@@ -328,6 +345,7 @@ def execute_something():
328345

329346
try:
330347
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
348+
# self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
331349
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
332350
except ValueError:
333351
_check_signature(handler, "handler", *(event_args + args), **kwargs)
@@ -433,6 +451,15 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
433451
for func, args, kwargs in self._event_handlers[event_name]:
434452
kwargs.update(event_kwargs)
435453
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
454+
# if args and isinstance(args[0], weakref.ref):
455+
# resolved_engine = args[0]()
456+
# if resolved_engine is None:
457+
# raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
458+
# first, others = ((resolved_engine,), args[1:])
459+
# else:
460+
# # metrics do not provide engine when registered
461+
# first, others = (tuple(), args) # type: ignore[assignment]
462+
436463
func(*first, *(event_args + others), **kwargs)
437464

438465
def fire_event(self, event_name: Any) -> None:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torch.nn as nn
3+
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
4+
from ignite.handlers import ProgressBar, TensorboardLogger
5+
from ignite.handlers.tensorboard_logger import OptimizerParamsHandler
6+
from torch.optim import Adam
7+
from ignite.metrics import Loss
8+
from torch.utils.data import DataLoader, TensorDataset
9+
10+
11+
def do(model, dataloader, device):
12+
optim = Adam(model.parameters(), 1e-4)
13+
loss = nn.BCEWithLogitsLoss()
14+
trainer = create_supervised_trainer(model, optim, loss, device)
15+
metrics = {"Loss": Loss(loss)}
16+
evaluator = create_supervised_evaluator(model, metrics, device)
17+
18+
pbar = ProgressBar()
19+
pbar.attach(trainer)
20+
21+
tb_logger = TensorboardLogger(log_dir="runs")
22+
tb_logger.attach(trainer, OptimizerParamsHandler(optim), Events.EPOCH_STARTED)
23+
24+
trainer.run(dataloader, 1)
25+
@trainer.on(Events.COMPLETED)
26+
def completed(engine):
27+
evaluator.run(dataloader)
28+
29+
tb_logger.close()
30+
pbar.close()
31+
32+
# del trainer
33+
# del evaluator
34+
35+
def test_circular_references():
36+
device = torch.device("cuda")
37+
x = torch.rand(32, 1, 64, 64, 32)
38+
y = torch.round(torch.rand(32, 1))
39+
ds = TensorDataset(x, y)
40+
dataloader = DataLoader(ds, 6)
41+
for i in range(5):
42+
N = 3000
43+
model = nn.Sequential(nn.Flatten(), nn.Linear(64 * 64 * 32, N), nn.ReLU(), nn.Linear(N, 1))
44+
model = model.to(device)
45+
do(model, dataloader, device)
46+
print("!!!", i, torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())

0 commit comments

Comments
 (0)