Skip to content

Commit b43c01a

Browse files
tachyonicClockhmgomes
authored andcommitted
feat(ocl): add Repeated Augmented Rehearsal (RAR)
1 parent 2520019 commit b43c01a

File tree

5 files changed

+144
-3
lines changed

5 files changed

+144
-3
lines changed

.github/workflows/pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ env:
2525
jobs:
2626
tests:
2727
name: "Tests"
28-
timeout-minutes: 20
28+
timeout-minutes: 30
2929
runs-on: ubuntu-latest
3030
strategy:
3131
fail-fast: true

src/capymoa/ann/_perceptron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def forward(self, x: Tensor) -> Tensor:
2525
:param x: Input tensor of shape ``(batch_size, num_features)``.
2626
:return: Output tensor of shape ``(batch_size, num_classes)``.
2727
"""
28+
x = x.view(x.size(0), -1) # Flatten input
2829
x = self._fc1(x)
2930
x = self._relu(x)
3031
x = self._fc2(x)

src/capymoa/ocl/strategy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from ._slda import SLDA
33
from ._ncm import NCM
44
from ._gdumb import GDumb
5+
from ._rar import RAR
56

6-
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb"]
7+
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR"]

src/capymoa/ocl/strategy/_rar.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch
2+
from torch import Tensor
3+
4+
from capymoa.base import BatchClassifier
5+
from capymoa.ocl.util._coreset import ReservoirSampler
6+
from capymoa.ocl.base import TrainTaskAware, TestTaskAware
7+
8+
from typing import Callable
9+
10+
from torch import nn
11+
12+
13+
class RAR(BatchClassifier, TrainTaskAware, TestTaskAware):
14+
"""Repeated Augmented Rehearsal.
15+
16+
Repeated Augmented Rehearsal (RAR) [#f0]_ is a replay continual learning
17+
strategy that combines data augmentation with repeated training on each
18+
batch to mitigate catastrophic forgetting.
19+
20+
* Coreset Selection: Reservoir sampling is used to select a fixed-size
21+
buffer of past examples.
22+
23+
* Coreset Retrieval: During training, the learner samples uniformly from the
24+
buffer of past examples.
25+
26+
* Coreset Exploitation: The learner trains on the current batch of examples
27+
and the sampled buffer examples, performing multiple optimization steps
28+
per-batch using random augmentations of the examples.
29+
30+
* Not :class:`~capymoa.ocl.base.TrainTaskAware` or
31+
:class:`~capymoa.ocl.base.TestTaskAware`, but will proxy it to the wrapped
32+
learner.
33+
34+
>>> from capymoa.ann import Perceptron
35+
>>> from capymoa.classifier import Finetune
36+
>>> from capymoa.ocl.strategy import RAR
37+
>>> from capymoa.ocl.datasets import TinySplitMNIST
38+
>>> from capymoa.ocl.evaluation import ocl_train_eval_loop
39+
>>> import torch
40+
>>> _ = torch.manual_seed(0)
41+
>>> scenario = TinySplitMNIST()
42+
>>> model = Perceptron(scenario.schema)
43+
>>> learner = RAR(Finetune(scenario.schema, model), augment=nn.Dropout(p=0.2), repeats=2)
44+
>>> results = ocl_train_eval_loop(
45+
... learner,
46+
... scenario.train_loaders(32),
47+
... scenario.test_loaders(32),
48+
... )
49+
>>> print(f"{results.accuracy_final*100:.1f}%")
50+
41.5%
51+
52+
Usually more complex augmentations are used such as random crops and
53+
rotations.
54+
55+
.. [#f0] Zhang, Yaqian, Bernhard Pfahringer, Eibe Frank, Albert Bifet, Nick
56+
Jin Sean Lim, and Yunzhe Jia. “A Simple but Strong Baseline for Online
57+
Continual Learning: Repeated Augmented Rehearsal.” In Advances in Neural
58+
Information Processing Systems 35: Annual Conference on Neural
59+
Information Processing Systems 2022, NeurIPS 2022, New Orleans, LA, USA,
60+
November 28 - December 9, 2022, edited by Sanmi Koyejo, S. Mohamed, A.
61+
Agarwal, Danielle Belgrave, K. Cho, and A. Oh, 2022.
62+
https://doi.org/10.5555/3600270.3601344.
63+
"""
64+
65+
def __init__(
66+
self,
67+
learner: BatchClassifier,
68+
coreset_size: int = 200,
69+
augment: Callable[[Tensor], Tensor] = nn.Identity(),
70+
repeats: int = 1,
71+
) -> None:
72+
"""Initialize Repeated Augmented Rehearsal.
73+
74+
:param learner: Underlying learner to be trained with RAR.
75+
:param coreset_size: Size of the coreset buffer.
76+
:param augment: Data augmentation function to apply to the samples. Should take
77+
a Tensor of shape ``(batch_size, *schema.shape)`` and return a Tensor of the
78+
same shape.
79+
:param repeats: Number of times to repeat training on each batch, defaults to 1.
80+
"""
81+
82+
super().__init__(learner.schema)
83+
num_features = learner.schema.get_num_attributes()
84+
self.learner = learner
85+
self.augment = augment.to(self.device)
86+
self.repeats = repeats
87+
self.coreset = ReservoirSampler(
88+
coreset_size,
89+
num_features,
90+
rng=torch.Generator().manual_seed(learner.random_seed),
91+
)
92+
self.shape = learner.schema.shape
93+
94+
def train_step(self, x_fresh: Tensor, y_fresh: Tensor) -> None:
95+
# Sample from reservoir and augment the data
96+
n = x_fresh.shape[0]
97+
x_replay, y_replay = self.coreset.sample(n)
98+
x = torch.cat((x_fresh, x_replay), dim=0).to(self.device, self.x_dtype)
99+
y = torch.cat((y_fresh, y_replay), dim=0).to(self.device, self.y_dtype)
100+
x = x.view(-1, *self.shape)
101+
x: Tensor = self.augment(x)
102+
103+
# Train the learner
104+
x = x.to(self.learner.device, self.learner.x_dtype)
105+
y = y.to(self.learner.device, self.learner.y_dtype)
106+
self.learner.batch_train(x, y)
107+
108+
def batch_train(self, x: Tensor, y: Tensor) -> None:
109+
self.coreset.update(x, y)
110+
for i in range(self.repeats):
111+
self.train_step(x, y)
112+
113+
@torch.no_grad()
114+
def batch_predict_proba(self, x: Tensor) -> Tensor:
115+
x = x.to(self.learner.device, self.learner.x_dtype)
116+
return self.learner.batch_predict_proba(x)
117+
118+
def on_test_task(self, task_id: int):
119+
if isinstance(self.learner, TestTaskAware):
120+
self.learner.on_test_task(task_id)
121+
122+
def on_train_task(self, task_id: int):
123+
if isinstance(self.learner, TrainTaskAware):
124+
self.learner.on_train_task(task_id)

tests/ocl/test_strategy.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from capymoa.classifier import Finetune, HoeffdingTree
1212
from capymoa.ocl.datasets import TinySplitMNIST
1313
from capymoa.ocl.evaluation import ocl_train_eval_loop
14-
from capymoa.ocl.strategy import ExperienceReplay, SLDA, NCM, GDumb
14+
from capymoa.ocl.strategy import ExperienceReplay, SLDA, NCM, GDumb, RAR
1515
from capymoa.stream import Schema
1616

1717
import torch
@@ -46,6 +46,15 @@ def pre_processor() -> nn.Module:
4646
)
4747

4848

49+
def _new_rar(schema):
50+
# RAR test case constructor
51+
return RAR(
52+
Finetune(schema, Perceptron(schema)),
53+
augment=nn.Dropout(p=0.2),
54+
repeats=2,
55+
)
56+
57+
4958
"""
5059
Add new test cases here.
5160
@@ -55,6 +64,7 @@ def pre_processor() -> nn.Module:
5564
TEST_CASES: List[Case] = [
5665
Case("HoeffdingTree", HoeffdingTree, Result(59.49, 42.59, 45.8), batch_size=1),
5766
Case("HoeffdingTree", HoeffdingTree, Result(59.00, 42.80, 42.5), batch_size=32),
67+
Case("RAR", _new_rar, Result(41.50, 28.20, 8.20)),
5868
Case(
5969
"Finetune",
6070
partial(Finetune, model=Perceptron),
@@ -92,7 +102,12 @@ def test_ocl_classifier(case: Case):
92102
if os.environ.get("CI") == "true" and "SLDA" in case.name:
93103
pytest.skip("Skipping SLDA case on CI due to unreliable dataset download")
94104
scenario = TinySplitMNIST()
105+
106+
# Set random seeds for reproducibility
95107
torch.manual_seed(0)
108+
np.random.seed(0)
109+
torch.use_deterministic_algorithms(True)
110+
96111
learner = case.constructor(scenario.schema)
97112
r = ocl_train_eval_loop(
98113
learner,

0 commit comments

Comments
 (0)