|
6 | 6 | import time |
7 | 7 | from typing import Any |
8 | 8 |
|
| 9 | +import argparse |
9 | 10 | import torch |
10 | 11 | from torch import nn |
11 | 12 | from torch.cuda.amp import autocast |
12 | 13 | from torch.cuda.amp import GradScaler |
13 | 14 | from tqdm import tqdm |
14 | 15 |
|
15 | 16 | from ifbo import positional_encodings |
16 | | -from ifbo import utils |
| 17 | +from ifbo import utils, encoders, bar_distribution |
17 | 18 | from ifbo.bar_distribution import BarDistribution |
18 | 19 | from ifbo.bar_distribution import get_custom_bar_dist |
19 | | -from ifbo.priors import prior |
| 20 | +from ifbo.priors import prior, ftpfn_prior |
| 21 | +from ifbo.priors.utils import get_batch_to_dataloader |
20 | 22 | from ifbo.transformer import TransformerModel |
21 | 23 | from ifbo.utils import get_cosine_schedule_with_warmup |
22 | 24 | from ifbo.utils import get_openai_lr |
23 | 25 | from ifbo.utils import init_dist |
24 | 26 |
|
| 27 | +from ifbo.utils import default_device |
| 28 | + |
25 | 29 |
|
26 | 30 | class Losses: |
27 | 31 | def get_cross_entropy_loss(self, num_classes: int) -> nn.CrossEntropyLoss: |
@@ -205,8 +209,8 @@ def train_epoch() -> tuple[float, list[float], float, float, float, float, float |
205 | 209 | total_loss = 0.0 |
206 | 210 | total_positional_losses = torch.zeros(bptt) |
207 | 211 | total_positional_losses_recorded = torch.zeros(bptt) |
208 | | - nan_steps = torch.zeros(1) |
209 | | - ignore_steps = torch.zeros(1) |
| 212 | + nan_steps = torch.zeros(1).to(device) |
| 213 | + ignore_steps = torch.zeros(1).to(device) |
210 | 214 | before_get_batch = time.time() |
211 | 215 | assert ( |
212 | 216 | len(dl) % aggregate_k_gradients == 0 |
@@ -384,7 +388,7 @@ def apply_batch_wise_criterion(i: int) -> torch.Tensor: |
384 | 388 | } |
385 | 389 | if step_callback is not None and rank == 0: |
386 | 390 | step_callback(metrics_to_log) |
387 | | - nan_steps += nan_share |
| 391 | + nan_steps += nan_share.detach() |
388 | 392 | ignore_steps += (targets == -100).float().mean() |
389 | 393 | except Exception as e: |
390 | 394 | print("Invalid step encountered, skipping...") |
@@ -459,3 +463,105 @@ def apply_batch_wise_criterion(i: int) -> torch.Tensor: |
459 | 463 | return total_loss, total_positional_losses, model.to("cpu"), dl |
460 | 464 |
|
461 | 465 | return None |
| 466 | + |
| 467 | +if __name__ == "__main__": |
| 468 | + parser = argparse.ArgumentParser(description="Train an ifBO model") |
| 469 | + |
| 470 | + # transformer model parameters |
| 471 | + parser.add_argument("--nlayers", type=int, help="Number of layers", default=6) |
| 472 | + parser.add_argument("--emsize", type=int, default=512, help="Size of Embeddings") |
| 473 | + parser.add_argument("--nhead", type=int, default=4, help="Number of heads") |
| 474 | + |
| 475 | + # PFN parameters |
| 476 | + parser.add_argument( |
| 477 | + "--num_borders", |
| 478 | + type=int, |
| 479 | + default=1000, |
| 480 | + help="Number of borders considered in Bar distribution", |
| 481 | + ) |
| 482 | + |
| 483 | + # Prior parameters |
| 484 | + parser.add_argument("--seq_len", type=int, required=True, help="Maximum sequence length") |
| 485 | + parser.add_argument( |
| 486 | + "--num_features", |
| 487 | + type=int, |
| 488 | + required=False, |
| 489 | + help="The total number of features for each datapoint in an example.", |
| 490 | + default=12, # has to be at least 3 |
| 491 | + ) |
| 492 | + parser.add_argument( |
| 493 | + "--power_single_eval_pos_sampler", |
| 494 | + type=int, |
| 495 | + required=False, |
| 496 | + help="Power of an exponential distribution to weight sampling of single eval pos.", |
| 497 | + default=-2, |
| 498 | + ) |
| 499 | + |
| 500 | + # training parameters |
| 501 | + parser.add_argument("--epochs", type=int, required=True, help="Number of Training Epochs") |
| 502 | + parser.add_argument("--batch_size", type=int, default=25, help="Batch Size for Training") |
| 503 | + parser.add_argument("--lr", type=float, default=0.0001, help="Learning Rate") |
| 504 | + parser.add_argument("--steps_per_epoch", type=int, default=100, help="Number of Steps per Epoch") |
| 505 | + parser.add_argument( |
| 506 | + "--train_mixed_precision", |
| 507 | + action="store_true", |
| 508 | + help="Enable Mixed Precision Training", |
| 509 | + ) |
| 510 | + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use") |
| 511 | + |
| 512 | + # other parameters |
| 513 | + parser.add_argument("--output_path", type=str, required=True, help="Path to save the model") |
| 514 | + |
| 515 | + args = parser.parse_args() |
| 516 | + |
| 517 | + seq_len = args.seq_len |
| 518 | + |
| 519 | + bucket_limits = torch.linspace(0.0, 1.0, args.num_borders).to(default_device) |
| 520 | + criterion = bar_distribution.BarDistribution(bucket_limits) |
| 521 | + |
| 522 | + single_eval_pos_gen = utils.get_weighted_single_eval_pos_sampler( |
| 523 | + max_len=seq_len, |
| 524 | + min_len=0, |
| 525 | + p=args.power_single_eval_pos_sampler, |
| 526 | + ) |
| 527 | + |
| 528 | + configs_train = { |
| 529 | + "nlayers": args.nlayers, |
| 530 | + "emsize": args.emsize, |
| 531 | + "epochs": args.epochs, |
| 532 | + "lr": args.lr, |
| 533 | + "nhead": args.nhead, |
| 534 | + "bptt": seq_len, |
| 535 | + "steps_per_epoch": args.steps_per_epoch, |
| 536 | + "train_mixed_precision": args.train_mixed_precision, |
| 537 | + "batch_size": args.batch_size, |
| 538 | + } |
| 539 | + configs_train["bptt"] = seq_len |
| 540 | + configs_train["nhid"] = args.emsize * 2 |
| 541 | + configs_train["warmup_epochs"] = args.epochs // 4 |
| 542 | + configs_train.update( |
| 543 | + dict( |
| 544 | + priordataloader_class=get_batch_to_dataloader(ftpfn_prior.get_batch), |
| 545 | + criterion=criterion, |
| 546 | + encoder_generator=ftpfn_prior.get_encoder(seq_len), |
| 547 | + y_encoder_generator=encoders.get_normalized_uniform_encoder( |
| 548 | + encoders.Linear |
| 549 | + ), |
| 550 | + extra_prior_kwargs_dict={ |
| 551 | + "num_features": args.num_features, |
| 552 | + }, |
| 553 | + single_eval_pos_gen=single_eval_pos_gen, |
| 554 | + style_encoder_generator=None |
| 555 | + ) |
| 556 | + ) |
| 557 | + |
| 558 | + total_loss, total_positional_losses, model, dl = train( |
| 559 | + **configs_train |
| 560 | + ) |
| 561 | + print(f"Total loss: {total_loss}, Total positional losses: {total_positional_losses}") |
| 562 | + torch.save( |
| 563 | + model, |
| 564 | + args.output_path, |
| 565 | + ) |
| 566 | + print(f"Model saved to {args.output_path}") |
| 567 | + |
0 commit comments