-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_distill_e2d2_gsm8k.sh
More file actions
98 lines (91 loc) · 3.59 KB
/
run_distill_e2d2_gsm8k.sh
File metadata and controls
98 lines (91 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/bin/bash
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
# Model arch
BLOCK_SIZE=1
N_ENCODER_LAYERS=28
ENCODER_TOP_LAYERS=false
N_DECODER_LAYERS=14
DECODER_TOP_LAYERS=true
REINIT_ENCODER=false
REINIT_DECODER=false
TIE_WEIGHTS=true
FREEZE_ENCODER=false
ENCODER_CAUSAL_MASK=false
# Hyperparameters
LR=1e-5
WARMUP_DURATION="100ba"
ALPHA_F=0.5
BATCH_SIZE=1
MAX_DURATION="80000ba"
PRECISION="amp_bf16"
PRETRAINED_MODEL_NAME_OR_PATH=Qwen/Qwen3-1.7B-Base
TRAIN_ON_CONTEXT=false
TAG="e2d2_anneal"
if [ "${ENCODER_TOP_LAYERS}" == "true" ]; then
ENC_LAYERS="TOPenc${N_ENCODER_LAYERS}"
else
ENC_LAYERS="enc${N_ENCODER_LAYERS}"
fi
if [ "${DECODER_TOP_LAYERS}" == "true" ]; then
DEC_LAYERS="TOPdec${N_DECODER_LAYERS}"
else
DEC_LAYERS="dec${N_DECODER_LAYERS}"
fi
RUN_NAME=gsm8k-distill-block${BLOCK_SIZE}_lr${LR}_bsz${BATCH_SIZE}_warm${WARMUP_DURATION}_alphaf${ALPHA_F}_max-dur${MAX_DURATION}_${PRECISION}_${ENC_LAYERS}_${DEC_LAYERS}_${TAG}
if [ "${TIE_WEIGHTS}" == "true" ]; then
RUN_NAME="${RUN_NAME}_tie-weights"
fi
if [ "${ENCODER_CAUSAL_MASK}" == "true" ]; then
RUN_NAME="${RUN_NAME}_encoder-causal-mask"
fi
if [ "${FREEZE_ENCODER}" == "true" ]; then
RUN_NAME="${RUN_NAME}_freeze-enc"
fi
MICRO_BATCH_SIZE=1
NUM_WORKERS=0
composer -n ${NUM_VISIBLE_DEVICES} scripts/composer_scripts/train_discrete_denoiser.py \
run_name=${RUN_NAME} \
pretrained_model_name_or_path=${PRETRAINED_MODEL_NAME_OR_PATH} \
dataset@train_dataset=gsm8k_train_distill \
dataset@eval_dataset=gsm8k_eval_distill \
composer=distill_composer \
composer.optimizer.lr=${LR} \
composer.trainer.precision=${PRECISION} \
composer.trainer.eval_interval="1000ba" \
composer.trainer.max_duration=${MAX_DURATION} \
composer/lr_scheduler=cosine_annealing_with_warmup \
composer.lr_scheduler.t_warmup=${WARMUP_DURATION} \
composer.lr_scheduler.alpha_f=${ALPHA_F} \
model=e2d2 \
model.config.attn_backend="sdpa" \
training.compile_backbone=false \
model.config.length=1024 \
model/backbone@model.config.backbone_config=llm_as_encoder_decoder_share_kv \
model.config.backbone_config.use_encoder_causal_mask=${ENCODER_CAUSAL_MASK} \
model.config.backbone_config.num_encoder_layers=${N_ENCODER_LAYERS} \
model.config.backbone_config.num_decoder_layers=${N_DECODER_LAYERS} \
model.config.backbone_config.tie_encoder_decoder_weights=${TIE_WEIGHTS} \
model.config.backbone_config.freeze_encoder=${FREEZE_ENCODER} \
model.config.backbone_config.reinit_decoder=${REINIT_DECODER} \
model.config.backbone_config.reinit_encoder=${REINIT_ENCODER} \
model.config.backbone_config.keep_top_decoder_layers=${DECODER_TOP_LAYERS} \
model.config.backbone_config.keep_top_encoder_layers=${ENCODER_TOP_LAYERS} \
model.config.backbone_config.use_gradient_checkpointing=false \
training.global_batch_size=${BATCH_SIZE} \
training.grad_accum=$(( BATCH_SIZE / NUM_VISIBLE_DEVICES / MICRO_BATCH_SIZE )) \
block_size=${BLOCK_SIZE} \
training.antithetic_sampling=false \
hydra.run.dir=${RUN_DIR}/${RUN_NAME} \
composer.callbacks.hf_compatible_checkpointing.save_interval="1000ba" \
composer.callbacks.hf_compatible_checkpointing.num_checkpoints_to_keep=1 \
composer.callbacks.hf_compatible_checkpointing.disable_hf=true \
composer.loggers.name=${RUN_NAME} \
train_dataloader.num_workers=${NUM_WORKERS} \
eval_dataloader.batch_size=8 \
eval_evaluator.device_eval_microbatch_size=8 \
model.config.train_on_context=${TRAIN_ON_CONTEXT} \
model.config.bidirectional_ctx_attn=false \
composer.algorithms.block_size_annealing.schedule="1000ba" \
composer.callbacks.save_best_checkpointing.start="5000ba"