-
Notifications
You must be signed in to change notification settings - Fork 79
Pccl integration #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: prime-v2
Are you sure you want to change the base?
Pccl integration #241
Conversation
|
||
set -e | ||
|
||
# Colors for output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets keep this file
@@ -31,3 +42,46 @@ def set_optimizer_lr(optimizer: torch.optim.Optimizer, lr: float): | |||
""" | |||
for param_group in optimizer.param_groups: | |||
param_group['lr'] = lr | |||
|
|||
|
|||
OptimT = TypeVar("OptimT", bound=torch.optim.Optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OptimT = TypeVar("OptimT", bound=torch.optim.Optimizer) | |
OptimT : TypAlias = TypeVar("OptimT", bound=torch.optim.Optimizer) |
:param optimizer_type the type of optimizer used. | ||
""" | ||
|
||
def _validate_exists(to_check: List[Tuple[str, Optional[torch.Tensor]]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _validate_exists(to_check: List[Tuple[str, Optional[torch.Tensor]]]): | |
def _validate_exists(to_check: list[tuple[str, torch.Tensor | None]]): |
hf_name="mistralai/Mistral-7B-v0.1", | ||
# print(len(AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True))) | ||
vocab_size=32000, | ||
# print(AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True).bos_token_id) | ||
bot_token=1, | ||
# print(AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True).eos_token_id) | ||
eot_token=2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print
return TokenizerInfo( | ||
hf_name="meta-llama/Meta-Llama-3-8B", | ||
# print(len(AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True))) | ||
vocab_size=128256, | ||
# print(AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True).bos_token_id) | ||
bot_token=128000, | ||
# print(AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True).eos_token_id) | ||
eot_token=128001, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we not want to tell people how to re-obtain this number easily?
import copy | ||
import torch | ||
from zeroband.data import InterleaveDataset, ParquetDataset, SequencePackingDataSet, collate_fn | ||
from torch.utils.data import DataLoader | ||
from zeroband.data import load_all_datasets, DataConfig | ||
from zeroband.utils.logger import get_logger | ||
from collections import Counter | ||
from itertools import chain | ||
import pytest | ||
import logging | ||
import pyarrow as pa | ||
import pyarrow.parquet as pq | ||
from faker import Faker | ||
from typing import List | ||
import string | ||
from torchdata.stateful_dataloader import StatefulDataLoader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing the sequence packing tests ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to re-add them, but they were incompatible post port.
|
||
if __name__ == '__main__': | ||
pytest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
Draft, not ready to merge yet.