Skip to content

Commit d98e3d2

Browse files
committed
Add further details blog post.
1 parent 34b171c commit d98e3d2

File tree

6 files changed

+70
-12
lines changed

6 files changed

+70
-12
lines changed

ICLRblogpost/README.md

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,52 @@
11
# On the computation of the Fisher Information in continual learning (2025, ICLR Blogpost)
22

3-
... to be added ...
3+
The code in this repository is used for the experiments reported in the
4+
[ICLR 2025 blog post "On the computation of the Fisher Information in continual learning"](https://arxiv.org/abs/2502.11756).
5+
6+
This blog post compares the performance of Elastic Weight Consolidation (EWC) with various different ways of computing the diagonal elements of the Fisher Information matrix.
7+
The following options are considered:
8+
- **EXACT**
9+
The elements of the Fisher Information are computed exactly. All training samples are used.
10+
To use this option: `./main.py --ewc --fisher-labels='all'`
11+
12+
- **EXACT ($n$=500)**
13+
The elements of the Fisher Information are computed exactly. Only 500 training samples are used.
14+
To use this option: `./main.py --ewc --fisher-labels='all' --fisher-n=500`
15+
16+
- **SAMPLE**
17+
The elements of the Fisher Information are estimated using a single Monte Carlo sample. All training samples are used.
18+
To use this option: `./main.py --ewc --fisher-labels='sample'`
19+
20+
- **EMPIRICAL**
21+
The empirical Fisher Information is used. All training samples are used.
22+
To use this option: `./main.py --ewc --fisher-labels='true'`
23+
24+
- **BATCHED ($b$=128)**
25+
The empirical Fisher Information is approximated using mini-batches (see blog post for details).
26+
To use this option: `./main.py --ewc --fisher-labels='true' --fisher-batch=128`
27+
28+
29+
To run the experiments from the blog post, the following lines of code can be used:
30+
31+
```bash
32+
python3 ICLRblogpost/compare_FI.py --seed=1 --n-seeds=30 --experiment=splitMNIST --scenario=task
33+
python3 ICLRblogpost/compare_FI.py --seed=1 --n-seeds=30 --experiment=CIFAR10 --scenario=task --reducedResNet --iters=2000 --lr=0.001
34+
```
35+
36+
37+
### Citation
38+
If this is useful, please consider citing the blog post:
39+
```
40+
@inproceedings{vandeven2025fisher,
41+
title={On the computation of the {F}isher {I}nformation in continual learning},
42+
author={van de Ven, Gido M},
43+
booktitle={ICLR Blogposts 2025},
44+
year={2025},
45+
date={April 28, 2025}
46+
}
47+
```
48+
49+
50+
### Acknowledgments
51+
This project has been supported by a senior postdoctoral fellowship from the
52+
Resarch Foundation -- Flanders (FWO) under grant number 1266823N.

ICLRblogpost/compare_FI.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from params.param_values import check_for_errors,set_default_values
1111
from params import options
1212
from visual import visual_plt as my_plt
13-
import torch
1413

1514

1615
## Parameter-values to compare
@@ -33,8 +32,9 @@ def handle_inputs():
3332
parser.add_argument('--n-seeds', type=int, default=1, help='how often to repeat?')
3433
# Add options specific for EWC
3534
param_reg = parser.add_argument_group('Parameter Regularization')
36-
param_reg.add_argument('--online', action='store_true', help='use Online EWC rather than Offline EWC')
37-
param_reg.add_argument("--fisher-n-all", type=float, default=500, help="how many samples to approximate FI in 'ALL-n=X'")
35+
param_reg.add_argument('--offline', action='store_true', help='use Offline EWC rather than Online EWC')
36+
param_reg.add_argument("--fisher-n-all", type=float, default=500, metavar='N',
37+
help="how many samples to approximate FI in 'ALL-n=X'")
3838
# Parse, process (i.e., set defaults for unselected options) and check chosen options
3939
args = parser.parse_args()
4040
args.log_per_context = True
@@ -108,7 +108,6 @@ def collect_all(method_dict, seed_list, args, name=None):
108108
# -set EWC-specific arguments
109109
args.weight_penalty = True
110110
args.importance_weighting = 'fisher'
111-
args.offline = False if args.online else True
112111

113112
## EWC, "sample"
114113
SAMPLE = {}

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ see the folder [NeurIPS-tutorial](NeurIPStutorial).
4848

4949

5050
## ICLR blog post "On the computation of the Fisher Information in continual learning"
51-
This code repository is also used for the
51+
This repository is also used for the
5252
[ICLR 2025 blog post "On the computation of the Fisher Information in continual learning"](https://arxiv.org/abs/2502.11756).
5353
For details and instructions on how to re-run the experiments reported in this blog post,
5454
see the folder [ICLR-blogpost](ICLRblogpost).

all_results.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
########### ICLR 2025 Blogpost ###########
77

8-
python3 ICLRblogpost/compare_FI.py --seed=1 --n-seeds=30 --experiment=splitMNIST --scenario=task --online
9-
python3 ICLRblogpost/compare_FI.py --seed=1 --n-seeds=30 --experiment=CIFAR10 --scenario=task --contexts=5 --conv-type=resNet --fc-layers=1 --iters=2000 --reducing-layers=3 --depth=5 --global-pooling --channels=20 --lr=0.001 --online
8+
python3 ICLRblogpost/compare_FI.py --seed=1 --n-seeds=30 --experiment=splitMNIST --scenario=task
9+
python3 ICLRblogpost/compare_FI.py --seed=1 --n-seeds=30 --experiment=CIFAR10 --scenario=task --reducedResNet --iters=2000 --lr=0.001
1010

1111

1212

params/options.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,20 @@ def add_problem_options(parser, pretrain=False, no_boundaries=False, **kwargs):
9595

9696
def add_model_options(parser, pretrain=False, compare_replay=False, **kwargs):
9797
model = parser.add_argument_group('Parameters Main Model')
98+
# 'Convenience-commands' that select the defaults for specific architectures
99+
model.add_argument('--reducedResNet', action='store_true', help="select defaults for 'Reduced ResNet-18' (e.g., as in Hess et al, 2023)")
98100
# -convolutional layers
99101
model.add_argument('--conv-type', type=str, default="standard", choices=["standard", "resNet"])
100102
model.add_argument('--n-blocks', type=int, default=2, help="# blocks per conv-layer (only for 'resNet')")
101103
model.add_argument('--depth', type=int, default=None, help="# of convolutional layers (0 = only fc-layers)")
102-
model.add_argument('--reducing-layers', type=int, dest='rl', help="# of layers with stride (=image-size halved)")
103-
model.add_argument('--channels', type=int, default=16, help="# of channels 1st conv-layer (doubled every 'rl')")
104+
model.add_argument('--reducing-layers', type=int, dest='rl', default=None,
105+
help="# of layers with stride (=image-size halved)")
106+
model.add_argument('--channels', type=int, default=None, help="# of channels 1st conv-layer (doubled every 'rl')")
104107
model.add_argument('--conv-bn', type=str, default="yes", help="use batch-norm in the conv-layers (yes|no)")
105108
model.add_argument('--conv-nl', type=str, default="relu", choices=["relu", "leakyrelu"])
106109
model.add_argument('--global-pooling', action='store_true', dest='gp', help="ave global pool after conv-layers")
107110
# -fully connected layers
108-
model.add_argument('--fc-layers', type=int, default=3, dest='fc_lay', help="# of fully-connected layers")
111+
model.add_argument('--fc-layers', type=int, default=None, dest='fc_lay', help="# of fully-connected layers")
109112
model.add_argument('--fc-units', type=int, metavar="N", help="# of units in hidden fc-layers")
110113
model.add_argument('--fc-drop', type=float, default=0., help="dropout probability for fc-units")
111114
model.add_argument('--fc-bn', type=str, default="no", help="use batch-norm in the fc-layers (no|yes)")

params/param_values.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ def set_method_options(args, **kwargs):
4949
def set_default_values(args, also_hyper_params=True, single_context=False, no_boundaries=False):
5050
# -set default-values for certain arguments based on chosen experiment
5151
args.normalize = args.normalize if args.experiment in ('CIFAR10', 'CIFAR100') else False
52-
args.depth = (5 if args.experiment in ('CIFAR10', 'CIFAR100') else 0) if args.depth is None else args.depth
52+
args.depth = (
53+
5 if (args.experiment in ('CIFAR10', 'CIFAR100')) or checkattr(args, 'reducedResNet') else 0
54+
) if args.depth is None else args.depth
55+
args.fc_lay = (1 if checkattr(args, 'reducedResNet') else 3) if args.fc_lay is None else args.fc_lay
56+
args.channels = (20 if checkattr(args, 'reducedResNet') else 16) if args.channels is None else args.channels
57+
args.rl = 3 if checkattr(args, 'reducedResNet') and (args.rl is None) else args.rl
58+
args.gp = True if checkattr(args, 'reducedResNet') else args.gp
59+
args.conv_type = 'resNet' if checkattr(args, 'reducedResNet') else args.conv_type
5360
if not single_context:
5461
args.contexts = (
5562
5 if args.experiment in ('splitMNIST', 'CIFAR10') else 10

0 commit comments

Comments
 (0)