Skip to content

Commit e7c82a2

Browse files
Merge pull request #9 from romulus0914/tf_like
Match the TensorFlow implementation more closely
2 parents 4020a3f + c0fd05b commit e7c82a2

File tree

8 files changed

+109
-50
lines changed

8 files changed

+109
-50
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ checkpoint/*
1111
dist/*
1212
build/*
1313

14+
*.ipynb
15+
1416
**/__pycache__/**
1517

1618
*.egg-info/*

Contributors.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
# Contributors
2-
2+
- [@romulus0914](https://github.com/romulus0914) (Romulus Hong)
3+
- Author of the code - NAS-Bench-101 implementation in PyTorch
4+
- [@gabikadlecova](https://github.com/gabikadlecova)
5+
- Maintainer of the repository
6+
- Package structure, reproducibility
7+
---------
38
- [@abhash-er](https://github.com/abhash-er/) (Abhash Jha)
4-
- modified the model code so that cast to double is possible
9+
- Modified the model code so that cast to double is possible
10+
- [@longerHost](https://github.com/longerHost)
11+
- Reproducibility of the original NAS-Bench-101
12+
- Comparison of training results and API results

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ implementation is written in TensorFlow, and this projects contains
55
some files from the original repository (in the directory
66
`nasbench_pytorch/model/`).
77

8+
**Important:** if you want to reproduce the original results, please refer to the
9+
[Reproducibility](#repro) section.
10+
811
# Overview
912
A PyTorch implementation of *training* of NAS-Bench-101 dataset: [NAS-Bench-101: Towards Reproducible Neural Architecture Search](https://arxiv.org/abs/1902.09635).
1013
The dataset contains 423,624 unique neural networks exhaustively generated and evaluated from a fixed graph-based search space.
@@ -64,6 +67,25 @@ Then, you can train it just like the example network in `main.py`.
6467
Example architecture (picture from the original repository)
6568
![archtecture](./assets/architecture.png)
6669

70+
# Reproducibility <a id="repro"></a>
71+
The code should closely match the TensorFlow version (including the hyperparameters), but there are some differences:
72+
- RMSProp implementation in TensorFlow and PyTorch is **different**
73+
- For more information refer to [here](https://github.com/pytorch/pytorch/issues/32545) and [here](https://github.com/pytorch/pytorch/issues/23796).
74+
- Optionally, you can install pytorch-image-models where a [TensorFlow-like RMSProp](https://github.com/rwightman/pytorch-image-models/blob/main/timm/optim/rmsprop_tf.py#L5) is implemented
75+
- `pip install timm`
76+
- Then, pass `--optimizer rmsprop_tf` to `main.py` to use it
77+
78+
79+
- You can turn gradient clipping off by setting `--grad_clip_off True`
80+
81+
82+
- The original training was on TPUs, this code enables only GPU and CPU training
83+
- Input data augmentation methods are the same, but due to randomness they are not applied in the same manner
84+
- Cause: Batches and images cannot be shuffled as in the original TPU training, and the augmentation seed is also different
85+
- Results may still differ due to TensorFlow/PyTorch implementation differences
86+
87+
Refer to this [issue](https://github.com/romulus0914/NASBench-PyTorch/issues/6) for more information and for comparison with API results.
88+
6789
# Disclaimer
6890
Modified from [NASBench: A Neural Architecture Search Dataset and Benchmark](https://github.com/google-research/nasbench).
6991
*graph_util.py* and *model_spec.py* are directly copied from the original repo. Original license can be found [here](https://github.com/google-research/nasbench/blob/master/LICENSE).

main.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,31 @@ def reload_checkpoint(path, device=None):
4141
parser = argparse.ArgumentParser(description='NASBench')
4242
parser.add_argument('--random_state', default=1, type=int, help='Random seed.')
4343
parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.')
44-
parser.add_argument('--module_vertices', default=7, type=int, help='#vertices in graph')
45-
parser.add_argument('--max_edges', default=9, type=int, help='max edges in graph')
46-
parser.add_argument('--available_ops', default=['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'],
47-
type=list, help='available operations performed on vertex')
4844
parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.')
4945
parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution')
5046
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
5147
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
52-
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
53-
parser.add_argument('--test_batch_size', default=100, type=int, help='test set batch size')
54-
parser.add_argument('--epochs', default=100, type=int, help='#epochs of training')
55-
parser.add_argument('--validation_size', default=0, type=int, help="Size of the validation set to split off.")
48+
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
49+
parser.add_argument('--test_batch_size', default=256, type=int, help='test set batch size')
50+
parser.add_argument('--epochs', default=108, type=int, help='#epochs of training')
51+
parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.")
5652
parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
57-
parser.add_argument('--learning_rate', default=0.025, type=float, help='base learning rate')
53+
parser.add_argument('--learning_rate', default=0.02, type=float, help='base learning rate')
5854
parser.add_argument('--lr_decay_method', default='COSINE_BY_STEP', type=str, help='learning decay method')
59-
parser.add_argument('--optimizer', default='sgd', type=str, help='Optimizer (sgd or rmsprop)')
55+
parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)')
56+
parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.')
6057
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
6158
parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight')
6259
parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping')
63-
parser.add_argument('--batch_norm_momentum', default=0.1, type=float, help='Batch normalization momentum')
60+
parser.add_argument('--grad_clip_off', default=False, type=bool, help='If True, turn off gradient clipping.')
61+
parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum')
6462
parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
6563
parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint')
6664
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
6765
parser.add_argument('--device', default='cuda', type=str, help='Device for network training.')
6866
parser.add_argument('--print_freq', default=100, type=int, help='Batch print frequency.')
67+
parser.add_argument('--tf_like', default=False, type=bool,
68+
help='If true, use same weight initialization as in the tensorflow version.')
6969

7070
args = parser.parse_args()
7171

@@ -79,9 +79,10 @@ def reload_checkpoint(path, device=None):
7979

8080
# model
8181
spec = ModelSpec(matrix, operations)
82-
net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels, stem_out_channels=args.stem_out_channels,
83-
num_stacks=args.num_stacks, num_modules_per_stack=args.num_modules_per_stack,
84-
momentum=args.batch_norm_momentum, eps=args.batch_norm_eps)
82+
net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels,
83+
stem_out_channels=args.stem_out_channels, num_stacks=args.num_stacks,
84+
num_modules_per_stack=args.num_modules_per_stack,
85+
momentum=args.batch_norm_momentum, eps=args.batch_norm_eps, tf_like=args.tf_like)
8586

8687
if args.load_checkpoint != '':
8788
net.load_state_dict(reload_checkpoint(args.load_checkpoint))
@@ -91,16 +92,23 @@ def reload_checkpoint(path, device=None):
9192

9293
if args.optimizer.lower() == 'sgd':
9394
optimizer = optim.SGD
95+
optimizer_kwargs = {}
9496
elif args.optimizer.lower() == 'rmsprop':
9597
optimizer = optim.RMSprop
98+
optimizer_kwargs = {'eps': args.rmsprop_eps}
99+
elif args.optimizer.lower() == 'rmsprop_tf':
100+
from timm.optim import RMSpropTF
101+
optimizer = RMSpropTF
102+
optimizer_kwargs = {'eps': args.rmsprop_eps}
96103
else:
97104
raise ValueError(f"Invalid optimizer {args.optimizer}, possible: SGD, RMSProp")
98105

99106
optimizer = optimizer(net.parameters(), lr=args.learning_rate, momentum=args.momentum,
100-
weight_decay=args.weight_decay)
107+
weight_decay=args.weight_decay, **optimizer_kwargs)
101108
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
102109

103-
result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler, grad_clip=args.grad_clip,
110+
result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler,
111+
grad_clip=args.grad_clip if not args.grad_clip_off else None,
104112
num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader,
105113
device=args.device, print_frequency=args.print_freq)
106114

nasbench_pytorch/datasets/cifar10.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ def train_valid_split(dataset_size, valid_size, random_state=None):
2222

2323

2424
def seed_worker(seed, worker_id):
25+
seed = seed if seed is not None else 0
2526
worker_seed = seed + worker_id
2627
np.random.seed(worker_seed)
2728
random.seed(worker_seed)
2829

2930

30-
def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_size=0, random_state=None,
31-
set_global_seed=False, no_valid_transform=False,
31+
def prepare_dataset(batch_size, test_batch_size=256, root='./data/', use_validation=True, split_from_end=True,
32+
validation_size=10000, random_state=None, set_global_seed=False, no_valid_transform=True,
3233
num_workers=0, num_val_workers=0, num_test_workers=0):
3334
"""
3435
Download the CIFAR-10 dataset and prepare train and test DataLoaders (optionally also validation loader).
@@ -37,8 +38,9 @@ def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_
3738
batch_size: Batch size for the train (and validation) loader.
3839
test_batch_size: Batch size for the test loader.
3940
root: Directory path to download the CIFAR-10 dataset to.
41+
use_validation: If False, don't split off the validation set.
42+
split_from_end: If True, split off `validation_size` images from the end, if False, choose images randomly.
4043
validation_size: Size of the validation dataset to split off the train set.
41-
If == 0, don't return the validation set.
4244
4345
random_state: Seed for the random functions (generators from numpy and random)
4446
set_global_seed: If True, call np.random.seed(random_state) and random.seed(random_state). Useful when
@@ -66,7 +68,7 @@ def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_
6668
if random_state is not None:
6769
worker_fn = partial(seed_worker, random_state)
6870
else:
69-
worker_fn=None
71+
worker_fn = None
7072

7173
print('\n--- Preparing CIFAR10 Data ---')
7274

@@ -87,10 +89,19 @@ def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_
8789
valid_set = valid_set if no_valid_transform else train_set
8890
train_size = len(train_set)
8991

90-
# split off random validation set
91-
if validation_size > 0:
92-
train_sampler, valid_sampler = train_valid_split(train_size, validation_size, random_state=random_state)
93-
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False,
92+
if use_validation:
93+
if split_from_end:
94+
# get last n images
95+
indices = np.arange(len(train_set))
96+
train_set = torch.utils.data.Subset(train_set, indices[:-validation_size])
97+
valid_set = torch.utils.data.Subset(valid_set, indices[-validation_size:])
98+
train_sampler, valid_sampler = None, None
99+
else:
100+
# split off random validation set
101+
train_sampler, valid_sampler = train_valid_split(train_size, validation_size, random_state=random_state)
102+
103+
# shuffle is True if split_from_end otherwise False
104+
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=split_from_end,
94105
sampler=train_sampler, num_workers=num_workers,
95106
worker_init_fn=worker_fn)
96107
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False,

nasbench_pytorch/model/model.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121

2222
import torch
2323
import torch.nn as nn
24+
from torch.nn.init import _calculate_fan_in_and_fan_out
2425

2526

2627
class Network(nn.Module):
27-
def __init__(self, spec, num_labels=10,
28-
in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, momentum=0.1, eps=1e-5):
28+
def __init__(self, spec, num_labels=10, in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3,
29+
momentum=0.997, eps=1e-5, tf_like=False):
2930
"""
3031
3132
Args:
@@ -45,6 +46,7 @@ def __init__(self, spec, num_labels=10,
4546

4647
self.cell_indices = set()
4748

49+
self.tf_like = tf_like
4850
self.layers = nn.ModuleList([])
4951

5052
# initial stem convolution
@@ -84,21 +86,31 @@ def forward(self, x):
8486
def _initialize_weights(self):
8587
for m in self.modules():
8688
if isinstance(m, nn.Conv2d):
87-
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88-
m.weight.data.normal_(0, math.sqrt(2.0 / n))
89+
if self.tf_like:
90+
fan_in, _ = _calculate_fan_in_and_fan_out(m.weight)
91+
torch.nn.init.normal_(m.weight, mean=0, std=1.0 / torch.sqrt(torch.tensor(fan_in)))
92+
else:
93+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
94+
m.weight.data.normal_(0, math.sqrt(2.0 / n))
95+
8996
if m.bias is not None:
9097
m.bias.data.zero_()
98+
9199
elif isinstance(m, nn.BatchNorm2d):
92100
m.weight.data.fill_(1)
93101
m.bias.data.zero_()
94102
elif isinstance(m, nn.Linear):
95-
m.weight.data.normal_(0, 0.01)
103+
if self.tf_like:
104+
torch.nn.init.xavier_uniform_(m.weight)
105+
else:
106+
m.weight.data.normal_(0, 0.01)
96107
m.bias.data.zero_()
97108

109+
98110
class Cell(nn.Module):
99111
"""
100112
Builds the model using the adjacency matrix and op labels specified. Channels
101-
controls the module output channel count but the interior channels are
113+
control the module output channel count but the interior channels are
102114
determined via equally splitting the channel count whenever there is a
103115
concatenation of Tensors.
104116
"""
@@ -111,7 +123,7 @@ def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5):
111123
self.num_vertices = np.shape(self.matrix)[0]
112124

113125
# vertex_channels[i] = number of output channels of vertex i
114-
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.matrix)
126+
self.vertex_channels = compute_vertex_channels(in_channels, out_channels, self.matrix)
115127
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
116128

117129
# operation for each node
@@ -124,11 +136,11 @@ def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5):
124136
self.input_op = nn.ModuleList([Placeholder()])
125137
for t in range(1, self.num_vertices):
126138
if self.matrix[0, t]:
127-
self.input_op.append(Projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps))
139+
self.input_op.append(projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps))
128140
else:
129141
self.input_op.append(Placeholder())
130142

131-
self.last_inop : Projection = self.input_op[self.num_vertices-1]
143+
self.last_inop : projection = self.input_op[self.num_vertices - 1]
132144

133145
def forward(self, x):
134146
tensors = [x]
@@ -141,20 +153,17 @@ def forward(self, x):
141153
fan_in = []
142154
for src in range(1, t):
143155
if self.matrix[src, t]:
144-
fan_in.append(Truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
156+
fan_in.append(truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
145157

146158
if self.matrix[0, t]:
147159
l = inmod(x)
148160
fan_in.append(l)
149161

150162
# perform operation on node
151-
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
152163
vertex_input = torch.zeros_like(fan_in[0]).to(self.dev_param.device)
153-
154164
for val in fan_in:
155165
vertex_input += val
156-
#vertex_input = sum(fan_in)
157-
#vertex_input = sum(fan_in) / len(fan_in)
166+
158167
vertex_output = outmod(vertex_input)
159168

160169
tensors.append(vertex_output)
@@ -173,19 +182,15 @@ def forward(self, x):
173182
if self.matrix[0, self.num_vertices-1]:
174183
outputs = outputs + self.last_inop(tensors[0])
175184

176-
#if self.matrix[0, self.num_vertices-1]:
177-
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
178-
#outputs = sum(out_concat) / len(out_concat)
179-
180185
return outputs
181186

182187

183-
def Projection(in_channels, out_channels, momentum=0.1, eps=1e-5):
188+
def projection(in_channels, out_channels, momentum=0.1, eps=1e-5):
184189
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
185190
return ConvBnRelu(in_channels, out_channels, 1, momentum=momentum, eps=eps)
186191

187192

188-
def Truncate(inputs, channels):
193+
def truncate(inputs, channels):
189194
"""Slice the inputs to channels if necessary."""
190195
input_channels = inputs.size()[1]
191196
if input_channels < channels:
@@ -200,7 +205,7 @@ def Truncate(inputs, channels):
200205
return inputs[:, :channels, :, :]
201206

202207

203-
def ComputeVertexChannels(in_channels, out_channels, matrix):
208+
def compute_vertex_channels(in_channels, out_channels, matrix):
204209
"""Computes the number of channels at every vertex.
205210
206211
Given the input channels and output channels, this calculates the number of
@@ -210,6 +215,8 @@ def ComputeVertexChannels(in_channels, out_channels, matrix):
210215
When the division is not even, some vertices may receive an extra channel to
211216
compensate.
212217
218+
Code from https://github.com/google-research/nasbench/
219+
213220
Returns:
214221
list of channel counts, in order of the vertices.
215222
"""

nasbench_pytorch/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
7474
optimizer.zero_grad()
7575
curr_loss = loss(outputs, targets)
7676
curr_loss.backward()
77-
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
77+
if grad_clip is not None:
78+
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
7879
optimizer.step()
7980

8081
# metrics

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
setuptools.setup(
44
name='nasbench_pytorch',
5-
version='1.2.3',
5+
version='1.3',
66
license='Apache License 2.0',
7-
author='Romulus Hong, Gabriela Suchopárová',
7+
author='Romulus Hong, Gabriela Kadlecová',
88
packages=setuptools.find_packages()
99
)

0 commit comments

Comments
 (0)