|
2 | 2 | Custom models can be created in one of three ways: |
3 | 3 |
|
4 | 4 | * Using the `StreamingModule` (recommended) |
5 | | -* Using the `BaseModel` (a subclass of `StreamingModule`) |
| 5 | +* Using the `ImageNetClassifier` (a subclass of `StreamingModule`) |
6 | 6 | * Creating your own class (not recommended) |
7 | 7 |
|
8 | | -The `StreamingModule` and `BaseModel` classes are both regular `LightningModule` objects and should be treated as such. |
| 8 | +The `StreamingModule` and `ImageNetClassifier` classes are both regular `LightningModule` objects and should be treated as such. |
9 | 9 | Both classes have several helper functions and a custom initialization that create the streaming model instance. Secondly, |
10 | 10 | the helper functions make sure that several settings, such as freezing normalization layers and setting them to `eval()` mode, both during training and inference. |
11 | 11 | This is necessary since streaming does not work with layers that are not locally defined, but rather need the entire input image. |
@@ -83,24 +83,44 @@ class StreamingResNet(StreamingModule): |
83 | 83 | tile_size: int, |
84 | 84 | loss_fn: torch.nn.functional, |
85 | 85 | train_streaming_layers: bool = True, |
86 | | - use_streaming: bool = True, |
87 | | - *args, |
| 86 | + metrics: MetricCollection | None = None, |
88 | 87 | **kwargs |
89 | 88 | ): |
90 | 89 | assert model_name in list(StreamingResNet.model_choices.keys()) |
91 | | - network = StreamingResNet.model_choices[model_name](weights="IMAGENET1K_V1") |
92 | | - stream_net, head = split_resnet(network, num_classes=kwargs.get("num_classes")) |
| 90 | + network = StreamingResNet.model_choices[model_name](weights="DEFAULT") |
| 91 | + stream_network, head = split_resnet(network, num_classes=kwargs.pop("num_classes", 1000)) |
| 92 | + |
| 93 | + self._get_streaming_options(**kwargs) |
| 94 | + print("metrics", metrics) |
93 | 95 | super().__init__( |
94 | | - stream_net, |
| 96 | + stream_network, |
95 | 97 | head, |
96 | 98 | tile_size, |
| 99 | + loss_fn, |
97 | 100 | train_streaming_layers=train_streaming_layers, |
98 | | - use_streaming=use_streaming, |
99 | | - *args, |
100 | | - **kwargs |
| 101 | + metrics=metrics, |
| 102 | + **self.streaming_options, |
101 | 103 | ) |
| 104 | + |
| 105 | + def _get_streaming_options(self, **kwargs): |
| 106 | + """Set streaming defaults, but overwrite them with values of kwargs if present.""" |
| 107 | + |
| 108 | + # We need to add torch.nn.Batchnorm to the keep modules, because of some in-place ops error if we don't |
| 109 | + # https://discuss.pytorch.org/t/register-full-backward-hook-for-residual-connection/146850 |
| 110 | + streaming_options = { |
| 111 | + "verbose": True, |
| 112 | + "copy_to_gpu": False, |
| 113 | + "statistics_on_cpu": True, |
| 114 | + "normalize_on_gpu": True, |
| 115 | + "mean": [0.485, 0.456, 0.406], |
| 116 | + "std": [0.229, 0.224, 0.225], |
| 117 | + "add_keep_modules": [torch.nn.BatchNorm2d], |
| 118 | + } |
| 119 | + self.streaming_options = {**streaming_options, **kwargs} |
102 | 120 |
|
103 | 121 | ``` |
| 122 | +The actual streaming module can be configured with varying settings. These can be passed as a kwarg dictionary within the `super().__init__()` call of the parent class. Under the hood, this dictionary with options is passed to a constructor which takes care of properly building the streaming module. |
| 123 | +A more detailed explanation about the constructor can be found below. |
104 | 124 |
|
105 | 125 |
|
106 | 126 | ## Custom forward/backward logic |
@@ -150,10 +170,44 @@ def backward(self, loss): |
150 | 170 | ``` |
151 | 171 |
|
152 | 172 |
|
153 | | -- Hooks: Several hooks in pytorch lightning are used to set the normalization layers to `eval()` and set the inputs/models to the right device (this is not how it should be done, but we are working on a solution for this). |
| 173 | +- Hooks: Several hooks in pytorch lightning are used to set the normalization layers to `eval()` and set the inputs/models to the right device. |
154 | 174 | - on_training_start: Allocates the input and models to the correct device at training time. |
155 | 175 | - on_validation_start: Allocates the input and models to the correct device at validation time. |
156 | 176 | - on_test_start: Allocates the input and models to the correct device at test time. |
157 | 177 | - on_train_epoch_start(self): sets all the normalization layers to eval() during training |
158 | 178 |
|
159 | | -**Warning: do not override these hooks with your own code. If you need these hooks for any reason, then call the parent method first using e.g. `super().on_training_start`** |
| 179 | +**Warning: do not override these hooks with your own code. If you need these hooks for any reason, then call the parent method first using e.g. `super().on_training_start`** |
| 180 | + |
| 181 | + |
| 182 | +## Streaming using the constructor |
| 183 | +Under the hood of the `StreamingModule` class, we have an additional `Constructor` class that actually builds and defines the streaming module. The following arguments can be passed to it: |
| 184 | +At the very least, a torch model and a tile size must be provided. |
| 185 | +```python |
| 186 | +model: torch.nn.modules, |
| 187 | +tile_size: int, |
| 188 | +verbose: bool = False, |
| 189 | +deterministic: bool = False, |
| 190 | +saliency: bool = False, |
| 191 | +copy_to_gpu: bool = False, |
| 192 | +statistics_on_cpu: bool = False, |
| 193 | +normalize_on_gpu: bool = False, |
| 194 | +mean: list[float, float, float] | None = None, |
| 195 | +std: list[float, float, float] | None = None, |
| 196 | +tile_cache: dict | None = None, |
| 197 | +add_keep_modules: list[torch.nn.modules] | None = None, |
| 198 | +before_streaming_init_callbacks: list[Callable[[torch.nn.modules], None], ...] | None = None, |
| 199 | +after_streaming_init_callbacks: list[Callable[[torch.nn.modules], None], ...] | None = None, |
| 200 | +``` |
| 201 | + |
| 202 | +### Constructor default behaviour |
| 203 | +By default, the constructor will perform the following steps: |
| 204 | +1. All layers except convolution, local max pooling, and local average pooling layers are set to `nn.Identity` |
| 205 | +2. `before_streaming_init_callbacks` are executed. These are user-specified, and by default, no callbacks are executed. |
| 206 | +3. The streaming module is constructed from the convolution/local pooling layers. Within this step, the model's weights are altered to calculate tile statistics. |
| 207 | +4. The streaming module's `nn.Identity` layers from step 1 are restored back to their old layers, and the correct model weights are reloaded |
| 208 | +5. `after_streaming_init_callbacks` are executed. These are user-specified, and by default, no callbacks are executed. |
| 209 | +6. The streaming module is returned |
| 210 | + |
| 211 | +The before and after streaming initialization callbacks are added for flexibility, since we cannot take all possible variations for model creation into account. |
| 212 | +An example of where these callbacks come in handy is given in the code for the streaming convnext model. Within this model, we need to additionally turn off the stochastic depth operation, as well as the layer scale, which are not normal layers. |
| 213 | +For a more detailed example, we invite the reader to look at the code for either the [Resnet](/models/resnet) or [Convnext](/models/convnext) models within the repository. |
0 commit comments