Skip to content

Commit b7894a9

Browse files
authored
Merge branch 'master' into expose-fblogger
2 parents 768a803 + 1e2583c commit b7894a9

File tree

13 files changed

+194
-20
lines changed

13 files changed

+194
-20
lines changed

.github/workflows/gpu-hvd-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
pytorch-channel: [pytorch]
2525
fail-fast: false
2626
env:
27-
DOCKER_IMAGE: "pytorch/conda-builder:cuda12.1"
27+
DOCKER_IMAGE: "pytorch/almalinux-builder:cuda12.8"
2828
REPOSITORY: ${{ github.repository }}
2929
PR_NUMBER: ${{ github.event.pull_request.number }}
3030
runs-on: linux.8xlarge.nvidia.gpu
@@ -113,6 +113,10 @@ jobs:
113113
pip install -r requirements-dev.txt
114114
pip install -e .
115115
116+
# Upgrade pyOpenSSL to avoid issue:
117+
# AttributeError: module 'lib' has no attribute 'X509_V_FLAG_NOTIFY_POLICY'. Did you mean: 'X509_V_FLAG_EXPLICIT_POLICY'?
118+
pip install -U pyOpenSSL
119+
116120
EOF
117121
)
118122

.github/workflows/gpu-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
pytorch-channel: [pytorch, pytorch-nightly]
2525
fail-fast: false
2626
env:
27-
DOCKER_IMAGE: "pytorch/almalinux-builder:cuda12.4"
27+
DOCKER_IMAGE: "pytorch/almalinux-builder:cuda12.8"
2828
REPOSITORY: ${{ github.repository }}
2929
PR_NUMBER: ${{ github.event.pull_request.number }}
3030
runs-on: linux.g4dn.12xlarge.nvidia.gpu
@@ -113,6 +113,10 @@ jobs:
113113
pip install -r requirements-dev.txt
114114
pip install -e .
115115
116+
# Upgrade pyOpenSSL to avoid issue:
117+
# AttributeError: module 'lib' has no attribute 'X509_V_FLAG_NOTIFY_POLICY'. Did you mean: 'X509_V_FLAG_EXPLICIT_POLICY'?
118+
pip install -U pyOpenSSL
119+
116120
EOF
117121
)
118122

.github/workflows/mps-tests.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
pytorch-channel: ["pytorch"]
3737
skip-distrib-tests: [1]
3838
fail-fast: false
39-
runs-on: ["macos-m1-stable"]
39+
runs-on: ["macos-14"]
4040
timeout-minutes: 60
4141

4242
steps:
@@ -71,7 +71,7 @@ jobs:
7171
- name: Install uv
7272
shell: bash -l {0}
7373
run: |
74-
conda shell.bash hook
74+
eval "$(conda shell.bash hook)"
7575
conda activate $CONDA_ENV
7676
pip install uv
7777
@@ -85,23 +85,23 @@ jobs:
8585
if: ${{ matrix.pytorch-channel == 'pytorch' }}
8686
shell: bash -l {0}
8787
run: |
88-
conda shell.bash hook
88+
eval "$(conda shell.bash hook)"
8989
conda activate $CONDA_ENV
9090
uv pip install torch torchvision
9191
9292
- name: Install PyTorch (nightly)
9393
if: ${{ matrix.pytorch-channel == 'pytorch-nightly' }}
9494
shell: bash -l {0}
9595
run: |
96-
conda shell.bash hook
96+
eval "$(conda shell.bash hook)"
9797
conda activate $CONDA_ENV
9898
uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
9999
100100
- name: Install dependencies
101101
shell: bash -l {0}
102102
working-directory: ${{ github.repository }}
103103
run: |
104-
conda shell.bash hook
104+
eval "$(conda shell.bash hook)"
105105
conda activate $CONDA_ENV
106106
# TODO: We add set -xe to explicitly fail the CI if one of the commands is failing.
107107
# Somehow the step is passing even if a subcommand failed
@@ -129,7 +129,7 @@ jobs:
129129
shell: bash -l {0}
130130
working-directory: ${{ github.repository }}
131131
run: |
132-
conda shell.bash hook
132+
eval "$(conda shell.bash hook)"
133133
conda activate $CONDA_ENV
134134
SKIP_DISTRIB_TESTS=${{ matrix.skip-distrib-tests }} bash tests/run_cpu_tests.sh
135135
@@ -144,6 +144,6 @@ jobs:
144144
shell: bash -l {0}
145145
working-directory: ${{ github.repository }}
146146
run: |
147-
conda shell.bash hook
147+
eval "$(conda shell.bash hook)"
148148
conda activate $CONDA_ENV
149149
python examples/mnist/mnist.py --epochs=1

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,7 @@ jobs:
182182
set -e
183183
echo "Testing Caltech101 dataset availability..."
184184
if python -c "import torchvision; torchvision.datasets.Caltech101(root='./data', download=True)"; then
185-
echo "Caltech101 dataset downloaded successfully. Please remove this workaround and restore dataset check."
186-
exit 1
187-
# python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2 --debug
185+
python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2 --debug
188186
else
189187
echo "Caltech101 dataset failed to download. Skipping SR example test."
190188
fi

docs/source/handlers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Complete list of generic handlers
1111
:toctree: generated
1212

1313
checkpoint.Checkpoint
14+
checkpoint.CheckpointEvents
1415
DiskSaver
1516
checkpoint.ModelCheckpoint
1617
ema_handler.EMAHandler

ignite/engine/engine.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,17 @@ class TBPTT_Events(EventEnum):
249249
# we need to update state attributes associated with new custom events
250250
self.state._update_attrs()
251251

252+
def has_registered_events(self, event: Any) -> bool:
253+
"""Check whether engine has a registered event.
254+
255+
Args:
256+
event: Event to check for registration.
257+
258+
Returns:
259+
bool: True if the event is registered, False otherwise.
260+
"""
261+
return event in self._allowed_events
262+
252263
def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
253264
# signature of the following wrapper will be inspected during registering to check if engine is necessary
254265
# we have to build a wrapper with relevant signature : solution is functools.wraps
@@ -328,7 +339,7 @@ def execute_something():
328339

329340
try:
330341
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
331-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
342+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
332343
except ValueError:
333344
_check_signature(handler, "handler", *(event_args + args), **kwargs)
334345
self._event_handlers[event_name].append((handler, args, kwargs))
@@ -432,7 +443,15 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
432443
self.last_event_name = event_name
433444
for func, args, kwargs in self._event_handlers[event_name]:
434445
kwargs.update(event_kwargs)
435-
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
446+
if args and isinstance(args[0], weakref.ref):
447+
resolved_engine = args[0]()
448+
if resolved_engine is None:
449+
raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
450+
first, others = ((resolved_engine,), args[1:])
451+
else:
452+
# metrics do not provide engine when registered
453+
first, others = (tuple(), args) # type: ignore[assignment]
454+
436455
func(*first, *(event_args + others), **kwargs)
437456

438457
def fire_event(self, event_name: Any) -> None:

ignite/handlers/checkpoint.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,21 @@
2121

2222
import ignite.distributed as idist
2323
from ignite.base import Serializable
24-
from ignite.engine import Engine, Events
24+
from ignite.engine import Engine, Events, EventEnum
2525
from ignite.utils import _tree_apply2, _tree_map
2626

27-
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
27+
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
28+
29+
30+
class CheckpointEvents(EventEnum):
31+
"""Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
32+
33+
- SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
34+
35+
.. versionadded:: 0.5.3
36+
"""
37+
38+
SAVED_CHECKPOINT = "saved_checkpoint"
2839

2940

3041
class BaseSaveHandler(metaclass=ABCMeta):
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
264275
to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
265276
)
266277
278+
Respond to checkpoint events:
279+
280+
.. code-block:: python
281+
282+
from ignite.handlers import Checkpoint
283+
from ignite.engine import Engine, Events
284+
285+
checkpoint_handler = Checkpoint(
286+
{'model': model, 'optimizer': optimizer},
287+
save_dir,
288+
n_saved=2
289+
)
290+
291+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
292+
def on_checkpoint_saved(engine):
293+
print(f"Checkpoint saved at epoch {engine.state.epoch}")
294+
295+
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
296+
297+
Attributes:
298+
SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
299+
:class:`~ignite.handlers.checkpoint.CheckpointEvents`.
300+
267301
.. versionchanged:: 0.4.3
268302
269303
- Checkpoint can save model with same filename.
@@ -274,8 +308,13 @@ class Checkpoint(Serializable):
274308
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
275309
- `save_handler` automatically saves to disk if path to directory is provided.
276310
- `save_on_rank` saves objects on this rank in a distributed configuration.
311+
312+
.. versionchanged:: 0.5.3
313+
314+
- Added ``SAVED_CHECKPOINT`` class attribute.
277315
"""
278316

317+
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
279318
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
280319
_state_dict_all_req_keys = ("_saved",)
281320

@@ -400,6 +439,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
400439
return new > self._saved[0].priority
401440

402441
def __call__(self, engine: Engine) -> None:
442+
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
443+
engine.register_events(*CheckpointEvents)
403444
global_step = None
404445
if self.global_step_transform is not None:
405446
global_step = self.global_step_transform(engine, engine.last_event_name)
@@ -460,11 +501,11 @@ def __call__(self, engine: Engine) -> None:
460501
if self.include_self:
461502
# Now that we've updated _saved, we can add our own state_dict.
462503
checkpoint["checkpointer"] = self.state_dict()
463-
464504
try:
465505
self.save_handler(checkpoint, filename, metadata)
466506
except TypeError:
467507
self.save_handler(checkpoint, filename)
508+
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
468509

469510
def _setup_checkpoint(self) -> Dict[str, Any]:
470511
if self.to_save is not None:

ignite/handlers/visdom_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Visdom logger and its helper handlers."""
22

33
import os
4-
from typing import Any, Callable, cast, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, List, Optional, Union
55

66
import torch
77
import torch.nn as nn
@@ -179,7 +179,7 @@ def __init__(
179179
)
180180

181181
if server is None:
182-
server = cast(str, os.environ.get("VISDOM_SERVER_URL", "localhost"))
182+
server = os.environ.get("VISDOM_SERVER_URL", "localhost")
183183

184184
if port is None:
185185
port = int(os.environ.get("VISDOM_PORT", 8097))

tests/ignite/distributed/test_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def test_dist_proxy_sampler():
311311
DistributedProxySampler(None)
312312

313313
with pytest.raises(TypeError, match=r"Argument sampler should have length"):
314-
DistributedProxySampler(Sampler([1]))
314+
DistributedProxySampler(Sampler())
315315

316316
with pytest.raises(TypeError, match=r"Argument sampler must not be a distributed sampler already"):
317317
DistributedProxySampler(DistributedSampler(sampler, num_replicas=num_replicas, rank=0))

tests/ignite/engine/test_custom_events.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ def process_func(engine, batch):
4646
assert h.called
4747

4848

49+
def test_has_registered_events_custom():
50+
"""Test has_registered_events with custom events."""
51+
52+
class TestEvents(EventEnum):
53+
CUSTOM_EVENT = "custom_event"
54+
55+
engine = Engine(lambda e, b: None)
56+
57+
# Custom event not registered yet
58+
assert not engine.has_registered_events(TestEvents.CUSTOM_EVENT)
59+
60+
# Register custom event
61+
engine.register_events(TestEvents.CUSTOM_EVENT)
62+
63+
# Now should return True
64+
assert engine.has_registered_events(TestEvents.CUSTOM_EVENT)
65+
66+
4967
def test_custom_events_asserts():
5068
# Dummy engine
5169
engine = Engine(lambda engine, batch: 0)

0 commit comments

Comments
 (0)