|
| 1 | +# @package bypass |
| 2 | +# Bypass Distillation Configuration |
| 3 | +# This config defines parameters for blockwise local distillation (BLD), |
| 4 | +# which trains alternative transformer block configurations using per-block |
| 5 | +# knowledge distillation from a teacher model. |
| 6 | + |
| 7 | +# Runtime Configuration |
| 8 | +dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability |
| 9 | +seed: 42 # Random seed for reproducibility |
| 10 | + |
| 11 | +# Experiment Tracking |
| 12 | +experiment_id: # Unique identifier for this experiment. Will be dynamically set |
| 13 | +experiment_dir: # Directory for this experiment. Will be dynamically set |
| 14 | +iter_num: 1 # Current iteration number |
| 15 | +step_num: 1 # Current step number within iteration |
| 16 | +token_count: 0 # Token count tracker (auto-updated during training) |
| 17 | + |
| 18 | +# Data Configuration |
| 19 | +data: |
| 20 | + data_column: "messages" |
| 21 | + block_size: 512 # Sequence length (tokens per sample) |
| 22 | + bos_rate: 0.5 |
| 23 | + fim_rate: 0 |
| 24 | + fim_spm_rate: 0 |
| 25 | + source_datasets_to_discard: [] |
| 26 | + load_from_disk: true # Load preprocessed data from disk or from stream |
| 27 | + keep_in_memory: false |
| 28 | + val_dataset_name: valid |
| 29 | + max_eval_samples: 4 |
| 30 | + eval_samples_per_process: # Samples per GPU during distributed eval (auto if null) |
| 31 | + shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data |
| 32 | + |
| 33 | +# Training Configuration |
| 34 | +training: |
| 35 | + learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001) |
| 36 | + training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check) |
| 37 | + micro_batch_size: 2 |
| 38 | + val_micro_batch_size: 1 |
| 39 | + warmup_ratio: 0.05 |
| 40 | + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}} # Auto-calculated warmup steps |
| 41 | + min_lr_factor: 1e-5 |
| 42 | + grad_accumulation_steps: 1 |
| 43 | + skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues. |
| 44 | + weight_decay: 0.1 |
| 45 | + decay_lr: true |
| 46 | + beta1: 0.9 |
| 47 | + beta2: 0.95 |
| 48 | + use_grad_scaling: false |
| 49 | + grad_clip: 1.0 |
| 50 | + grad_clip_type: norm |
| 51 | + clipping_count: 0 |
| 52 | + log_interval: 5 |
| 53 | + eval_interval: 5 |
| 54 | + |
| 55 | +# Model Loading Configuration |
| 56 | +resume_checkpoint_path: # Path to resume training from checkpoint |
| 57 | +find_last_ckpt_for_resume: true # Auto-resume by finding last checkpoint (bool) |
| 58 | +parameter_count: |
| 59 | +init_checkpoint_path: # Path to initialize weights from |
| 60 | + |
| 61 | +model: |
| 62 | + student_weights_dtype: "bf16" # Student model weight precision |
| 63 | + |
| 64 | + model_overrides: |
| 65 | + delete_old_checkpoints: true # Clean up old checkpoints to save disk space |
| 66 | + save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours |
| 67 | + save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled) |
| 68 | + save_checkpoint_when_done: true # Save final checkpoint when training completes |
| 69 | + |
| 70 | +# Architecture modifications for student model |
| 71 | + model_config_overrides: |
| 72 | + ffn: |
| 73 | + - intermediate_size: |
| 74 | + no_op: # Disable FFN entirely (true/false) |
| 75 | + attention: |
| 76 | + - num_key_value_heads: # Number of kv-heads (for GQA) |
| 77 | + no_op: # Disable attention entirely (true/false) |
| 78 | + |
| 79 | +# Model Factory Configuration - Controls student model creation and initialization |
| 80 | +model_factory: |
| 81 | + factory: bypass_factory_fn # Unified factory supporting all layer types |
| 82 | + block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss |
| 83 | + gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode |
| 84 | + mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode |
| 85 | + mlp_init_config: # Configuration for MLP initialization (if needed) |
| 86 | + activations_log_dir: # Directory with activation statistics (required for PruneByActivationsLog) |
| 87 | + linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc. |
| 88 | + submodule_for_loss_calculation: # Specific submodule for loss calc. |
| 89 | + keys_to_learn: # Subblock(s) to train: entire_block, subblock_attention, subblock_ffn, subblock_mamba, or a list of those. |
| 90 | + |
| 91 | +# Validation Configuration |
| 92 | +disable_initial_validate: false |
| 93 | +validate_teacher_model: true |
| 94 | +validate_student_model: true |
| 95 | +disable_validation: false # Enable validation to exercise all code paths |
| 96 | +best_val_loss: 1e+9 # Track best validation loss achieved |
| 97 | + |
| 98 | +# Performance Optimization |
| 99 | +compile: false # Use PyTorch compilation |
| 100 | +disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available) |
| 101 | +teacher_model_load_on_cpu: false |
| 102 | + |
| 103 | +# Checkpoint Management |
| 104 | +save_checkpoint_before_training: false # Save initial checkpoint before training |
| 105 | +disable_checkpoint_save: false # Disable all checkpoint saving |
| 106 | +save_best_ckpt: true # Save checkpoint when validation improves |
| 107 | +kill_after_first_save: false # Exit after first checkpoint save (for testing) |
| 108 | +realize_best_or_latest: "best" |
| 109 | + |
| 110 | +wandb_log: false |
| 111 | +wandb: |
| 112 | + project: |
| 113 | + entity: |
| 114 | + |
| 115 | +# Multiple bypass configurations to train sequentially. |
| 116 | +# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn. |
| 117 | +# If empty or absent, a single run uses the settings above. |
| 118 | +configs: |
| 119 | + - model_config_overrides: |
| 120 | + ffn: |
| 121 | + - intermediate_size: 3072 |
| 122 | + attention: |
| 123 | + - num_key_value_heads: 8 |
| 124 | + keys_to_learn: subblock_ffn |
| 125 | + - model_config_overrides: |
| 126 | + ffn: |
| 127 | + - intermediate_size: 5888 |
| 128 | + attention: |
| 129 | + - num_key_value_heads: 8 |
| 130 | + keys_to_learn: subblock_ffn |
0 commit comments