Skip to content

Commit 0d35d35

Browse files
committed
Updated docs
1 parent 492b502 commit 0d35d35

8 files changed

Lines changed: 434 additions & 513 deletions

File tree

docs/models/convnext.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Resnet
2+
3+
::: models.convnext.convnext.StreamingConvnext

docs/modules/basemodule.md

Lines changed: 0 additions & 3 deletions
This file was deleted.

docs/modules/constructor.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# StreamingModule
2+
3+
::: modules.constructor.StreamingConstructor
4+
options:
5+
docstring_section_style: table

docs/modules/imagenettemplate.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# BaseModule
2+
3+
::: modules.imagenet_template.ImageNetClassifier

docs/tutorials/classification.md

Lines changed: 212 additions & 29 deletions
Large diffs are not rendered by default.

docs/tutorials/custom_models.md

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
Custom models can be created in one of three ways:
33

44
* Using the `StreamingModule` (recommended)
5-
* Using the `BaseModel` (a subclass of `StreamingModule`)
5+
* Using the `ImageNetClassifier` (a subclass of `StreamingModule`)
66
* Creating your own class (not recommended)
77

8-
The `StreamingModule` and `BaseModel` classes are both regular `LightningModule` objects and should be treated as such.
8+
The `StreamingModule` and `ImageNetClassifier` classes are both regular `LightningModule` objects and should be treated as such.
99
Both classes have several helper functions and a custom initialization that create the streaming model instance. Secondly,
1010
the helper functions make sure that several settings, such as freezing normalization layers and setting them to `eval()` mode, both during training and inference.
1111
This is necessary since streaming does not work with layers that are not locally defined, but rather need the entire input image.
@@ -83,24 +83,44 @@ class StreamingResNet(StreamingModule):
8383
tile_size: int,
8484
loss_fn: torch.nn.functional,
8585
train_streaming_layers: bool = True,
86-
use_streaming: bool = True,
87-
*args,
86+
metrics: MetricCollection | None = None,
8887
**kwargs
8988
):
9089
assert model_name in list(StreamingResNet.model_choices.keys())
91-
network = StreamingResNet.model_choices[model_name](weights="IMAGENET1K_V1")
92-
stream_net, head = split_resnet(network, num_classes=kwargs.get("num_classes"))
90+
network = StreamingResNet.model_choices[model_name](weights="DEFAULT")
91+
stream_network, head = split_resnet(network, num_classes=kwargs.pop("num_classes", 1000))
92+
93+
self._get_streaming_options(**kwargs)
94+
print("metrics", metrics)
9395
super().__init__(
94-
stream_net,
96+
stream_network,
9597
head,
9698
tile_size,
99+
loss_fn,
97100
train_streaming_layers=train_streaming_layers,
98-
use_streaming=use_streaming,
99-
*args,
100-
**kwargs
101+
metrics=metrics,
102+
**self.streaming_options,
101103
)
104+
105+
def _get_streaming_options(self, **kwargs):
106+
"""Set streaming defaults, but overwrite them with values of kwargs if present."""
107+
108+
# We need to add torch.nn.Batchnorm to the keep modules, because of some in-place ops error if we don't
109+
# https://discuss.pytorch.org/t/register-full-backward-hook-for-residual-connection/146850
110+
streaming_options = {
111+
"verbose": True,
112+
"copy_to_gpu": False,
113+
"statistics_on_cpu": True,
114+
"normalize_on_gpu": True,
115+
"mean": [0.485, 0.456, 0.406],
116+
"std": [0.229, 0.224, 0.225],
117+
"add_keep_modules": [torch.nn.BatchNorm2d],
118+
}
119+
self.streaming_options = {**streaming_options, **kwargs}
102120

103121
```
122+
The actual streaming module can be configured with varying settings. These can be passed as a kwarg dictionary within the `super().__init__()` call of the parent class. Under the hood, this dictionary with options is passed to a constructor which takes care of properly building the streaming module.
123+
A more detailed explanation about the constructor can be found below.
104124

105125

106126
## Custom forward/backward logic
@@ -150,10 +170,44 @@ def backward(self, loss):
150170
```
151171

152172

153-
- Hooks: Several hooks in pytorch lightning are used to set the normalization layers to `eval()` and set the inputs/models to the right device (this is not how it should be done, but we are working on a solution for this).
173+
- Hooks: Several hooks in pytorch lightning are used to set the normalization layers to `eval()` and set the inputs/models to the right device.
154174
- on_training_start: Allocates the input and models to the correct device at training time.
155175
- on_validation_start: Allocates the input and models to the correct device at validation time.
156176
- on_test_start: Allocates the input and models to the correct device at test time.
157177
- on_train_epoch_start(self): sets all the normalization layers to eval() during training
158178

159-
**Warning: do not override these hooks with your own code. If you need these hooks for any reason, then call the parent method first using e.g. `super().on_training_start`**
179+
**Warning: do not override these hooks with your own code. If you need these hooks for any reason, then call the parent method first using e.g. `super().on_training_start`**
180+
181+
182+
## Streaming using the constructor
183+
Under the hood of the `StreamingModule` class, we have an additional `Constructor` class that actually builds and defines the streaming module. The following arguments can be passed to it:
184+
At the very least, a torch model and a tile size must be provided.
185+
```python
186+
model: torch.nn.modules,
187+
tile_size: int,
188+
verbose: bool = False,
189+
deterministic: bool = False,
190+
saliency: bool = False,
191+
copy_to_gpu: bool = False,
192+
statistics_on_cpu: bool = False,
193+
normalize_on_gpu: bool = False,
194+
mean: list[float, float, float] | None = None,
195+
std: list[float, float, float] | None = None,
196+
tile_cache: dict | None = None,
197+
add_keep_modules: list[torch.nn.modules] | None = None,
198+
before_streaming_init_callbacks: list[Callable[[torch.nn.modules], None], ...] | None = None,
199+
after_streaming_init_callbacks: list[Callable[[torch.nn.modules], None], ...] | None = None,
200+
```
201+
202+
### Constructor default behaviour
203+
By default, the constructor will perform the following steps:
204+
1. All layers except convolution, local max pooling, and local average pooling layers are set to `nn.Identity`
205+
2. `before_streaming_init_callbacks` are executed. These are user-specified, and by default, no callbacks are executed.
206+
3. The streaming module is constructed from the convolution/local pooling layers. Within this step, the model's weights are altered to calculate tile statistics.
207+
4. The streaming module's `nn.Identity` layers from step 1 are restored back to their old layers, and the correct model weights are reloaded
208+
5. `after_streaming_init_callbacks` are executed. These are user-specified, and by default, no callbacks are executed.
209+
6. The streaming module is returned
210+
211+
The before and after streaming initialization callbacks are added for flexibility, since we cannot take all possible variations for model creation into account.
212+
An example of where these callbacks come in handy is given in the code for the streaming convnext model. Within this model, we need to additionally turn off the stochastic depth operation, as well as the layer scale, which are not normal layers.
213+
For a more detailed example, we invite the reader to look at the code for either the [Resnet](/models/resnet) or [Convnext](/models/convnext) models within the repository.

mkdocs.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ nav:
1212
Trainer options: tutorials/trainer_options.md
1313
- Models:
1414
ResNet: models/resnet.md
15+
ConvNext: models/convnext.md
1516
- Modules:
16-
streamingmodule: modules/streamingmodule.md
17-
basemodule: modules/basemodule.md
17+
streaming: modules/streamingmodule.md
18+
imagenet: modules/imagenettemplate.md
19+
constructor: modules/constructor.md
20+
1821

1922
theme:
2023
features:

0 commit comments

Comments
 (0)