Skip to content

Commit 3dfc34c

Browse files
committed
Updating CONTRIBUTING.md. Removing DEBUG, TODO if outdated. Make LongContextInference ready for token generation
1 parent ef29cb7 commit 3dfc34c

31 files changed

Lines changed: 402 additions & 329 deletions

CONTRIBUTING.md

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
# Contributing Guidelines
22

3-
Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4-
documentation, we greatly value feedback and contributions from our community.
3+
Thank you for your interest in contributing to our project. Whether it's a bug
4+
report, a new feature, an integration into existing long context libraries, or
5+
additional documentation, we greatly value feedback and contributions from our
6+
community.
57

6-
Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7-
information to effectively respond to your bug report or contribution.
8+
Please read through this document before submitting any issues or pull requests
9+
to ensure we have all the necessary information to effectively respond to your
10+
bug report or contribution.
811

912

1013
## Reporting Bugs/Feature Requests
1114

1215
We welcome you to use the GitHub issue tracker to report bugs or suggest features.
1316

14-
When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15-
reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
17+
When filing an issue, please check existing open, or recently closed, issues to
18+
make sure somebody else hasn't already reported the issue. Please try to include
19+
as much information as you can. Details like these are incredibly useful:
1620

1721
* A reproducible test case or series of steps
1822
* The version of our code being used
@@ -21,27 +25,38 @@ reported the issue. Please try to include as much information as you can. Detail
2125

2226

2327
## Contributing via Pull Requests
24-
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
28+
Contributions via pull requests are much appreciated. Before sending us a pull
29+
request, please ensure that:
2530

2631
1. You are working against the latest source on the *main* branch.
27-
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28-
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
32+
2. You check existing open, and recently merged, pull requests to make sure
33+
someone else hasn't addressed the problem already. If the pull request is open,
34+
feel free to add a comment to it, expressing your interest.
35+
3. You open an issue to discuss any significant work - we would hate for your
36+
time to be wasted.
2937

3038
To send us a pull request, please:
3139

3240
1. Fork the repository.
33-
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34-
3. Ensure local tests pass.
41+
2. Modify the source; please focus on the specific change you are contributing.
42+
If you also reformat all the code, it will be hard for us to focus on your
43+
change.
44+
3. Ensure that all local tests pass.
3545
4. Commit to your fork using clear commit messages.
36-
5. Send us a pull request, answering any default questions in the pull request interface.
37-
6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
46+
5. Send us a pull request, answering any default questions in the pull request
47+
interface.
48+
6. Pay attention to any automated CI failures reported in the pull request, and
49+
stay involved in the conversation.
3850

3951
GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
4052
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
4153

4254

4355
## Finding contributions to work on
44-
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
56+
Looking at the existing issues is a great way to find something to contribute
57+
on. As our projects, by default, use the default GitHub issue labels
58+
(enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at
59+
any 'help wanted' issues is a great place to start.
4560

4661

4762
## Code of Conduct
@@ -51,9 +66,12 @@ opensource-codeofconduct@amazon.com with any additional questions or comments.
5166

5267

5368
## Security issue notifications
54-
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
69+
If you discover a potential security issue in this project we ask that you notify
70+
AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/).
71+
Please do **not** create a public github issue.
5572

5673

5774
## Licensing
5875

59-
See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
76+
See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you
77+
to confirm the licensing of your contribution.

keys_values/data/evaluation.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,6 @@ def _wrapped_collate_fn(
233233
task = next(iter(tasks))
234234
orig_collated_samples = orig_collate_fn(samples)
235235
orig_idxs = [elem[ORIG_IDX_NAME] for elem in samples]
236-
# DEBUG
237-
#print(f"*** evaluation._wrapped_collate_fn: {orig_idxs} ({task})")
238-
#offset = samples[0]["prefix_len"]
239-
#print(orig_collated_samples["input_ids"][:, offset:(offset + 15)])
240-
# END DEBUG
241236
return {
242237
**orig_collated_samples,
243238
ORIG_IDX_NAME: orig_idxs,

keys_values/data/longbench_v2.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -648,13 +648,6 @@ def filter_and_transform(
648648
min_length = test_results[0]["num_tokens_instruction"]
649649
max_length = test_results[-1]["num_tokens_instruction"]
650650
print(f"Test dataset has {len(test_results)} records, token lengths between {min_length} and {max_length}")
651-
# DEBUG
652-
#prefix_len = len(tokenizer.encode("\n".join(PROMPTLINES_PREFIX) + "\n"))
653-
#test_results = [
654-
# dict(entry, prefix_len=prefix_len)
655-
# for entry in test_results
656-
#]
657-
# END DEBUG
658651
else:
659652
test_results = None
660653
if seq_lengths is None:

keys_values/data/sequence_classification.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def __getitem__(self, idx: int) -> Dict[str, Union[Tensor, Dict[str, Any]]]:
102102
return {
103103
INPUT_IDS_NAME: encoded_prompt,
104104
LABELS_NAME: label_idx,
105-
#"prefix_len": example["prefix_len"], # DEBUG!!
106105
"token_counts": token_counts,
107106
}
108107

keys_values/finetune/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ class KVCacheArgs:
6464
dtype. The default is delayed allocation with first usage
6565
6666
"""
67-
name: str # TODO: Different per layer
68-
cache_length: int # TODO: Different per layer
67+
name: str
68+
cache_length: int
6969
chunk_size: int = 16
7070
cache_kwargs: Optional[Dict[str, Any]] = None
7171
randomize_chunk_sizes: bool = False

keys_values/finetune/longcon_offload_full.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
print_with_rank_and_timestamp,
7474
print_message,
7575
check_kv_cache,
76+
adapt_requires_grad,
7677
)
7778
from keys_values.gpu_memory import RecordGPUMemory
7879
from keys_values.head_model import CrossEntropyOnLogits
@@ -210,7 +211,7 @@ def setup(
210211
- 1: Only record gradient computations (after initial forward). For
211212
each update, we store one snapshot file per row of cells being
212213
processed.
213-
- 2: Special case (DEBUG)
214+
- 2: Special case
214215
- 3: One snapshot file during initial validation
215216
Defaults to 0.
216217
record_gpu_memory_period: Only if `record_gpu_memory_snapshots` is used.
@@ -278,7 +279,7 @@ def setup(
278279
config = Config.from_file(checkpoint_dir / "model_config.yaml")
279280

280281
precision = precision or get_default_supported_precision(training=True)
281-
# TODO: Currently not used!
282+
# Currently not used:
282283
logger = choose_logger(
283284
logger_name,
284285
out_dir,
@@ -409,6 +410,7 @@ def main(
409410
**head_model_kwargs,
410411
)
411412
gpt_model = gpt_model.to(optim_device)
413+
adapt_requires_grad(gpt_model, head_model)
412414
batch_size = train.micro_batch_size
413415
if eval.micro_batch_size is not None:
414416
batch_size = max(batch_size, eval.micro_batch_size)
@@ -782,7 +784,6 @@ def fit(
782784
)
783785
else:
784786
generate_example_kwargs = None
785-
# TODO: Fix bug in generation!
786787
valid_model = model.copy_model_for_evaluation()
787788
metrics = validate_and_all_reduce(
788789
model=valid_model,

keys_values/finetune/longcon_offload_lora.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
print_with_rank_and_timestamp,
8080
print_message,
8181
check_kv_cache,
82+
adapt_requires_grad,
8283
)
8384
from keys_values.head_model import CrossEntropyOnLogits
8485
from keys_values.head_model_factory import HeadModelFactory
@@ -230,7 +231,7 @@ def setup(
230231
- 1: Only record gradient computations (after initial forward). For
231232
each update, we store one snapshot file per row of cells being
232233
processed.
233-
- 2: Special case (DEBUG)
234+
- 2: Special case
234235
- 3: One snapshot file during initial validation
235236
Defaults to 0.
236237
record_gpu_memory_period: Only if `record_gpu_memory_snapshots` is used.
@@ -309,7 +310,7 @@ def setup(
309310
)
310311

311312
precision = precision or get_default_supported_precision(training=True)
312-
# TODO: Currently not used!
313+
# Currently not used:
313314
logger = choose_logger(
314315
logger_name,
315316
out_dir,
@@ -440,6 +441,8 @@ def main(
440441
**head_model_kwargs,
441442
)
442443
gpt_model = gpt_model.to(optim_device)
444+
mark_only_lora_as_trainable(gpt_model)
445+
adapt_requires_grad(gpt_model, head_model)
443446
batch_size = train.micro_batch_size
444447
if eval.micro_batch_size is not None:
445448
batch_size = max(batch_size, eval.micro_batch_size)
@@ -456,7 +459,6 @@ def main(
456459
profile_parts=profile_parts,
457460
cpu_offload_device=device,
458461
)
459-
mark_only_lora_as_trainable(model.gpt_model)
460462

461463
num_trainable_params = num_parameters(model, requires_grad=True)
462464
print_message(f"\nNumber of trainable parameters: {num_trainable_params:,}")
@@ -820,7 +822,6 @@ def fit(
820822
)
821823
else:
822824
generate_example_kwargs = None
823-
# TODO: Fix bug in generation!
824825
valid_model = model.copy_model_for_evaluation()
825826
metrics = validate_and_all_reduce(
826827
model=valid_model,

keys_values/finetune/longcontext_eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def main(
220220
eos_id=tokenizer.eos_id,
221221
ignore_index=ignore_index,
222222
)
223-
print(f"\ntokenizer.eos_id = {tokenizer.eos_id}\n") # DEBUG!
224223

225224
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
226225

keys_values/finetune/longcontext_full.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from keys_values.utils import flush_io_streams
3333
from litgpt.args import TrainArgs
3434
from litgpt.data import DataModule
35-
from litgpt.generate.base import generate
3635
from litgpt.config import Config
3736
from litgpt.prompts import save_prompt_style
3837
from litgpt.tokenizer import Tokenizer
@@ -80,6 +79,7 @@
8079
print_message,
8180
check_kv_cache,
8281
)
82+
from keys_values.generate.base import generate
8383
from keys_values.gpu_memory import RecordGPUMemory
8484
from keys_values.head_model import HeadModel, CrossEntropyOnLogits
8585
from keys_values.head_model_factory import HeadModelFactory
@@ -169,6 +169,7 @@ def setup(
169169
record_gpu_memory_snapshots: Optional[int] = None,
170170
record_gpu_memory_kind: int = 0,
171171
record_gpu_memory_period: int = 0,
172+
generate_with_eval: bool = False,
172173
profile_grad_times: int = 0,
173174
profile_parts: Optional[str] = None,
174175
) -> None:
@@ -365,6 +366,7 @@ def setup(
365366
record_gpu_memory_snapshots=record_gpu_memory_snapshots,
366367
record_gpu_memory_kind=record_gpu_memory_kind,
367368
record_gpu_memory_period=record_gpu_memory_period,
369+
generate_with_eval=generate_with_eval,
368370
profile_grad_times=profile_grad_times,
369371
profile_parts=profile_parts,
370372
)
@@ -394,6 +396,7 @@ def main(
394396
record_gpu_memory_snapshots: Optional[RecordGPUMemory],
395397
record_gpu_memory_kind: int,
396398
record_gpu_memory_period: int,
399+
generate_with_eval: bool,
397400
profile_grad_times: int,
398401
profile_parts: Optional[str],
399402
) -> None:
@@ -540,6 +543,7 @@ def main(
540543
record_gpu_memory_snapshots=record_gpu_memory_snapshots,
541544
record_gpu_memory_kind=record_gpu_memory_kind,
542545
record_gpu_memory_period=record_gpu_memory_period,
546+
generate_with_eval=generate_with_eval,
543547
profile_grad_params=profile_grad_params,
544548
)
545549
training_time = time.perf_counter() - train_time
@@ -550,13 +554,21 @@ def main(
550554
if eval.final_validation:
551555
print_with_rank_and_timestamp("Starting validation evaluations.", fabric.global_rank)
552556
print_message("\nFinal validation evaluation ...", fabric)
557+
if generate_with_eval:
558+
generate_example_kwargs = dict(
559+
tokenizer=tokenizer,
560+
data=data,
561+
)
562+
else:
563+
generate_example_kwargs = None
553564
metrics = validate_and_all_reduce(
554565
model=model,
555566
val_dataloader=val_dataloader,
556567
eval=dataclasses.replace(eval, max_iters=len(val_dataloader)),
557568
batch_transform=batch_transform,
558569
log_metrics=False,
559570
fabric=fabric,
571+
generate_example_kwargs=generate_example_kwargs,
560572
)
561573
fabric.log_dict(metrics, step=state["iter_num"])
562574
print_message(
@@ -576,7 +588,6 @@ def main(
576588
save_prompt_style(data.prompt_style, save_dir)
577589

578590

579-
# TODO: Support caches of different lengths, maybe even different types
580591
def wrap_gpt_model(
581592
gpt_model: GPT,
582593
head_model: HeadModel,
@@ -723,6 +734,7 @@ def fit(
723734
record_gpu_memory_snapshots: Optional[RecordGPUMemory],
724735
record_gpu_memory_kind: int,
725736
record_gpu_memory_period: int,
737+
generate_with_eval: bool,
726738
profile_grad_params: Optional[Dict[str, Any]],
727739
) -> Dict[str, Any]:
728740
model = state["model"]
@@ -740,12 +752,20 @@ def fit(
740752
if eval.initial_validation:
741753
print_with_rank_and_timestamp("Starting validation evaluations.", fabric.global_rank)
742754
print_message("\nInitial validation evaluation ...", fabric)
755+
if generate_with_eval:
756+
generate_example_kwargs = dict(
757+
tokenizer=tokenizer,
758+
data=data,
759+
)
760+
else:
761+
generate_example_kwargs = None
743762
metrics = validate_and_all_reduce(
744763
model=model,
745764
val_dataloader=val_dataloader,
746765
eval=dataclasses.replace(eval, max_iters=len(val_dataloader)),
747766
batch_transform=batch_transform,
748767
fabric=fabric,
768+
generate_example_kwargs=generate_example_kwargs,
749769
)
750770
val_loss = f"{metrics['val_loss']:.3f}"
751771
print_message(
@@ -905,17 +925,19 @@ def fit(
905925
if not is_accumulating and state["step_count"] % eval.interval == 0:
906926
print_with_rank_and_timestamp("Starting validation evaluations.", fabric.global_rank)
907927
print_message("\nPeriodic validation evaluation ...", fabric)
908-
generate_example_kwargs = dict(
909-
tokenizer=tokenizer,
910-
data=data,
911-
)
912-
# TODO: Fix bug in generation!
928+
if generate_with_eval:
929+
generate_example_kwargs = dict(
930+
tokenizer=tokenizer,
931+
data=data,
932+
)
933+
else:
934+
generate_example_kwargs = None
913935
metrics = validate_and_all_reduce(
914936
model=model,
915937
val_dataloader=val_dataloader,
916938
eval=eval,
917939
batch_transform=batch_transform,
918-
# generate_example_kwargs=generate_example_kwargs,
940+
generate_example_kwargs=generate_example_kwargs,
919941
log_metrics=False,
920942
fabric=fabric,
921943
)
@@ -1041,7 +1063,7 @@ def generate_example(
10411063

10421064
if max_returned_tokens < gpt_model.max_seq_length:
10431065
output = generate(
1044-
model=gpt_model,
1066+
model=model,
10451067
prompt=encoded,
10461068
max_returned_tokens=max_returned_tokens,
10471069
temperature=0.8,

0 commit comments

Comments
 (0)