Skip to content

Commit f10efed

Browse files
committed
add previous training code in train.py
1 parent 3d94f65 commit f10efed

3 files changed

Lines changed: 145 additions & 8 deletions

File tree

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
This repository contains the official code for our [ICML 2024 paper](https://openreview.net/forum?id=VyoY3Wh9Wd). `ifBO` is an efficient Bayesian Optimization algorithm that dynamically selects and incrementally evaluates candidates during the optimization process. It uses a model called the `Freeze-Thaw surrogate (FT-PFN)` to predict the performance of candidate configurations as more resources are allocated. The `main` branch includes the necessary API to use `FT-PFN`. Refer to the following sections:
88
- [Surrogate API](#surrogate-api): to learn how to initialize and use the surrogate model.
99
- [Bayesian Optimization with ifBO](#bayesian-optimization-with-ifbo): to understand how to use `ifBO` for Hyperparameter Optimization.
10+
- [Training your own model][#training-your-own-model]: to understand ifBO training pipeline.
1011

1112

1213
> To reproduce experiments from the above paper version, please refer to the branch [`icml-2024`](https://github.com/automl/ifBO/tree/icml-2024).
1314
1415
# Installation
1516

16-
Requires Python 3.11.
17+
Requires Python 3.11 or later.
1718

1819
```bash
1920
pip install -U ifBO
@@ -139,6 +140,15 @@ neps.run(
139140
)
140141
```
141142

143+
## Training your own model
144+
145+
Train ifBO from scratch with the following command:
146+
147+
```bash
148+
python -m ifbo.train --epochs 20 --output_path your_own_ifbo.model --seq_len 1000
149+
```
150+
151+
For more training options, run ``python -m ifbo.train -h`` or inspect ``ifbo/train.py``.
142152

143153

144154
# Citation

ifbo/priors/ftpfn_prior.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,30 @@ def forward(self, *x, **kwargs) -> torch.Tensor:
505505
)
506506
return out
507507

508+
class MultiCurvesEncoderSeqLen(torch.nn.Module):
509+
def __init__(self, in_dim: int, out_dim: int, seq_len: int) -> None:
510+
super().__init__()
511+
self.normalizer = torch.nn.Sequential(
512+
encoders.Normalize(0.5, math.sqrt(1 / 12)),
513+
)
514+
self.epoch_enc = torch.nn.Linear(1, out_dim, bias=False)
515+
self.idcurve_enc = torch.nn.Embedding(seq_len + 1, out_dim)
516+
self.configuration_enc = encoders.get_variable_num_features_encoder(encoders.Linear)(
517+
in_dim - 2, out_dim
518+
)
519+
520+
def forward(self, *x, **kwargs) -> torch.Tensor:
521+
x = torch.cat(x, dim=-1)
522+
out = (
523+
self.epoch_enc(self.normalizer(x[..., 1:2]))
524+
+ self.idcurve_enc(x[..., :1].int()).squeeze(2)
525+
+ self.configuration_enc(x[..., 2:])
526+
)
527+
return out
528+
508529

509-
def get_encoder() -> Callable[[int, int], torch.nn.Module]:
510-
return lambda num_features, emsize: MultiCurvesEncoder(num_features, emsize)
530+
def get_encoder(seq_len) -> Callable[[int, int], torch.nn.Module]:
531+
return lambda num_features, emsize: MultiCurvesEncoderSeqLen(num_features, emsize, seq_len)
511532

512533

513534
def sample_curves(

ifbo/train.py

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@
66
import time
77
from typing import Any
88

9+
import argparse
910
import torch
1011
from torch import nn
1112
from torch.cuda.amp import autocast
1213
from torch.cuda.amp import GradScaler
1314
from tqdm import tqdm
1415

1516
from ifbo import positional_encodings
16-
from ifbo import utils
17+
from ifbo import utils, encoders, bar_distribution
1718
from ifbo.bar_distribution import BarDistribution
1819
from ifbo.bar_distribution import get_custom_bar_dist
19-
from ifbo.priors import prior
20+
from ifbo.priors import prior, ftpfn_prior
21+
from ifbo.priors.utils import get_batch_to_dataloader
2022
from ifbo.transformer import TransformerModel
2123
from ifbo.utils import get_cosine_schedule_with_warmup
2224
from ifbo.utils import get_openai_lr
2325
from ifbo.utils import init_dist
2426

27+
from ifbo.utils import default_device
28+
2529

2630
class Losses:
2731
def get_cross_entropy_loss(self, num_classes: int) -> nn.CrossEntropyLoss:
@@ -205,8 +209,8 @@ def train_epoch() -> tuple[float, list[float], float, float, float, float, float
205209
total_loss = 0.0
206210
total_positional_losses = torch.zeros(bptt)
207211
total_positional_losses_recorded = torch.zeros(bptt)
208-
nan_steps = torch.zeros(1)
209-
ignore_steps = torch.zeros(1)
212+
nan_steps = torch.zeros(1).to(device)
213+
ignore_steps = torch.zeros(1).to(device)
210214
before_get_batch = time.time()
211215
assert (
212216
len(dl) % aggregate_k_gradients == 0
@@ -384,7 +388,7 @@ def apply_batch_wise_criterion(i: int) -> torch.Tensor:
384388
}
385389
if step_callback is not None and rank == 0:
386390
step_callback(metrics_to_log)
387-
nan_steps += nan_share
391+
nan_steps += nan_share.detach()
388392
ignore_steps += (targets == -100).float().mean()
389393
except Exception as e:
390394
print("Invalid step encountered, skipping...")
@@ -459,3 +463,105 @@ def apply_batch_wise_criterion(i: int) -> torch.Tensor:
459463
return total_loss, total_positional_losses, model.to("cpu"), dl
460464

461465
return None
466+
467+
if __name__ == "__main__":
468+
parser = argparse.ArgumentParser(description="Train an ifBO model")
469+
470+
# transformer model parameters
471+
parser.add_argument("--nlayers", type=int, help="Number of layers", default=6)
472+
parser.add_argument("--emsize", type=int, default=512, help="Size of Embeddings")
473+
parser.add_argument("--nhead", type=int, default=4, help="Number of heads")
474+
475+
# PFN parameters
476+
parser.add_argument(
477+
"--num_borders",
478+
type=int,
479+
default=1000,
480+
help="Number of borders considered in Bar distribution",
481+
)
482+
483+
# Prior parameters
484+
parser.add_argument("--seq_len", type=int, required=True, help="Maximum sequence length")
485+
parser.add_argument(
486+
"--num_features",
487+
type=int,
488+
required=False,
489+
help="The total number of features for each datapoint in an example.",
490+
default=12, # has to be at least 3
491+
)
492+
parser.add_argument(
493+
"--power_single_eval_pos_sampler",
494+
type=int,
495+
required=False,
496+
help="Power of an exponential distribution to weight sampling of single eval pos.",
497+
default=-2,
498+
)
499+
500+
# training parameters
501+
parser.add_argument("--epochs", type=int, required=True, help="Number of Training Epochs")
502+
parser.add_argument("--batch_size", type=int, default=25, help="Batch Size for Training")
503+
parser.add_argument("--lr", type=float, default=0.0001, help="Learning Rate")
504+
parser.add_argument("--steps_per_epoch", type=int, default=100, help="Number of Steps per Epoch")
505+
parser.add_argument(
506+
"--train_mixed_precision",
507+
action="store_true",
508+
help="Enable Mixed Precision Training",
509+
)
510+
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use")
511+
512+
# other parameters
513+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the model")
514+
515+
args = parser.parse_args()
516+
517+
seq_len = args.seq_len
518+
519+
bucket_limits = torch.linspace(0.0, 1.0, args.num_borders).to(default_device)
520+
criterion = bar_distribution.BarDistribution(bucket_limits)
521+
522+
single_eval_pos_gen = utils.get_weighted_single_eval_pos_sampler(
523+
max_len=seq_len,
524+
min_len=0,
525+
p=args.power_single_eval_pos_sampler,
526+
)
527+
528+
configs_train = {
529+
"nlayers": args.nlayers,
530+
"emsize": args.emsize,
531+
"epochs": args.epochs,
532+
"lr": args.lr,
533+
"nhead": args.nhead,
534+
"bptt": seq_len,
535+
"steps_per_epoch": args.steps_per_epoch,
536+
"train_mixed_precision": args.train_mixed_precision,
537+
"batch_size": args.batch_size,
538+
}
539+
configs_train["bptt"] = seq_len
540+
configs_train["nhid"] = args.emsize * 2
541+
configs_train["warmup_epochs"] = args.epochs // 4
542+
configs_train.update(
543+
dict(
544+
priordataloader_class=get_batch_to_dataloader(ftpfn_prior.get_batch),
545+
criterion=criterion,
546+
encoder_generator=ftpfn_prior.get_encoder(seq_len),
547+
y_encoder_generator=encoders.get_normalized_uniform_encoder(
548+
encoders.Linear
549+
),
550+
extra_prior_kwargs_dict={
551+
"num_features": args.num_features,
552+
},
553+
single_eval_pos_gen=single_eval_pos_gen,
554+
style_encoder_generator=None
555+
)
556+
)
557+
558+
total_loss, total_positional_losses, model, dl = train(
559+
**configs_train
560+
)
561+
print(f"Total loss: {total_loss}, Total positional losses: {total_positional_losses}")
562+
torch.save(
563+
model,
564+
args.output_path,
565+
)
566+
print(f"Model saved to {args.output_path}")
567+

0 commit comments

Comments
 (0)