Skip to content
Open
83 changes: 37 additions & 46 deletions doc/reproducing_best_results.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ they could be dispatched to run in parallel with any infrastructure you have
access to.

Time reported are approximate, and were measured on GCP instances with one P100
GPU, 16vCPU cores (although the CPU usage was far from full), 155 GB memory, and
GPU, 16vCPU cores (although the CPU usage was far from full), 104 GB memory, and
local SSD storage for records, summaries, and checkpoints.

## Baseline and pre-training on ImageNet
## Pre-training the backbones on ImageNet

Some of the best meta-learning models are initialized from the weights of a
batch baseline (trained on ImageNet). For this reason, we will start with
training the baseline with several backbones (not only the best one). Since not
all backbone variants are needed for the best models, we will only need to train
3 of them.
The best meta-learning models are initialized from the weights of a batch
baseline (trained on ImageNet). For this reason, we will start with training
the baseline with several backbones (not only the best one). Since not all
backbone variants are needed for the best models, we will only need to train 3
of them.

```bash
export EXPNAME=baseline_imagenet
for BACKBONE in resnet mamlconvnet mamlresnet
export EXPNAME=pretrain_imagenet
for BACKBONE in convnet resnet wide_resnet
do
export JOBNAME=${EXPNAME}_${BACKBONE}
python -m meta_dataset.train \
Expand All @@ -64,53 +64,39 @@ do
done
```

Each of the jobs took between 12 and 18 hours to reach 75k steps (episodes).
Time to reach 50k steps (episodes):

## Training on ImageNet

### k-NN

The `baseline` ("k-NN") model does not have to be trained again, the `resnet`
variant performed the best. For consistency, we can simply add symbolic links:

```bash
ln -s ${EXPROOT}/checkpoints/baseline_imagenet_resnet ${EXPROOT}/checkpoints/baseline_imagenet
ln -s ${EXPROOT}/summaries/baseline_imagenet_resnet ${EXPROOT}/summaries/baseline_imagenet
```

### Finetune, ProtoNet

The best models for `baselinefinetune` ("Finetune") and `prototypical`
("ProtoNet") on ILSVRC-2012 were not initialized from pre-trained model, so
their respective gin config indicates:

- `LearnerConfig.pretrained_source = 'scratch'`, and
- `LearnerConfig.pretrained_checkpoint = ''`

They can be launched right away (in parallel with the pre-training), and their
configuration does not need to be changed.
- `convnet`: 8 hours
- `resnet`: 12 hours
- `wide_resnet`: 24 hours

### Other models
### Update the training files

For the other models, the respective best pre-train model is:
Once the pre-training is complete, edit the following files to replace
`/path/to/checkpoints` with the value of `${EXPROOT}/checkpoints`:

- `matching` ("MatchingNet"): `resnet`
- `maml` ("fo-MAML"): `mamlconvnet` (`four_layer_convnet_maml`)
- `maml_init_with_proto` ("Proto-MAML"): `mamlresnet` (`resnet_maml`)
- `meta_dataset/learn/gin/best/pretrained_convnet.gin`
- `meta_dataset/learn/gin/best/pretrained_resnet.gin`
- `meta_dataset/learn/gin/best/pretrained_wide_resnet.gin`

The corresponding `.gin` file indicates `LearnerConfig.pretrained_source =
'imagenet'`, and has a placeholder for `LearnerConfig.pretrained_checkpoint`.
The number of steps for the best checkpoint did not make a measurable difference
in our experience, so you can simply update the base path and keep the number in
"`model_?????.ckpt`". If you would like to perform the selection for the best
number of steps, see [Get the best checkpoint](#get-the-best-checkpoint) section
below, and update the gin files accordingly.
below, and update these gin files accordingly.

### Command line
## Training on ImageNet

### Episodic training

The episodic models had the option to reload pre-trained weights as a starting
point, or to start from scratch. All the best models started from pre-trained
weights. Each of their `.gin` file includes the corresponding
`best/pretrained_*net.gin` file, so the appropriate weights are reloaded.

```bash
export SOURCE=imagenet
for MODEL in baselinefinetune prototypical matching maml maml_init_with_proto
for MODEL in matching prototypical maml relationnet maml_init_with_proto
do
export EXPNAME=${MODEL}_${SOURCE}
python -m meta_dataset.train \
Expand All @@ -121,10 +107,15 @@ do
--gin_bindings="LearnerConfig.experiment_name='$EXPNAME'"
done
```
### Re-training baselines

(baseline, baselinefinetune)

```


### Command line

Note: rather than editing the `.gin` file, it is also possible to specify the
path to the pretrained checkpoint to load on the command-line, adding
`--gin_bindings="LearnerConfig.pretrained_checkpoint='...'"`.

Run time:

Expand Down
6 changes: 6 additions & 0 deletions meta_dataset/analysis/select_best_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def keep_variant(name):
if tf.io.gfile.isdir(os.path.join(summary_dir, fname))
]

if not variant_names:
# Maybe there are no variants, and we are already in the directory that
# contains the summaries. In this case, we consider that the current
# directory (.) is the only variant.
variant_names = ['.']

# Further filter variant names based on the given restrictions.
variant_names = [name for name in variant_names if keep_variant(name)]

Expand Down
8 changes: 5 additions & 3 deletions meta_dataset/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def get_event_files(root_dir):
os.path.join(summaries_dir, f)
for f in tf.io.gfile.listdir(summaries_dir)
]
# Filter out non-directory files, if any.
child_dirs = [child for child in child_dirs if tf.io.gfile.isdir(child)]
logging.info('Looking for events in dirs: %s', child_dirs)
for child_dir in child_dirs:
for file_name in tf.io.gfile.listdir(child_dir):
Expand Down Expand Up @@ -536,9 +538,9 @@ def analyze_events(paths_to_event_files, experiment_root_dir,

else:
# Read the data from the event files.
(ways, shots, class_props,
class_ids, test_logits, test_targets) = read_data(
path_to_event, do_finegrainedness_analysis, do_imbalance_analysis)
(ways, shots, class_props, class_ids, test_logits,
test_targets) = read_data(path_to_event, do_finegrainedness_analysis,
do_imbalance_analysis)

# A dict mapping each observed 'shot' to a list of class precisions
# obtained by classes that had that shot (regardless of shots of other
Expand Down
16 changes: 7 additions & 9 deletions meta_dataset/learn/gin/best/baseline_all.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@ include 'meta_dataset/learn/gin/models/baseline_config.gin'
BatchSplitReaderGetReader.add_dataset_offset = True

# Backbone hypers.
LearnerConfig.embedding_network = 'resnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'
include 'meta_dataset/learn/gin/best/pretrained_wide_resnet.gin'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = False
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True

# Data hypers.
DataConfig.image_height = 84
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 100
LearnerConfig.decay_rate = 0.9509421043104758
LearnerConfig.learning_rate = 0.0005102311315353643
weight_decay = 0.0005155272693869694
LearnerConfig.decay_every = 500
LearnerConfig.decay_rate = 0.8778059962506467
LearnerConfig.learning_rate = 0.000253906846867988
weight_decay = 0.00002393929026012612
23 changes: 0 additions & 23 deletions meta_dataset/learn/gin/best/baseline_all_option_to_pretrain.gin

This file was deleted.

16 changes: 8 additions & 8 deletions meta_dataset/learn/gin/best/baseline_imagenet.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ include 'meta_dataset/learn/gin/models/baseline_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'
LearnerConfig.pretrained_checkpoint = ''

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_classifier = False
BaselineLearner.cosine_logits_multiplier = 10
BaselineLearner.use_weight_norm = False

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 100
LearnerConfig.decay_rate = 0.5082121576573064
LearnerConfig.learning_rate = 0.007084776688116927
weight_decay = 0.000005078192976503067
LearnerConfig.decay_every = 10000
LearnerConfig.decay_rate = 0.7294597641152971
LearnerConfig.learning_rate = 0.007634189137886614
weight_decay = 0.000007138118976497546

This file was deleted.

16 changes: 8 additions & 8 deletions meta_dataset/learn/gin/best/baselinefinetune_all.gin
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
include 'meta_dataset/learn/gin/setups/all.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'
BatchSplitReaderGetReader.add_dataset_offset = True

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_source = 'scratch'
include 'meta_dataset/learn/gin/best/pretrained_wide_resnet.gin'

# Model hypers.
BaselineLearner.cosine_classifier = False
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineFinetuneLearner.num_finetune_steps = 50
BaselineFinetuneLearner.finetune_lr = 0.1
BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_all_layers = True
BaselineFinetuneLearner.finetune_with_adam = True

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 2500
LearnerConfig.decay_rate = 0.5508783586336378
LearnerConfig.learning_rate = 0.005493938830376542
weight_decay = 0.0000031050368100770684
LearnerConfig.decay_every = 5000
LearnerConfig.decay_rate = 0.5559080744371039
LearnerConfig.learning_rate = 0.0027015533546616804
weight_decay = 0.00002266979856832968

This file was deleted.

21 changes: 10 additions & 11 deletions meta_dataset/learn/gin/best/baselinefinetune_imagenet.gin
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'resnet'
LearnerConfig.pretrained_source = 'scratch'
include 'meta_dataset/learn/gin/best/pretrained_wide_resnet.gin'

# Model hypers.
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_classifier = False
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_logits_multiplier = 10
BaselineFinetuneLearner.num_finetune_steps = 100
BaselineLearner.cosine_logits_multiplier = 1
BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_all_layers = False
BaselineFinetuneLearner.finetune_all_layers = True
BaselineFinetuneLearner.finetune_with_adam = True

# Data hypers.
DataConfig.image_height = 126
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 2500
LearnerConfig.decay_rate = 0.7427378713742643
LearnerConfig.learning_rate = 0.003725198463674423
weight_decay = 0.000003337891450479888
LearnerConfig.decay_every = 5000
LearnerConfig.decay_rate = 0.5559080744371039
LearnerConfig.learning_rate = 0.0027015533546616804
weight_decay = 0.00002266979856832968

This file was deleted.

Loading