Skip to content

Commit 3d57b85

Browse files
committed
Support gradient clipping. Simplify loss normalization (sum losses over chunks, normalize at end) (#94)
1 parent e2e9c47 commit 3d57b85

13 files changed

Lines changed: 310 additions & 216 deletions

keys_values/data/iterators.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,6 @@ def __init__(
203203
):
204204
assert micro_batch_size >= 1
205205
assert num_devices >= 1
206-
if micro_batch_size == 1 and num_devices == 1:
207-
raise ValueError(
208-
"This sampler requires micro_batch_size > 1 or num_devices > 1"
209-
)
210206
if shortest_first and longest_first:
211207
raise ValueError("Cannot set both shortest_first and longest_first")
212208
if num_devices > 1:

keys_values/data/load_helmet_dev_eval.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def load_rag(
456456
"8k": "kilt/popqa_test_1000_k50_dep6.jsonl",
457457
},
458458
} # the load paths can only be stored in this way, as they are hard-coded from the original code
459-
instruction_template = get_instruction_template(dataset_key)
459+
instruction_template = get_instruction_template(dataset_key)[0]
460460

461461
instance_path = str(
462462
Path(dataset_parent_dir) / Path(file_paths[dataset_key][max_length])
@@ -589,7 +589,7 @@ def load_cited_generation(
589589
"16k": 75,
590590
"8k": 30,
591591
}
592-
instruction_template = get_instruction_template(dataset_key)
592+
instruction_template = get_instruction_template(dataset_key)[0]
593593
demo_template = "Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing a document, surround its ID with square brackets, such as [x] to cite document x. To cite multiple documents, simply concatenate the citation markers; for example, use [x][y][z] to cite the documents with ID x, y, and z. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.\n\nQuestion: {question}\n\n{context}\n\nAnswer: {answer}"
594594
doc_template = "Document [{ID}](Title: {title}): {text}"
595595

@@ -658,7 +658,7 @@ def load_rerank(
658658
"8k": "msmarco/test_reranking_data_k50_dep3.jsonl",
659659
},
660660
}
661-
instruction_template = get_instruction_template(dataset_key)
661+
instruction_template = get_instruction_template(dataset_key)[0]
662662

663663
instance_path = str(
664664
Path(dataset_parent_dir) / Path(file_paths[dataset_key][max_length])
@@ -765,7 +765,7 @@ def load_icl(
765765
"banking77": 77,
766766
"clinc150": 151,
767767
}
768-
instruction_template = get_instruction_template(dataset_key)
768+
instruction_template = get_instruction_template(dataset_key)[0]
769769
demo_template = "{text}\nlabel: {label}"
770770

771771
if dataset_key == "trec_coarse":
@@ -938,7 +938,7 @@ def load_long_doc_qa(
938938
)
939939
eval_questions_num = 100
940940

941-
instruction_template = get_instruction_template(dataset_key)
941+
instruction_template = get_instruction_template(dataset_key)[0]
942942
if dataset_key == "narrative_qa":
943943
all_data = load_dataset("narrativeqa")
944944
instance_data = all_data["test"].shuffle(seed=seed)
@@ -1114,7 +1114,7 @@ def load_summarization(
11141114
"meta-llama/Llama-2-7b-hf", token=HF_TOKEN
11151115
)
11161116

1117-
instruction_template = get_instruction_template(dataset_key)
1117+
instruction_template = get_instruction_template(dataset_key)[0]
11181118
if dataset_key == "infinite_bench_sum":
11191119
eval_questions_num = 50 # different from HELMET
11201120
ft = Features(
@@ -1302,7 +1302,7 @@ def load_synthetic(
13021302
)
13031303
instance_data = load_dataset("json", data_files=data_path)["train"]
13041304

1305-
instruction_template = get_instruction_template(dataset_key)
1305+
instruction_template = get_instruction_template(dataset_key)[0]
13061306
if dataset_key == "json_kv":
13071307
demo_template = "Key: {key}\nCorresponding value:{value}"
13081308

keys_values/finetune/args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,14 @@ class TrainArgs:
421421
Args:
422422
intermed_save_interval: See above
423423
intermed_save_num: See above
424+
max_grad_norm: If not `None`, we use gradient clipping (so
425+
`torch.nn.utils.clip_grad_norm_`) with this maximum norm.
426+
Defaults to 1.0.
427+
average_loss_per_batch: If `True`, the sum of loss values for a batch
428+
is normalized by the number of non-masked target tokens in that
429+
batch. Otherwise (`False`, the default), we average the sum of loss
430+
values per data case (by the number of non-masked target tokens),
431+
then use the uniform average over the batch.
424432
"""
425433

426434
save_interval: Optional[int] = 1000
@@ -450,6 +458,8 @@ class TrainArgs:
450458
"""Whether to tie the embedding weights with the language modeling head weights"""
451459
intermed_save_interval: Optional[int] = None
452460
intermed_save_num: Optional[int] = None
461+
max_grad_norm: Optional[float] = 1.0
462+
average_loss_per_batch: Optional[bool] = False
453463

454464
def __post_init__(self) -> None:
455465
if self.lr_warmup_fraction and self.lr_warmup_steps:
@@ -492,6 +502,8 @@ def __post_init__(self) -> None:
492502
raise ValueError(
493503
"intermed_save_num only needed if intermed_save_interval is given"
494504
)
505+
if self.max_grad_norm is not None and self.max_grad_norm <= 0:
506+
raise ValueError("max_grad_norm must be positive (or `None` to disable)")
495507

496508
def gradient_accumulation_iters(self, devices: int, num_nodes: int = 1) -> int:
497509
"""Number of iterations between gradient synchronizations"""

keys_values/finetune/longcon_offload_full.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def setup(
4949
max_seq_length=None,
5050
intermed_save_interval=None,
5151
intermed_save_num=None,
52+
max_grad_norm=1.0,
5253
),
5354
eval: EvalArgs = EvalArgs(
5455
interval=100,

keys_values/finetune/longcon_offload_lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def setup(
5050
max_seq_length=None,
5151
intermed_save_interval=None,
5252
intermed_save_num=None,
53+
max_grad_norm=1.0,
5354
),
5455
lora: LoRAArgs = LoRAArgs(
5556
r=8,

keys_values/finetune/longcontext_eval.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def main(
419419
attention_backward_temp_size_gb=None,
420420
max_batch_size=batch_size,
421421
dtype=dtype,
422+
average_loss_per_batch=False,
422423
fabric=fabric,
423424
model_kwargs=None,
424425
)
@@ -502,16 +503,17 @@ def main(
502503
t0 = time.perf_counter()
503504
# One entry per batch dimension:
504505
input_ids = batch[INPUT_IDS_NAME]
506+
targets = batch["targets"]
505507
if evaluator is None:
506508
with torch.no_grad():
507-
metric_values = model(input_ids, batch["targets"])
509+
metric_values = model(input_ids, targets)
508510
metric_name = "eval_loss"
509511
else:
510512
metric_name = evaluator.metrics[0]
511-
targets = batch[TARGETS_STRINGS_NAME]
512-
prompt_len = input_ids.shape[1] - batch["targets"].shape[1] + 1
513+
prompt_len = input_ids.shape[1] - targets.shape[1] + 1
513514
prompts = input_ids[:, :prompt_len]
514-
metric_values = evaluator(model, prompts, targets)[metric_name]
515+
raw_targets = batch[TARGETS_STRINGS_NAME]
516+
metric_values = evaluator(model, prompts, raw_targets)[metric_name]
515517
eval_time = time.perf_counter() - t0
516518
print_with_rank_and_timestamp(
517519
f"Batch {task}, {orig_idxs}: {metric_name} = {metric_values.mean().item():.3f}, eval_time = {eval_time * 1000:.2f} ms",

0 commit comments

Comments
 (0)