Skip to content

Commit 50b8b7f

Browse files
committed
Add stability gap example.
1 parent d98e3d2 commit 50b8b7f

File tree

7 files changed

+219
-3
lines changed

7 files changed

+219
-3
lines changed

ICLRblogpost/compare_FI.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import os
44
import numpy as np
5-
# -change working directory to parent directory
5+
# -expand module search path to parent directory
66
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
77
# -custom-written code
88
import main

NeurIPStutorial/compare_for_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import os
44
import numpy as np
5-
# -change working directory to parent directory
5+
# -expand module search path to parent directory
66
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
77
# -custom-written code
88
import main

StabilityGap/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Stability Gap example
2+
3+
The script `stability_gap_example.py` provides a simple example of the **stability gap** [(De Lange et al.; 2023, *ICLR*)](https://openreview.net/forum?id=Zy350cRstc6). This phenomenon of temporary forgetting can be consistently observed when using state-of-the-art continual learning methods (e.g., replay or regularization) to incrementally train a deep neural network on multiple tasks. Strikingly, as described by [Hess et al. (2023, *ContinualAI Unconference*)](https://proceedings.mlr.press/v249/hess24a.html), the stability gap occurs even with **incremental joint training** (i.e., when training on a new task, all previous tasks are fully retrained as well), which can be interpreted as "full replay" or "perfect regularization".
4+
5+
The example in this script uses **Rotated MNIST** with three tasks (rotations: 0°, 80° and 160°) as the task sequence:
6+
7+
![image](../figures/rotatedMNIST.png)
8+
9+
This task sequence is performed according to the domain-incremental learning scenario ([van de Ven et al.; 2022, *Nat Mach Intell*](https://www.nature.com/articles/s42256-022-00568-3)).
10+
A fully-connected neural network (with two hidden layers of 400 ReLUs each) is trained on this task sequence using incremental joint training, while the model's performance on the first task is evaluated after each training iteration.
11+
12+
Running this script should produce a plot similar to:
13+
14+
![image](../figures/stabilityGap.png)

StabilityGap/stability_gap_example.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python3
2+
3+
# Standard libraries
4+
import sys
5+
import os
6+
import numpy as np
7+
import tqdm
8+
# Pytorch
9+
import torch
10+
from torch.nn import functional as F
11+
from torchvision import datasets, transforms
12+
# For visualization
13+
from torchvision.utils import make_grid
14+
import matplotlib.pyplot as plt
15+
16+
# Expand the module search path to parent directory
17+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
18+
# Load custom-written code
19+
import utils
20+
from visual import visual_plt
21+
from eval.evaluate import test_acc
22+
from models.classifier import Classifier
23+
from data.manipulate import TransformedDataset
24+
25+
26+
################## INITIAL SET-UP ##################
27+
28+
# Specify directories, and if needed create them
29+
p_dir = "./store/plots"
30+
d_dir = "./store/data"
31+
if not os.path.isdir(p_dir):
32+
print("Creating directory: {}".format(p_dir))
33+
os.makedirs(p_dir)
34+
if not os.path.isdir(d_dir):
35+
os.makedirs(d_dir)
36+
print("Creating directory: {}".format(d_dir))
37+
38+
# Open pdf for plotting
39+
plot_name = "stability_gap_example"
40+
full_plot_name = "{}/{}.pdf".format(p_dir, plot_name)
41+
pp = visual_plt.open_pdf(full_plot_name)
42+
figure_list = []
43+
44+
45+
46+
################## CREATE TASK SEQUENCE ##################
47+
48+
## Download the MNIST dataset
49+
print("\n\n " +' LOAD DATA '.center(70, '*'))
50+
MNIST_trainset = datasets.MNIST(root='data/', train=True, download=True,
51+
transform=transforms.ToTensor())
52+
MNIST_testset = datasets.MNIST(root='data/', train=False, download=True,
53+
transform=transforms.ToTensor())
54+
config = {'size': 28, 'channels': 1, 'classes': 10}
55+
56+
# Set for each task the amount of rotation to use
57+
rotations = [0, 80, 160]
58+
59+
# Specify for each task the transformed train- and testset
60+
n_tasks = len(rotations)
61+
train_datasets = []
62+
test_datasets = []
63+
for rotation in rotations:
64+
train_datasets.append(TransformedDataset(
65+
MNIST_trainset, transform=transforms.RandomRotation(degrees=(rotation,rotation)),
66+
))
67+
test_datasets.append(TransformedDataset(
68+
MNIST_testset, transform=transforms.RandomRotation(degrees=(rotation,rotation)),
69+
))
70+
71+
# Visualize the different tasks
72+
figure, axis = plt.subplots(1, n_tasks, figsize=(3*n_tasks, 4))
73+
n_samples = 49
74+
for task_id in range(len(train_datasets)):
75+
# Show [n_samples] examples from the training set for each task
76+
data_loader = torch.utils.data.DataLoader(train_datasets[task_id], batch_size=n_samples, shuffle=True)
77+
image_tensor, _ = next(iter(data_loader))
78+
image_grid = make_grid(image_tensor, nrow=int(np.sqrt(n_samples)), pad_value=1) # pad_value=0 would give black borders
79+
axis[task_id].imshow(np.transpose(image_grid.numpy(), (1,2,0)))
80+
axis[task_id].set_title("Task {}".format(task_id+1))
81+
axis[task_id].axis('off')
82+
figure_list.append(figure)
83+
84+
85+
86+
################## SET UP THE MODEL ##################
87+
88+
print("\n\n " + ' DEFINE THE CLASSIFIER '.center(70, '*'))
89+
90+
# Specify the architectural layout of the network to use
91+
fc_lay = 3 #--> number of fully-connected layers
92+
fc_units = 400 #--> number of units in each hidden layer
93+
94+
# Define the model
95+
model = Classifier(image_size=config['size'], image_channels=config['channels'], classes=config['classes'],
96+
fc_layers=fc_lay, fc_units=fc_units, fc_bn=False)
97+
98+
# Print some model info to screen
99+
utils.print_model_info(model)
100+
101+
102+
103+
################## TRAINING AND EVALUATION ##################
104+
105+
print('\n\n' + ' TRAINING + CONTINUAL EVALUATION '.center(70, '*'))
106+
107+
# Define a function to train a model, while also evaluating its performance after each iteration
108+
def train_and_evaluate(model, trainset, iters, lr, batch_size, testset,
109+
test_size=512, performance=[]):
110+
'''Function to train a [model] on a given [dataset],
111+
while evaluating after each training iteration on [testset].'''
112+
113+
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
114+
model.train()
115+
iters_left = 1
116+
progress_bar = tqdm.tqdm(range(1, iters+1))
117+
118+
for _ in range(1, iters+1):
119+
optimizer.zero_grad()
120+
121+
# Collect data from [trainset] and compute the loss
122+
iters_left -= 1
123+
if iters_left==0:
124+
data_loader = iter(torch.utils.data.DataLoader(trainset, batch_size=batch_size,
125+
shuffle=True, drop_last=True))
126+
iters_left = len(data_loader)
127+
x, y = next(data_loader)
128+
y_hat = model(x)
129+
loss = torch.nn.functional.cross_entropy(input=y_hat, target=y, reduction='mean')
130+
131+
# Calculate test accuracy (in %)
132+
accuracy = 100*test_acc(model, testset, test_size=test_size, verbose=False, batch_size=512)
133+
performance.append(accuracy)
134+
135+
# Take gradient step
136+
loss.backward()
137+
optimizer.step()
138+
progress_bar.set_description(
139+
'<CLASSIFIER> | training loss: {loss:.3} | test accuracy: {prec:.3}% |'
140+
.format(loss=loss.item(), prec=accuracy)
141+
)
142+
progress_bar.update(1)
143+
progress_bar.close()
144+
145+
# Specify the training parameters
146+
iters = 500 #--> for how many iterations to train?
147+
lr = 0.1 #--> learning rate
148+
batch_size = 128 #--> size of mini-batches
149+
test_size = 2000 #--> number of test samples to evaluate on after each iteration
150+
151+
# Define a list to keep track of the performance on task 1 after each iteration
152+
performance_task1 = []
153+
154+
# Iterate through the contexts
155+
for task_id in range(n_tasks):
156+
current_task = task_id+1
157+
158+
# Concatenate the training data of all tasks so far
159+
joint_dataset = torch.utils.data.ConcatDataset(train_datasets[:current_task])
160+
161+
# Determine the batch size to use
162+
batch_size_to_use = current_task*batch_size
163+
164+
# Train
165+
print('Training after arrival of Task {}:'.format(current_task))
166+
train_and_evaluate(model, trainset=joint_dataset, iters=iters, lr=lr,
167+
batch_size=batch_size_to_use, testset=test_datasets[0],
168+
test_size=test_size, performance=performance_task1)
169+
170+
171+
172+
################## PLOTTING ##################
173+
174+
## Plot per-iteration performance curve
175+
figure = visual_plt.plot_lines(
176+
[performance_task1], x_axes=list(range(n_tasks*iters)),
177+
line_names=['Incremental Joint'],
178+
title="Performance on Task 1 throughout 'Incremental Joint Training'",
179+
ylabel="Test Accuracy (%) on Task 1",
180+
xlabel="Total number of training iterations", figsize=(10,5),
181+
v_line=[iters*(i+1) for i in range(n_tasks-1)], v_label='Task switch', ylim=(70,100),
182+
)
183+
figure_list.append(figure)
184+
185+
## Finalize the pdf with the plots
186+
# -add figures to pdf
187+
for figure in figure_list:
188+
pp.savefig(figure)
189+
# -close pdf
190+
pp.close()
191+
# -print name of generated plot on screen
192+
print("\nGenerated plot: {}\n".format(full_plot_name))
193+

figures/rotatedMNIST.png

183 KB
Loading

figures/stabilityGap.png

71 KB
Loading

visual/visual_plt.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def plot_bar(numbers, names=None, colors=None, ylabel=None, title=None, top_titl
103103
def plot_lines(list_with_lines, x_axes=None, line_names=None, colors=None, title=None,
104104
title_top=None, xlabel=None, ylabel=None, ylim=None, figsize=None, list_with_errors=None, errors="shaded",
105105
x_log=False, with_dots=False, linestyle='solid', h_line=None, h_label=None, h_error=None,
106-
h_lines=None, h_colors=None, h_labels=None, h_errors=None):
106+
h_lines=None, h_colors=None, h_labels=None, h_errors=None,
107+
v_line=None, v_label=None):
107108
'''Generates a figure containing multiple lines in one plot.
108109
109110
:param list_with_lines: <list> of all lines to plot (with each line being a <list> as well)
@@ -179,6 +180,14 @@ def plot_lines(list_with_lines, x_axes=None, line_names=None, colors=None, title
179180
axarr.axhline(y=new_h_line-h_errors[line_id], label=None,
180181
color=None if (h_colors is None) else h_colors[line_id], linewidth=1,
181182
linestyle='dashed')
183+
184+
# add vertical line(s)
185+
if v_line is not None:
186+
if type(v_line)==list:
187+
for id,new_line in enumerate(v_line):
188+
axarr.axvline(x=new_line, label=v_label if id==0 else None, color='black')
189+
else:
190+
axarr.axvline(x=v_line, label=v_label, color='black')
182191

183192
# finish layout
184193
# -set y-axis

0 commit comments

Comments
 (0)