Skip to content

Pccl integration #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: prime-v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
5a45af8
small cleanup
mikex86 Apr 4, 2025
e96c5e7
introduce mpi info and allow non-mpi runs
mikex86 Apr 4, 2025
f5263f8
working fsdp with pccl accept loop
mikex86 Apr 5, 2025
1cd3764
working sync DiLoCo
mikex86 Apr 5, 2025
43be48c
working async DiLoCo
mikex86 Apr 5, 2025
6e60343
introduce functions for sanity
mikex86 Apr 6, 2025
471f640
configurable async/non-async DiLoCo
mikex86 Apr 7, 2025
2a842ad
implemented nibble dataset
mikex86 Apr 8, 2025
12351cb
fix bug where outer lr is not set
mikex86 Apr 8, 2025
efe58d9
fix configs & unit tests
mikex86 Apr 8, 2025
c74d5b2
fix ruff
mikex86 Apr 8, 2025
eaeec2a
fix ruff
mikex86 Apr 8, 2025
1160e49
clone pccl dependency via git instead of https
mikex86 Apr 8, 2025
8bd560d
fix pccl git url
mikex86 Apr 8, 2025
f9e37ea
backported ParquetDataset
mikex86 Apr 9, 2025
906517e
fix ruff
mikex86 Apr 9, 2025
22c7249
fix pending mpi ranks join wait logic
mikex86 Apr 11, 2025
09b515b
add launch prime script & add nibble ds folder listing support
mikex86 Apr 13, 2025
a072ba2
add streaming data loader support
mikex86 Apr 14, 2025
8867b06
fix fake data loader
mikex86 Apr 14, 2025
bf39f61
fix for 8xH100
mikex86 Apr 14, 2025
7f6d0af
add H100 config
mikex86 Apr 14, 2025
1762ce2
add topology optimization
mikex86 Apr 14, 2025
25b27ce
change config of launch_prime.sh
mikex86 Apr 14, 2025
bebdd0f
small changes
mikex86 Apr 18, 2025
cb20ff1
fix training_progress.step
mikex86 Apr 18, 2025
b9615ac
log outer step
mikex86 Apr 18, 2025
02647c0
fix incompetence
mikex86 Apr 18, 2025
f89e4a9
set step from shared state synced var exactly post shared state sync
mikex86 Apr 18, 2025
4573f0d
fix NCCL crash on some Lambda nodes
mikex86 Apr 21, 2025
12f89c4
enable topology optimization
mikex86 Apr 22, 2025
30ff670
utilize SharedStateSyncStrategy
mikex86 Apr 23, 2025
933ba5b
fix typo
mikex86 Apr 24, 2025
fee98cd
bump pccl commit revision
mikex86 Apr 28, 2025
2ba792c
bump pccl commit revision
mikex86 Apr 28, 2025
387f0bd
bump pccl commit revision
mikex86 Apr 28, 2025
00b9724
bump pccl commit revision
mikex86 Apr 29, 2025
3f8f849
fix config dataset
samsja Apr 29, 2025
4060e5f
add logging pccl
samsja Apr 29, 2025
54245d5
add diloco delayed default vlaue
samsja Apr 30, 2025
a08b87e
fix fake data
samsja Apr 30, 2025
01e63a9
bump pccl commit revision
mikex86 May 1, 2025
3d47532
Add dtype argument
mikex86 May 4, 2025
4b170e1
Add missing dtype argument
mikex86 May 4, 2025
4be8810
bump pccl commit revision
mikex86 May 5, 2025
1642298
revert
mikex86 May 5, 2025
60d83bd
fix unused imports
mikex86 May 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions configs/10B/H100_intellect1.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,34 @@ betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1

[train.outer_lr_scheduler]
lr = 0.7
end_lr = 0.7
num_decay_steps = 0
num_warmup_steps = 0
num_stable_steps = 0

[train.outer_optimizer]
type = "sgd"
momentum = 0.9
nesterov = true

[data]
seq_length = 8192
dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular,/data/datasets/dclm-baseline-1.0-parquet,/data/datasets/open-web-math"
dataset_name_or_paths = "/data/datasets/fineweb-edu.bin,/data/datasets/fineweb.bin,/data/datasets/StackV1-popular.bin,/data/datasets/dclm-baseline-1.0-parquet.bin,/data/datasets/open-web-math.bin"
token_bit_size = 17
dataset_ratio = "55:10:20:10:5"
num_workers = 4
reverse_data_files = true
split_by_data_rank = false # the 10b training assume that data was already split by datarank. Keeping this for backward compatibility


[diloco]
inner_steps = 100
compression = "uint8"
delayed_update = true

[ckpt]
interval = 100
path = "/data/10B"

[pccl]
ccoip_host = "127.0.0.1:48148"
4 changes: 4 additions & 0 deletions configs/10B/H100_simple.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ weight_decay = 0.1
[data]
seq_length = 8192
num_workers = 4
fake = true

[pccl]
ccoip_host = "127.0.0.1:48148"
3 changes: 3 additions & 0 deletions configs/13B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ batch_size = 512
[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"

[pccl]
ccoip_host = "127.0.0.1:48148"
28 changes: 23 additions & 5 deletions configs/150M/A100_debug.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
project = "debug_150m_zero_band"

model_name = "150M"
model_type = "llama2"
model_type = "llama3"

wandb = false
log_all_ranks = true

[hardware]
micro_batch_size = 64
micro_batch_size = 32
torch_compile = true

[train]
Expand All @@ -16,6 +15,25 @@ batch_size = 512
num_warmup_steps = 10
num_decay_steps = 1000

[train.outer_lr_scheduler]
lr = 1.0
end_lr = 1.0
num_decay_steps = 0
num_warmup_steps = 0
num_stable_steps = 0

[train.outer_optimizer]
type = "sgd"

[data]
fake = true
dataset_name_or_paths = 'tests/test_data/parquet/parquet_ds_folder_1,tests/test_data/parquet/parquet_ds_folder_2'
#dataset_name_or_paths = '/home/mike/IntelliJProjects/dataproctest/working_dir/train_0.bin,/home/mike/IntelliJProjects/dataproctest/working_dir/train_1.bin'
dataset_ratio = "50:50"
token_bit_size = 17

[diloco]
inner_steps = 16
delayed_update = true

[pccl]
ccoip_host = "127.0.0.1:48148"
3 changes: 3 additions & 0 deletions configs/150M/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ batch_size = 512
[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"

[pccl]
ccoip_host = "127.0.0.1:48148"
3 changes: 3 additions & 0 deletions configs/150M/H100_best.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ weight_decay = 0.24530252977858977
[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"

[pccl]
ccoip_host = "127.0.0.1:48148"
3 changes: 3 additions & 0 deletions configs/1B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ batch_size = 512
[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"

[pccl]
ccoip_host = "127.0.0.1:48148"
7 changes: 7 additions & 0 deletions configs/70M/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ batch_size = 512
[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"

[diloco]
inner_steps = 16
delayed_update = true

[pccl]
ccoip_host = "127.0.0.1:48148"
39 changes: 33 additions & 6 deletions configs/7B/H100.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,42 @@
project = "debug_7B_zero_band"

model_name = "7B"
model_type = "llama2"
model_type = "llama3"

wandb = true
log_all_ranks = true

[hardware]
micro_batch_size = 64
reshard_after_forward = false
micro_batch_size = 8
reshard_after_forward = true
torch_compile = false
attn_fn="sdpa"

[train]
batch_size = 512

[train.lr_scheduler]
lr = 3e-4
end_lr = 0.0
num_warmup_steps = 8000
num_decay_steps = 1.2e6

[train.outer_lr_scheduler]
lr = 0.7
end_lr = 0.7
num_decay_steps = 0
num_warmup_steps = 0
num_stable_steps = 0

[train.outer_optimizer]
type = "sgd"

[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"
dataset_name_or_paths = 'http://65.108.32.176:8080/api/v1/datasets/fineweb-edu-train/stream'
token_bit_size = 17

[diloco]
inner_steps = 64
delayed_update = true

[pccl]
ccoip_host = "127.0.0.1:48148"
33 changes: 26 additions & 7 deletions configs/debug/diloco.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
model_name = "debugmodel"
model_type = "llama2"
model_name = "150M"
model_type = "llama3"

wandb = false
log_all_ranks = true

[hardware]
micro_batch_size = 8
micro_batch_size = 32
torch_compile = true

[train]
batch_size = 16
batch_size = 512

[train.lr_scheduler]
num_warmup_steps = 10
num_decay_steps = 10
num_decay_steps = 1000

[train.outer_lr_scheduler]
lr = 1.0
end_lr = 1.0
num_decay_steps = 0
num_warmup_steps = 0
num_stable_steps = 0

[train.outer_optimizer]
type = "sgd"

[data]
fake = true
#dataset_name_or_paths = 'tests/test_data/parquet/parquet_ds_folder_1,tests/test_data/parquet/parquet_ds_folder_2'
dataset_name_or_paths = '/home/mike/IntelliJProjects/dataproctest/working_dir/train'
#dataset_ratio = "50:50"
token_bit_size = 17

[diloco]
inner_steps = 5
inner_steps = 16
delayed_update = true

[pccl]
ccoip_host = "127.0.0.1:48148"
4 changes: 3 additions & 1 deletion configs/debug/normal.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ wandb = false
micro_batch_size = 8

[train]
batch_size = 16
batch_size = 32

[train.lr_scheduler]
num_warmup_steps = 10
Expand All @@ -16,3 +16,5 @@ num_decay_steps = 10
[data]
fake = true

[pccl]
ccoip_host = "127.0.0.1:48148"
3 changes: 3 additions & 0 deletions configs/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ num_warmup_steps = 1000
lr = 3e-4
end_lr = 0.0
num_decay_steps = 80000

[pccl]
ccoip_host = "127.0.0.1:48148"
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ dependencies = [
"transformers>=4.44.2",
"datasets>=3.0.0",
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@b7becc3",
"tomli",
"torchdata>=0.8.0",
"fsspec[gcs]>=2024.3.1",
"ninja",
"zstandard",
"pyarrow",
"psutil",
"wandb",
"imageio[ffmpeg]"
"numba",
"imageio[ffmpeg]",
"pccl @ git+https://github.com/PrimeIntellect-ai/pccl.git@main#subdirectory=python/framework",
"datasetstream @ git+https://github.com/PrimeIntellect-ai/datasetstream@main"
]

[project.optional-dependencies]
Expand All @@ -37,4 +41,4 @@ allow-direct-references = true # allow direct references to git repos in depende
line-length = 120

[tool.uv]
dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0","pytest>=7.0.0", "faker", "matplotlib"]
dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0", "pytest>=7.0.0", "faker", "matplotlib", "transformers"]
24 changes: 0 additions & 24 deletions scripts/bandwith/down.sh

This file was deleted.

1 change: 0 additions & 1 deletion scripts/bandwith/up.sh

This file was deleted.

35 changes: 0 additions & 35 deletions scripts/convert_dl_ckpt.sh

This file was deleted.

Loading
Loading