Skip to content

Commit

Permalink
Merge pull request #9 from romulus0914/tf_like
Browse files Browse the repository at this point in the history
Match the TensorFlow implementation more closely
  • Loading branch information
gabikadlecova authored Dec 5, 2022
2 parents 4020a3f + c0fd05b commit e7c82a2
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 50 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ checkpoint/*
dist/*
build/*

*.ipynb

**/__pycache__/**

*.egg-info/*
Expand Down
12 changes: 10 additions & 2 deletions Contributors.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Contributors

- [@romulus0914](https://github.com/romulus0914) (Romulus Hong)
- Author of the code - NAS-Bench-101 implementation in PyTorch
- [@gabikadlecova](https://github.com/gabikadlecova)
- Maintainer of the repository
- Package structure, reproducibility
---------
- [@abhash-er](https://github.com/abhash-er/) (Abhash Jha)
- modified the model code so that cast to double is possible
- Modified the model code so that cast to double is possible
- [@longerHost](https://github.com/longerHost)
- Reproducibility of the original NAS-Bench-101
- Comparison of training results and API results
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ implementation is written in TensorFlow, and this projects contains
some files from the original repository (in the directory
`nasbench_pytorch/model/`).

**Important:** if you want to reproduce the original results, please refer to the
[Reproducibility](#repro) section.

# Overview
A PyTorch implementation of *training* of NAS-Bench-101 dataset: [NAS-Bench-101: Towards Reproducible Neural Architecture Search](https://arxiv.org/abs/1902.09635).
The dataset contains 423,624 unique neural networks exhaustively generated and evaluated from a fixed graph-based search space.
Expand Down Expand Up @@ -64,6 +67,25 @@ Then, you can train it just like the example network in `main.py`.
Example architecture (picture from the original repository)
![archtecture](./assets/architecture.png)

# Reproducibility <a id="repro"></a>
The code should closely match the TensorFlow version (including the hyperparameters), but there are some differences:
- RMSProp implementation in TensorFlow and PyTorch is **different**
- For more information refer to [here](https://github.com/pytorch/pytorch/issues/32545) and [here](https://github.com/pytorch/pytorch/issues/23796).
- Optionally, you can install pytorch-image-models where a [TensorFlow-like RMSProp](https://github.com/rwightman/pytorch-image-models/blob/main/timm/optim/rmsprop_tf.py#L5) is implemented
- `pip install timm`
- Then, pass `--optimizer rmsprop_tf` to `main.py` to use it


- You can turn gradient clipping off by setting `--grad_clip_off True`


- The original training was on TPUs, this code enables only GPU and CPU training
- Input data augmentation methods are the same, but due to randomness they are not applied in the same manner
- Cause: Batches and images cannot be shuffled as in the original TPU training, and the augmentation seed is also different
- Results may still differ due to TensorFlow/PyTorch implementation differences

Refer to this [issue](https://github.com/romulus0914/NASBench-PyTorch/issues/6) for more information and for comparison with API results.

# Disclaimer
Modified from [NASBench: A Neural Architecture Search Dataset and Benchmark](https://github.com/google-research/nasbench).
*graph_util.py* and *model_spec.py* are directly copied from the original repo. Original license can be found [here](https://github.com/google-research/nasbench/blob/master/LICENSE).
Expand Down
40 changes: 24 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,31 @@ def reload_checkpoint(path, device=None):
parser = argparse.ArgumentParser(description='NASBench')
parser.add_argument('--random_state', default=1, type=int, help='Random seed.')
parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.')
parser.add_argument('--module_vertices', default=7, type=int, help='#vertices in graph')
parser.add_argument('--max_edges', default=9, type=int, help='max edges in graph')
parser.add_argument('--available_ops', default=['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'],
type=list, help='available operations performed on vertex')
parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.')
parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--test_batch_size', default=100, type=int, help='test set batch size')
parser.add_argument('--epochs', default=100, type=int, help='#epochs of training')
parser.add_argument('--validation_size', default=0, type=int, help="Size of the validation set to split off.")
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--test_batch_size', default=256, type=int, help='test set batch size')
parser.add_argument('--epochs', default=108, type=int, help='#epochs of training')
parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.")
parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
parser.add_argument('--learning_rate', default=0.025, type=float, help='base learning rate')
parser.add_argument('--learning_rate', default=0.02, type=float, help='base learning rate')
parser.add_argument('--lr_decay_method', default='COSINE_BY_STEP', type=str, help='learning decay method')
parser.add_argument('--optimizer', default='sgd', type=str, help='Optimizer (sgd or rmsprop)')
parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)')
parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight')
parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping')
parser.add_argument('--batch_norm_momentum', default=0.1, type=float, help='Batch normalization momentum')
parser.add_argument('--grad_clip_off', default=False, type=bool, help='If True, turn off gradient clipping.')
parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum')
parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint')
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
parser.add_argument('--device', default='cuda', type=str, help='Device for network training.')
parser.add_argument('--print_freq', default=100, type=int, help='Batch print frequency.')
parser.add_argument('--tf_like', default=False, type=bool,
help='If true, use same weight initialization as in the tensorflow version.')

args = parser.parse_args()

Expand All @@ -79,9 +79,10 @@ def reload_checkpoint(path, device=None):

# model
spec = ModelSpec(matrix, operations)
net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels, stem_out_channels=args.stem_out_channels,
num_stacks=args.num_stacks, num_modules_per_stack=args.num_modules_per_stack,
momentum=args.batch_norm_momentum, eps=args.batch_norm_eps)
net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels,
stem_out_channels=args.stem_out_channels, num_stacks=args.num_stacks,
num_modules_per_stack=args.num_modules_per_stack,
momentum=args.batch_norm_momentum, eps=args.batch_norm_eps, tf_like=args.tf_like)

if args.load_checkpoint != '':
net.load_state_dict(reload_checkpoint(args.load_checkpoint))
Expand All @@ -91,16 +92,23 @@ def reload_checkpoint(path, device=None):

if args.optimizer.lower() == 'sgd':
optimizer = optim.SGD
optimizer_kwargs = {}
elif args.optimizer.lower() == 'rmsprop':
optimizer = optim.RMSprop
optimizer_kwargs = {'eps': args.rmsprop_eps}
elif args.optimizer.lower() == 'rmsprop_tf':
from timm.optim import RMSpropTF
optimizer = RMSpropTF
optimizer_kwargs = {'eps': args.rmsprop_eps}
else:
raise ValueError(f"Invalid optimizer {args.optimizer}, possible: SGD, RMSProp")

optimizer = optimizer(net.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
weight_decay=args.weight_decay, **optimizer_kwargs)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler, grad_clip=args.grad_clip,
result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler,
grad_clip=args.grad_clip if not args.grad_clip_off else None,
num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader,
device=args.device, print_frequency=args.print_freq)

Expand Down
27 changes: 19 additions & 8 deletions nasbench_pytorch/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ def train_valid_split(dataset_size, valid_size, random_state=None):


def seed_worker(seed, worker_id):
seed = seed if seed is not None else 0
worker_seed = seed + worker_id
np.random.seed(worker_seed)
random.seed(worker_seed)


def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_size=0, random_state=None,
set_global_seed=False, no_valid_transform=False,
def prepare_dataset(batch_size, test_batch_size=256, root='./data/', use_validation=True, split_from_end=True,
validation_size=10000, random_state=None, set_global_seed=False, no_valid_transform=True,
num_workers=0, num_val_workers=0, num_test_workers=0):
"""
Download the CIFAR-10 dataset and prepare train and test DataLoaders (optionally also validation loader).
Expand All @@ -37,8 +38,9 @@ def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_
batch_size: Batch size for the train (and validation) loader.
test_batch_size: Batch size for the test loader.
root: Directory path to download the CIFAR-10 dataset to.
use_validation: If False, don't split off the validation set.
split_from_end: If True, split off `validation_size` images from the end, if False, choose images randomly.
validation_size: Size of the validation dataset to split off the train set.
If == 0, don't return the validation set.
random_state: Seed for the random functions (generators from numpy and random)
set_global_seed: If True, call np.random.seed(random_state) and random.seed(random_state). Useful when
Expand Down Expand Up @@ -66,7 +68,7 @@ def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_
if random_state is not None:
worker_fn = partial(seed_worker, random_state)
else:
worker_fn=None
worker_fn = None

print('\n--- Preparing CIFAR10 Data ---')

Expand All @@ -87,10 +89,19 @@ def prepare_dataset(batch_size, test_batch_size=100, root='./data/', validation_
valid_set = valid_set if no_valid_transform else train_set
train_size = len(train_set)

# split off random validation set
if validation_size > 0:
train_sampler, valid_sampler = train_valid_split(train_size, validation_size, random_state=random_state)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False,
if use_validation:
if split_from_end:
# get last n images
indices = np.arange(len(train_set))
train_set = torch.utils.data.Subset(train_set, indices[:-validation_size])
valid_set = torch.utils.data.Subset(valid_set, indices[-validation_size:])
train_sampler, valid_sampler = None, None
else:
# split off random validation set
train_sampler, valid_sampler = train_valid_split(train_size, validation_size, random_state=random_state)

# shuffle is True if split_from_end otherwise False
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=split_from_end,
sampler=train_sampler, num_workers=num_workers,
worker_init_fn=worker_fn)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False,
Expand Down
49 changes: 28 additions & 21 deletions nasbench_pytorch/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

import torch
import torch.nn as nn
from torch.nn.init import _calculate_fan_in_and_fan_out


class Network(nn.Module):
def __init__(self, spec, num_labels=10,
in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, momentum=0.1, eps=1e-5):
def __init__(self, spec, num_labels=10, in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3,
momentum=0.997, eps=1e-5, tf_like=False):
"""
Args:
Expand All @@ -45,6 +46,7 @@ def __init__(self, spec, num_labels=10,

self.cell_indices = set()

self.tf_like = tf_like
self.layers = nn.ModuleList([])

# initial stem convolution
Expand Down Expand Up @@ -84,21 +86,31 @@ def forward(self, x):
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if self.tf_like:
fan_in, _ = _calculate_fan_in_and_fan_out(m.weight)
torch.nn.init.normal_(m.weight, mean=0, std=1.0 / torch.sqrt(torch.tensor(fan_in)))
else:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))

if m.bias is not None:
m.bias.data.zero_()

elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
if self.tf_like:
torch.nn.init.xavier_uniform_(m.weight)
else:
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()


class Cell(nn.Module):
"""
Builds the model using the adjacency matrix and op labels specified. Channels
controls the module output channel count but the interior channels are
control the module output channel count but the interior channels are
determined via equally splitting the channel count whenever there is a
concatenation of Tensors.
"""
Expand All @@ -111,7 +123,7 @@ def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5):
self.num_vertices = np.shape(self.matrix)[0]

# vertex_channels[i] = number of output channels of vertex i
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.matrix)
self.vertex_channels = compute_vertex_channels(in_channels, out_channels, self.matrix)
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)

# operation for each node
Expand All @@ -124,11 +136,11 @@ def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5):
self.input_op = nn.ModuleList([Placeholder()])
for t in range(1, self.num_vertices):
if self.matrix[0, t]:
self.input_op.append(Projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps))
self.input_op.append(projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps))
else:
self.input_op.append(Placeholder())

self.last_inop : Projection = self.input_op[self.num_vertices-1]
self.last_inop : projection = self.input_op[self.num_vertices - 1]

def forward(self, x):
tensors = [x]
Expand All @@ -141,20 +153,17 @@ def forward(self, x):
fan_in = []
for src in range(1, t):
if self.matrix[src, t]:
fan_in.append(Truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
fan_in.append(truncate(tensors[src], torch.tensor(self.vertex_channels[t])))

if self.matrix[0, t]:
l = inmod(x)
fan_in.append(l)

# perform operation on node
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
vertex_input = torch.zeros_like(fan_in[0]).to(self.dev_param.device)

for val in fan_in:
vertex_input += val
#vertex_input = sum(fan_in)
#vertex_input = sum(fan_in) / len(fan_in)

vertex_output = outmod(vertex_input)

tensors.append(vertex_output)
Expand All @@ -173,19 +182,15 @@ def forward(self, x):
if self.matrix[0, self.num_vertices-1]:
outputs = outputs + self.last_inop(tensors[0])

#if self.matrix[0, self.num_vertices-1]:
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
#outputs = sum(out_concat) / len(out_concat)

return outputs


def Projection(in_channels, out_channels, momentum=0.1, eps=1e-5):
def projection(in_channels, out_channels, momentum=0.1, eps=1e-5):
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
return ConvBnRelu(in_channels, out_channels, 1, momentum=momentum, eps=eps)


def Truncate(inputs, channels):
def truncate(inputs, channels):
"""Slice the inputs to channels if necessary."""
input_channels = inputs.size()[1]
if input_channels < channels:
Expand All @@ -200,7 +205,7 @@ def Truncate(inputs, channels):
return inputs[:, :channels, :, :]


def ComputeVertexChannels(in_channels, out_channels, matrix):
def compute_vertex_channels(in_channels, out_channels, matrix):
"""Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of
Expand All @@ -210,6 +215,8 @@ def ComputeVertexChannels(in_channels, out_channels, matrix):
When the division is not even, some vertices may receive an extra channel to
compensate.
Code from https://github.com/google-research/nasbench/
Returns:
list of channel counts, in order of the vertices.
"""
Expand Down
3 changes: 2 additions & 1 deletion nasbench_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
optimizer.zero_grad()
curr_loss = loss(outputs, targets)
curr_loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
if grad_clip is not None:
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
optimizer.step()

# metrics
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

setuptools.setup(
name='nasbench_pytorch',
version='1.2.3',
version='1.3',
license='Apache License 2.0',
author='Romulus Hong, Gabriela Suchopárová',
author='Romulus Hong, Gabriela Kadlecová',
packages=setuptools.find_packages()
)

0 comments on commit e7c82a2

Please sign in to comment.