|
| 1 | +import torch |
| 2 | +from torch import Tensor |
| 3 | + |
| 4 | +from capymoa.base import BatchClassifier |
| 5 | +from capymoa.ocl.util._coreset import ReservoirSampler |
| 6 | +from capymoa.ocl.base import TrainTaskAware, TestTaskAware |
| 7 | + |
| 8 | +from typing import Callable |
| 9 | + |
| 10 | +from torch import nn |
| 11 | + |
| 12 | + |
| 13 | +class RAR(BatchClassifier, TrainTaskAware, TestTaskAware): |
| 14 | + """Repeated Augmented Rehearsal. |
| 15 | +
|
| 16 | + Repeated Augmented Rehearsal (RAR) [#f0]_ is a replay continual learning |
| 17 | + strategy that combines data augmentation with repeated training on each |
| 18 | + batch to mitigate catastrophic forgetting. |
| 19 | +
|
| 20 | + * Coreset Selection: Reservoir sampling is used to select a fixed-size |
| 21 | + buffer of past examples. |
| 22 | +
|
| 23 | + * Coreset Retrieval: During training, the learner samples uniformly from the |
| 24 | + buffer of past examples. |
| 25 | +
|
| 26 | + * Coreset Exploitation: The learner trains on the current batch of examples |
| 27 | + and the sampled buffer examples, performing multiple optimization steps |
| 28 | + per-batch using random augmentations of the examples. |
| 29 | +
|
| 30 | + * Not :class:`~capymoa.ocl.base.TrainTaskAware` or |
| 31 | + :class:`~capymoa.ocl.base.TestTaskAware`, but will proxy it to the wrapped |
| 32 | + learner. |
| 33 | +
|
| 34 | + >>> from capymoa.ann import Perceptron |
| 35 | + >>> from capymoa.classifier import Finetune |
| 36 | + >>> from capymoa.ocl.strategy import RAR |
| 37 | + >>> from capymoa.ocl.datasets import TinySplitMNIST |
| 38 | + >>> from capymoa.ocl.evaluation import ocl_train_eval_loop |
| 39 | + >>> import torch |
| 40 | + >>> _ = torch.manual_seed(0) |
| 41 | + >>> scenario = TinySplitMNIST() |
| 42 | + >>> model = Perceptron(scenario.schema) |
| 43 | + >>> learner = RAR(Finetune(scenario.schema, model), augment=nn.Dropout(p=0.2), repeats=2) |
| 44 | + >>> results = ocl_train_eval_loop( |
| 45 | + ... learner, |
| 46 | + ... scenario.train_loaders(32), |
| 47 | + ... scenario.test_loaders(32), |
| 48 | + ... ) |
| 49 | + >>> print(f"{results.accuracy_final*100:.1f}%") |
| 50 | + 41.5% |
| 51 | +
|
| 52 | + Usually more complex augmentations are used such as random crops and |
| 53 | + rotations. |
| 54 | +
|
| 55 | + .. [#f0] Zhang, Yaqian, Bernhard Pfahringer, Eibe Frank, Albert Bifet, Nick |
| 56 | + Jin Sean Lim, and Yunzhe Jia. “A Simple but Strong Baseline for Online |
| 57 | + Continual Learning: Repeated Augmented Rehearsal.” In Advances in Neural |
| 58 | + Information Processing Systems 35: Annual Conference on Neural |
| 59 | + Information Processing Systems 2022, NeurIPS 2022, New Orleans, LA, USA, |
| 60 | + November 28 - December 9, 2022, edited by Sanmi Koyejo, S. Mohamed, A. |
| 61 | + Agarwal, Danielle Belgrave, K. Cho, and A. Oh, 2022. |
| 62 | + https://doi.org/10.5555/3600270.3601344. |
| 63 | + """ |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + learner: BatchClassifier, |
| 68 | + coreset_size: int = 200, |
| 69 | + augment: Callable[[Tensor], Tensor] = nn.Identity(), |
| 70 | + repeats: int = 1, |
| 71 | + ) -> None: |
| 72 | + """Initialize Repeated Augmented Rehearsal. |
| 73 | +
|
| 74 | + :param learner: Underlying learner to be trained with RAR. |
| 75 | + :param coreset_size: Size of the coreset buffer. |
| 76 | + :param augment: Data augmentation function to apply to the samples. Should take |
| 77 | + a Tensor of shape ``(batch_size, *schema.shape)`` and return a Tensor of the |
| 78 | + same shape. |
| 79 | + :param repeats: Number of times to repeat training on each batch, defaults to 1. |
| 80 | + """ |
| 81 | + |
| 82 | + super().__init__(learner.schema) |
| 83 | + num_features = learner.schema.get_num_attributes() |
| 84 | + self.learner = learner |
| 85 | + self.augment = augment.to(self.device) |
| 86 | + self.repeats = repeats |
| 87 | + self.coreset = ReservoirSampler( |
| 88 | + coreset_size, |
| 89 | + num_features, |
| 90 | + rng=torch.Generator().manual_seed(learner.random_seed), |
| 91 | + ) |
| 92 | + self.shape = learner.schema.shape |
| 93 | + |
| 94 | + def train_step(self, x_fresh: Tensor, y_fresh: Tensor) -> None: |
| 95 | + # Sample from reservoir and augment the data |
| 96 | + n = x_fresh.shape[0] |
| 97 | + x_replay, y_replay = self.coreset.sample(n) |
| 98 | + x = torch.cat((x_fresh, x_replay), dim=0).to(self.device, self.x_dtype) |
| 99 | + y = torch.cat((y_fresh, y_replay), dim=0).to(self.device, self.y_dtype) |
| 100 | + x = x.view(-1, *self.shape) |
| 101 | + x: Tensor = self.augment(x) |
| 102 | + |
| 103 | + # Train the learner |
| 104 | + x = x.to(self.learner.device, self.learner.x_dtype) |
| 105 | + y = y.to(self.learner.device, self.learner.y_dtype) |
| 106 | + self.learner.batch_train(x, y) |
| 107 | + |
| 108 | + def batch_train(self, x: Tensor, y: Tensor) -> None: |
| 109 | + self.coreset.update(x, y) |
| 110 | + for i in range(self.repeats): |
| 111 | + self.train_step(x, y) |
| 112 | + |
| 113 | + @torch.no_grad() |
| 114 | + def batch_predict_proba(self, x: Tensor) -> Tensor: |
| 115 | + x = x.to(self.learner.device, self.learner.x_dtype) |
| 116 | + return self.learner.batch_predict_proba(x) |
| 117 | + |
| 118 | + def on_test_task(self, task_id: int): |
| 119 | + if isinstance(self.learner, TestTaskAware): |
| 120 | + self.learner.on_test_task(task_id) |
| 121 | + |
| 122 | + def on_train_task(self, task_id: int): |
| 123 | + if isinstance(self.learner, TrainTaskAware): |
| 124 | + self.learner.on_train_task(task_id) |
0 commit comments