Skip to content

Commit c8f7b56

Browse files
Add additional asserts and update post training readme (#1300)
* add asserts and fix post training readme * precommit --------- Co-authored-by: Quentin Anthony <[email protected]>
1 parent 774eb58 commit c8f7b56

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

megatron/training.py

+10
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator):
406406
datatype=datatype,
407407
)
408408
elif neox_args.train_impl == "kto":
409+
assert (
410+
neox_args.train_micro_batch_size_per_gpu > 1
411+
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
409412
tup = _get_batch(
410413
neox_args=neox_args,
411414
tokenizer=neox_args.tokenizer,
@@ -459,6 +462,13 @@ def get_batch(neox_args, data_iterator):
459462

460463
def get_batch_pipe(data, neox_args, curr_scheduler=None):
461464
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
465+
466+
assert neox_args.train_impl not in [
467+
"kto",
468+
"dpo",
469+
"rm",
470+
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"
471+
462472
# Items and their type.
463473
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
464474
datatype = torch.int64

post-training/README.md

-2
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,13 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis
3434

3535
## SFT data
3636
```bash
37-
python post-training/llama_dpo_data.py
3837
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
3938
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
4039
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
4140
```
4241

4342
## KTO data
4443
```bash
45-
python post-training/llama_dpo_data.py
4644
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
4745
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
4846
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward

0 commit comments

Comments
 (0)