Skip to content

Commit 0688cc6

Browse files
committed
🦾 new loss API
1 parent 1cb4457 commit 0688cc6

File tree

4 files changed

+251
-50
lines changed

4 files changed

+251
-50
lines changed

‎README.md

+89-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# layer-to-layer-pytorch
1+
# L2L execution algorithm PyTorch [WIP]
22

33
<div align="center">
44

@@ -12,12 +12,97 @@
1212
[![Semantic Versions](https://img.shields.io/badge/%F0%9F%9A%80-semantic%20versions-informational.svg)](https://github.com/TezRomacH/layer-to-layer-pytorch/releases)
1313
[![License](https://img.shields.io/github/license/TezRomacH/layer-to-layer-pytorch)](https://github.com/TezRomacH/layer-to-layer-pytorch/blob/master/LICENSE)
1414

15-
PyTorch implementation of L2L execution algorithm
15+
PyTorch implementation of L2L execution algorithm from paper [Training Large Neural Networks with Constant Memory using a New Execution Algorithm](https://arxiv.org/abs/2002.05645)
1616
</div>
1717

18-
## 🚀 Features [WIP]
18+
## [Not ready yet]
19+
20+
## 🚀 Exapmle
21+
22+
You need to define a torch model where all layers are specified in ModuleList.
23+
24+
for example
25+
26+
```python
27+
import torch
28+
from torch import nn, optim
29+
30+
class M(nn.Module):
31+
def __init__(self, depth: int, dim: int, hidden_dim: Optional[int] = None):
32+
super().__init__()
33+
hidden_dim = hidden_dim or dim
34+
self.layers = nn.ModuleList(
35+
[
36+
nn.Sequential(
37+
nn.Linear(dim, hidden_dim),
38+
nn.BatchNorm1d(hidden_dim),
39+
nn.LeakyReLU(),
40+
)
41+
]
42+
+ [
43+
nn.Sequential(
44+
nn.Linear(hidden_dim, hidden_dim),
45+
nn.BatchNorm1d(hidden_dim),
46+
nn.LeakyReLU(),
47+
)
48+
for i in range(depth)
49+
]
50+
+ [nn.Linear(hidden_dim, dim), nn.Sigmoid()]
51+
)
52+
53+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
54+
x = batch
55+
for l in self.layers:
56+
x = l(x)
57+
58+
return x
1959

20-
## Installation [Not yet ready]
60+
```
61+
62+
Then, you can use the L2L wrapper over this model.
63+
64+
```python
65+
from layer_to_layer_pytorch.l2l import Layer2Layer
66+
67+
model = M(depth=5, dim=40).train() # on CPU
68+
69+
l2l_model = Layer2Layer(
70+
model,
71+
layers_attr="layers", # attribute with ModuleList
72+
microbatch_size=100, # size of microbatch in minibatch :) from original paper
73+
verbose=False # enable tqdm
74+
)
75+
```
76+
77+
And train it, like torch model (almost):
78+
79+
```python
80+
from tqdm.auto import tqdm, trange
81+
82+
x = torch.rand(1_000, 40) # on CPU
83+
y = torch.rand(1_000, 40) # on CPU
84+
85+
losses = []
86+
loss_fn = nn.MSELoss(reduction="sum") # since L2L calcs average loses itself, we just need to save them
87+
88+
optimizer = optim.AdamW(l2l_model.main_model.parameters(), lr=0.001) # optimizer looks to main model on CPu
89+
90+
for i in trange(5000):
91+
l2l_model.zero_grad()
92+
l2l_model.forward(x)
93+
94+
with l2l_model.l2l_loss(loss_fn=loss_fn) as loss: # APEX-like loss style
95+
loss_value = loss(x, y)
96+
loss.backward()
97+
98+
if i % 50 == 0:
99+
tqdm.write(f"[{i}] loss = {loss_value.item()}")
100+
losses.append(loss_value.item())
101+
102+
optimizer.step()
103+
```
104+
105+
## Installation
21106

22107
```bash
23108
pip install layer-to-layer-pytorch

‎layer_to_layer_pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
__version__ = "unknown"
1414

1515
from layer_to_layer_pytorch.l2l import Layer2Layer
16+
from layer_to_layer_pytorch.loss import L2LLoss
1617
from layer_to_layer_pytorch.types import Device, LossFn, TensorOrTensorArray

‎layer_to_layer_pytorch/l2l.py

+80-36
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch import nn
88

99
from layer_to_layer_pytorch.helpers import enumerator, iterator, zipper
10+
from layer_to_layer_pytorch.loss import L2LLoss
1011
from layer_to_layer_pytorch.types import Device, LossFn, TensorOrTensorArray
1112

1213

@@ -54,7 +55,7 @@ def zero_grad(self) -> None:
5455

5556
@torch.no_grad()
5657
def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
57-
layers: nn.ModuleList = getattr(self.main_model, self.layers_attr)
58+
layers: nn.ModuleList = self._get_layers()
5859

5960
# layer by layer forward pass. only activations are stored
6061
for idx, l in enumerator(
@@ -72,13 +73,7 @@ def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
7273
else:
7374
input = self._activations[idx - 1]
7475

75-
# forward with microbatching
76-
batch_size = input.shape[0]
77-
microbatch_size = (
78-
batch_size
79-
if self.microbatch_size is None
80-
else self.microbatch_size
81-
)
76+
microbatch_size = self._get_microbatch_size(input)
8277
num_steps: int = input.shape[0] // microbatch_size
8378

8479
for microbatch in iterator(
@@ -103,12 +98,13 @@ def calculate_gradients(
10398
target: torch.Tensor,
10499
loss_fn: LossFn,
105100
loss_kwargs: dict = None,
101+
skip_last_layer: bool = False,
106102
**forward_kwargs,
107-
) -> torch.Tensor:
103+
) -> Optional[torch.Tensor]:
108104
if loss_kwargs is None:
109105
loss_kwargs = {}
110106
# layer by layer backward pass (in reverse order)
111-
layers: nn.ModuleList = getattr(self.main_model, self.layers_attr)
107+
layers: nn.ModuleList = self._get_layers()
112108
losses: List[torch.Tensor] = []
113109
num_steps_in_loss: int = 1
114110

@@ -122,26 +118,38 @@ def calculate_gradients(
122118
layer = copy.deepcopy(l).to(self.gpu_device)
123119
f_idx: int = self.num_layers - idx - 1
124120

121+
if idx == 0 and skip_last_layer:
122+
microbatch_size = self._get_microbatch_size(
123+
self._activations[f_idx]
124+
)
125+
num_steps: int = (
126+
self._activations[f_idx].shape[0] // microbatch_size
127+
)
128+
self._copy_grad_to_main_model(
129+
idx,
130+
num_steps,
131+
local_params=layer.parameters(),
132+
main_params=layers[f_idx].parameters(),
133+
)
134+
continue
135+
125136
for param in layer.parameters():
126137
param.requires_grad = True
127138

128139
input: torch.Tensor
129140
output: torch.Tensor
130141

131-
if f_idx == 0:
142+
if idx == 0: # last layer
143+
input = self._activations[f_idx]
144+
output = target
145+
elif f_idx == 0: # first layer
132146
input = batch
133-
output = self._grads[idx - 1]
134-
else:
147+
output = self._activations[f_idx]
148+
else: # any other layer
135149
input = self._activations[f_idx - 1]
136-
output = target
137-
138-
batch_size = input.shape[0]
139-
microbatch_size = (
140-
batch_size
141-
if self.microbatch_size is None
142-
else self.microbatch_size
143-
)
150+
output = self._activations[f_idx]
144151

152+
microbatch_size = self._get_microbatch_size(input)
145153
num_steps: int = input.shape[0] // microbatch_size
146154
if idx == 0:
147155
num_steps_in_loss = num_steps
@@ -160,7 +168,12 @@ def calculate_gradients(
160168

161169
microtarget = microtarget.to(self.gpu_device)
162170

163-
activation: torch.Tensor = layer(microbatch, **forward_kwargs)
171+
if idx == 0:
172+
activation = microbatch
173+
else:
174+
activation: torch.Tensor = layer(
175+
microbatch, **forward_kwargs
176+
)
164177

165178
if idx == 0:
166179
loss = loss_fn(activation, microtarget, **loss_kwargs)
@@ -172,27 +185,58 @@ def calculate_gradients(
172185
activation.backward(microtarget)
173186
self._grads[idx].append(microbatch.grad.cpu())
174187

175-
for local_param, main_param in zip(
176-
layer.parameters(), layers[f_idx].parameters()
177-
):
178-
if main_param.grad is None:
179-
main_param.grad = local_param.grad.cpu() / num_steps
180-
else:
181-
main_param.grad += local_param.grad.cpu() / num_steps
188+
self._copy_grad_to_main_model(
189+
idx,
190+
num_steps,
191+
local_params=layer.parameters(),
192+
main_params=layers[f_idx].parameters(),
193+
)
182194

195+
self._grads = list(reversed(self._grads))
196+
197+
if not skip_last_layer:
183198
with torch.no_grad():
184-
self._grads[idx] = (
185-
torch.cat(self._grads[idx], dim=0).cpu() / num_steps
186-
)
199+
loss_value = torch.tensor(np.sum(losses) / num_steps_in_loss)
187200

188-
self._grads = list(reversed(self._grads))
189-
with torch.no_grad():
190-
loss_value = torch.tensor(np.sum(losses) / num_steps_in_loss)
201+
return loss_value
191202

192-
return loss_value
203+
return None
193204

194205
def __call__(self, batch: torch.Tensor) -> torch.Tensor:
195206
return self.forward(batch)
196207

208+
def _get_microbatch_size(self, batch: torch.Tensor) -> int:
209+
batch_size = batch.shape[0]
210+
return (
211+
batch_size if self.microbatch_size is None else self.microbatch_size
212+
)
213+
214+
def _get_layers(self) -> nn.ModuleList:
215+
return getattr(self.main_model, self.layers_attr)
216+
217+
def _copy_grad_to_main_model(
218+
self, idx: int, num_steps: int, local_params, main_params
219+
):
220+
for local_param, main_param in zip(local_params, main_params):
221+
if main_param.grad is None:
222+
main_param.grad = local_param.grad.cpu() / num_steps
223+
else:
224+
main_param.grad += local_param.grad.cpu() / num_steps
225+
226+
with torch.no_grad():
227+
self._grads[idx] = (
228+
torch.cat(self._grads[idx], dim=0).cpu() / num_steps
229+
)
230+
231+
def l2l_loss(
232+
self, loss_fn: LossFn, store_grad_on_calc: bool = True, **forward_kwargs
233+
) -> L2LLoss:
234+
return L2LLoss(
235+
model=self,
236+
loss_fn=loss_fn,
237+
store_grad_on_calc=store_grad_on_calc,
238+
**forward_kwargs,
239+
)
240+
197241

198242
__all__ = ["Layer2Layer"]

‎layer_to_layer_pytorch/loss.py

+81-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,85 @@
1-
# from typing import Callable
1+
from typing import Callable, List
22

3-
# import torch
4-
# from torch import nn
3+
import numpy as np
4+
import torch
5+
from torch import nn
56

6-
# from layer_to_layer_pytorch.types import LossFn
7-
# from layer_to_layer_pytorch.l2l import Layer2Layer
7+
from layer_to_layer_pytorch.helpers import zipper
8+
from layer_to_layer_pytorch.types import LossFn
89

9-
# class L2LLoss:
10-
# def __init__(self, model: Layer2Layer, loss_fn: LossFn):
11-
# self.model = model
12-
# self.loss_fn = loss_fn
1310

14-
# def __call__(self, batch: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
11+
class L2LLoss:
12+
def __init__(
13+
self,
14+
model,
15+
loss_fn: LossFn,
16+
store_grad_on_calc: bool = True,
17+
**forward_kwargs,
18+
):
19+
self.model = model
20+
self.loss_fn = loss_fn
21+
self.store_grad_on_calc = store_grad_on_calc
22+
self.forward_kwargs = forward_kwargs or {}
23+
24+
self._batch = None
25+
self._target = None
26+
27+
def __call__(
28+
self, batch: torch.Tensor, target: torch.Tensor
29+
) -> torch.Tensor:
30+
self._batch = batch
31+
self._target = target
32+
33+
microbatch_size = self.model._get_microbatch_size(batch)
34+
num_steps_in_loss = batch.shape[0] // microbatch_size
35+
losses: List[torch.Tensor] = []
36+
37+
layer: nn.Module = self.model._get_layers()[-1].to(
38+
self.model.gpu_device
39+
)
40+
41+
for microbatch, microtarget in zipper(
42+
batch.split(microbatch_size),
43+
target.split(microbatch_size),
44+
verbose=False,
45+
desc="Microbatching",
46+
total=num_steps_in_loss,
47+
leave=False,
48+
):
49+
microbatch = microbatch.to(self.model.gpu_device)
50+
microbatch.requires_grad = True
51+
52+
microtarget = microtarget.to(self.model.gpu_device)
53+
54+
activation: torch.Tensor = layer(microbatch, **self.forward_kwargs)
55+
56+
loss = self.loss_fn(activation, microtarget)
57+
losses.append(loss.item())
58+
59+
if self.store_grad_on_calc:
60+
loss.backward()
61+
self.model._grads[-1].append(microbatch.grad.cpu())
62+
63+
with torch.no_grad():
64+
loss_value = torch.tensor(np.sum(losses) / num_steps_in_loss)
65+
66+
return loss_value
67+
68+
@torch.no_grad()
69+
def backward(self) -> None:
70+
self.model.calculate_gradients(
71+
self._batch,
72+
self._target,
73+
loss_fn=self.loss_fn,
74+
skip_last_layer=self.store_grad_on_calc,
75+
)
76+
77+
def __enter__(self):
78+
return self
79+
80+
def __exit__(self, type, value, traceback):
81+
self._batch = None
82+
self._target = None
83+
84+
85+
__all__ = ["L2LLoss"]

0 commit comments

Comments
 (0)