-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
208 lines (183 loc) · 8.26 KB
/
train.py
File metadata and controls
208 lines (183 loc) · 8.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
Trains a model, logs to wandb, and saves it to local and GCS.
Evaluation runs async on a separate machine. See eval/eval_poller.py
"""
import logging
import os
import subprocess
import warnings
import torch
from tap import tapify
import grouping_trainer as gt
_RUN_NAME_ENV_VAR = "GROUPING_TRAINER_RUN_NAME"
logger = logging.getLogger(__name__)
base_model_to_per_device_token_budget_scale = {
# Not including the v1 jinaai/jina-embeddings-v2-base-code model b/c it doesn't support SDPA.
# Pls don't use models that don't support flash attention.
"lightonai/modernbert-embed-large": 4,
"Alibaba-NLP/gte-modernbert-base": 6,
"Qwen/Qwen3-Embedding-0.6B": 3,
"jinaai/jina-embeddings-v5-text-nano-text-matching": 4,
}
def run(
base_model: str = "lightonai/modernbert-embed-large",
run_shortname: str | None = None,
per_device_token_budget_scale: int | None = None,
global_train_batch_size: int = 256,
learning_rate: float = 1e-4,
tiny_run: bool = False,
use_text_prefix: bool = False,
*,
gpu: gt.launch.TrainingGpuType | None = None, # type: ignore[valid-type]
zone: str | None = None,
sync_start: bool = False,
multi_flex_start: bool = False,
):
"""
Train a grouping model. Writes checkpoints to GCS. Logs to wandb.
Parameters
----------
base_model
HuggingFace model ID or local path for the base encoder, or a `gs://...` path to a custom model directory in our
bucket, e.g., the checkpoint to a model pretrained using pretrain.py. Others we've tried:
Alibaba-NLP/gte-modernbert-base, Qwen/Qwen3-Embedding-0.6B, jinaai/jina-embeddings-v5-text-nano-text-matching
run_shortname
Short name for the run. Doesn't need to be unique b/c it's appended to the timestamp.
per_device_token_budget_scale
The scale in per_device_token_budget = scale * 8192. This is the memory and throughput knob. By default, if the
base_model has a historically known good scale, we use that, o.w. uses 3.
global_train_batch_size
Total training batch size across all devices. Only used for non-tiny runs. Technically this can be arbitrarily
high b/c we accumulate the gradient based on per_device_token_budget_scale.
tiny_run
Run a tiny training run to sanity check plumbing.
use_text_prefix
If True, add the model's designated prefix to the input text. Didn't help lightonai/modernbert-embed-large
gpu
The type of GPU to flex-start and run on.
zone
Override the default GCP zone for the gpu type when launching the GPU instance. Useful when flex-start capacity
is dry in the default zone for the requested gpu type.
sync_start
If False (default), flex-starts the instance—GCP waits up to 2h to find one. `--sync_start` uses on-demand
pricing and finds an instance in any zone, as flex-starting often can't find instances in time.
multi_flex_start
Fan out async flex-start submits across 10 zones simultaneously; first VM to boot claims a GCS lock and the rest
self-delete. Better odds than a single-zone flex-start when capacity is dry. Mutually exclusive with
--sync_start
"""
# Fail fast on a typo'd gs:// model URI before wasting time launching training.
if gt.utils.is_gcs_uri(base_model):
gt.utils.assert_gcs_path_exists(base_model)
run_name, run_gcs_dir = gt.launch.bootstrap_run(
run_shortname=run_shortname,
default_shortname=f"tiny-{gt.launch.JobType.TRAIN}",
run_name_env_var=_RUN_NAME_ENV_VAR,
process_type="training",
tiny_run=tiny_run,
)
if gpu is not None:
gt.launch.run_argv_remotely(
gpu=gpu,
job_type=gt.launch.JobType.TRAIN,
name_suffix=run_shortname or f"tiny-{gt.launch.JobType.TRAIN}",
sync_start=sync_start,
multi_flex_start=multi_flex_start,
zone=zone,
env_var_to_value={_RUN_NAME_ENV_VAR: run_name},
)
return
is_cuda = torch.cuda.is_available()
if not tiny_run:
assert is_cuda, "CUDA is required for full training. Did you mean to pass --tiny_run ?"
assert torch.cuda.is_bf16_supported(), "Get a GPU that supports bfloat16"
model = gt.utils.encoder_from_base(base_model, use_text_prefix=use_text_prefix)
if tiny_run:
training_config = gt.train.TrainingConfig(
run_shortname=run_shortname or f"tiny-{gt.launch.JobType.TRAIN}",
global_train_batch_size=2,
per_device_token_budget=64,
gradient_checkpointing=True,
sample_size_train=30,
num_logs=30,
num_checkpoints=2,
loss_type="contrastive",
contrastive_margin=0.5,
)
else:
assert run_shortname is not None
per_device_token_budget_scale = (
per_device_token_budget_scale or base_model_to_per_device_token_budget_scale.get(base_model, 3)
)
training_config = gt.train.TrainingConfig(
run_shortname=run_shortname,
global_train_batch_size=global_train_batch_size,
per_device_token_budget=8192 * per_device_token_budget_scale,
warmup_ratio=0.25,
learning_rate=learning_rate,
loss_type="contrastive",
contrastive_margin=0.5,
training_csvs=gt.data.DEFAULT_TRAIN_PATHS,
)
gt.data.ensure_local(training_config.training_csvs)
trainer = gt.train.make_trainer(model, training_config, run_name=run_name)
eval_was_launched = False
if trainer.accelerator.is_main_process:
gt.launch.upload_run_metadata(run_gcs_dir, training_config, config_filename="training_config.json")
gt.launch.init_wandb(run_name=run_name, x_label="train")
# Start eval poller on a separate machine. WANDB_ENTITY/WANDB_PROJECT are forwarded by gce_vm.
eval_command = (
f"python eval/eval_poller.py --run_gcs_dir {run_gcs_dir} --base_model {base_model} "
f"--wandb_run_id {run_name} "
f"--loss_type {training_config.loss_type} --contrastive_margin {training_config.contrastive_margin}"
)
if use_text_prefix:
eval_command += " --use_text_prefix"
if tiny_run:
eval_command += " --sample_val 200 --use_simple_precisions"
logger.info(f"\nThis command will be run to evaluate the model:\n\n{eval_command}\n")
if not tiny_run:
assert run_shortname is not None
gt.launch.gce_vm(
gpu="l4",
job_type=gt.launch.JobType.EVAL,
name_suffix=run_shortname,
command=eval_command,
)
logger.info("Created l4-eval instance with eval poller in startup script")
eval_was_launched = True
else:
logger.info("Skipping async eval on L4 for tiny_run")
trainer.add_callback(gt.train.GCSCheckpointUploadCallback(run_gcs_dir=run_gcs_dir))
warnings.filterwarnings(
"ignore",
message=".*torch.utils.checkpoint: the use_reentrant parameter.*",
category=UserWarning,
)
try:
logger.info("Training - start")
trainer.train(resume_from_checkpoint=training_config.resume_from_checkpoint)
logger.info("Training - complete")
if trainer.accelerator.is_main_process:
assert trainer.args.output_dir is not None
dir_inference = os.path.join(trainer.args.output_dir, "inference")
trainer.model.encoder.save_pretrained(dir_inference)
subprocess.run(
["gcloud", "storage", "cp", "-r", "wandb", f"{run_gcs_dir}/wandb"],
check=True,
)
subprocess.run(
["gcloud", "storage", "rsync", "-r", dir_inference, f"{run_gcs_dir}/inference"],
check=True,
)
logger.info(f"Uploaded wandb artifacts and model to {run_gcs_dir}")
finally:
# So the eval poller always stops polling, exists, and then the instance shuts down
if eval_was_launched:
subprocess.run(
["gcloud", "storage", "cp", "-", f"{run_gcs_dir}/{gt.sentinels.TRAINING_DONE}"],
input=b"",
check=False,
)
if __name__ == "__main__":
tapify(run)