Skip to content

Commit f7a67c5

Browse files
authored
v2.1.18 (#662)
2 parents 322bf3d + 2476012 commit f7a67c5

File tree

11 files changed

+266
-47
lines changed

11 files changed

+266
-47
lines changed

docs/miner.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ This guide will help you set up and run a miner for **τemplar**. We'll cover bo
173173
# Dataset R2 credentials - You may set up your own Shared Sharded Dataset, but must at minimum set these keys
174174
# See docs/shared_sharded_dataset.md for instructions
175175
export R2_DATASET_ACCOUNT_ID="8af7f92a8a0661cf7f1ac0420c932980"
176-
export R2_DATASET_BUCKET_NAME="gemma-migration"
177-
export R2_DATASET_READ_ACCESS_KEY_ID="a733fac6c32a549e0d48f9f7cf67d758"
178-
export R2_DATASET_READ_SECRET_ACCESS_KEY="f50cab456587f015ad21c48c3e23c7ff0e6f1ad5a22c814c3a50d1a4b7c76bb9"
176+
export R2_DATASET_BUCKET_NAME="mixed-dataset-migration"
177+
export R2_DATASET_READ_ACCESS_KEY_ID="e70cd26850f697479bbb5fd9413713f4"
178+
export R2_DATASET_READ_SECRET_ACCESS_KEY="11e3364d6ef70e44d671863fb6de32d474aa6220fa2c9c3df45c5e012ebfbda3"
179179
export DATASET_BINS_PATH="tokenized/"
180180

181181

docs/shared_sharded_dataset.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The Shared Sharded dataset is based on the [mlfoundations/dclm-baseline-1.0-parq
1919
For the fastest training, our optimized version includes:
2020

2121
- Pretokenized numpy arrays in .npy files
22-
- Array slicing provided via .bin files
22+
- Sample ID arrays provided via .npy files
2323

2424
## System Requirements
2525

@@ -43,9 +43,9 @@ Append the following env keys:
4343

4444
```bash
4545
R2_DATASET_ACCOUNT_ID=8af7f92a8a0661cf7f1ac0420c932980
46-
R2_DATASET_BUCKET_NAME=gemma-migration
47-
R2_DATASET_READ_ACCESS_KEY_ID=a733fac6c32a549e0d48f9f7cf67d758
48-
R2_DATASET_READ_SECRET_ACCESS_KEY=f50cab456587f015ad21c48c3e23c7ff0e6f1ad5a22c814c3a50d1a4b7c76bb9
46+
R2_DATASET_BUCKET_NAME=mixed-dataset-migration
47+
R2_DATASET_READ_ACCESS_KEY_ID=e70cd26850f697479bbb5fd9413713f4
48+
R2_DATASET_READ_SECRET_ACCESS_KEY=11e3364d6ef70e44d671863fb6de32d474aa6220fa2c9c3df45c5e012ebfbda3
4949
DATASET_BINS_PATH="tokenized/"
5050
```
5151

@@ -92,16 +92,16 @@ Use the CloudFlare migration tool for the easiest setup. Here are the key-value
9292

9393
- Bucket Information
9494
`Source bucket provider`: `S3-Compatible Storage`
95-
`Bucket name`: `gemma-migration`
96-
`S3-compatible endpoint URL`: `https://8af7f92a8a0661cf7f1ac0420c932980.r2.cloudflarestorage.com/gemma-migration`
95+
`Bucket name`: `mixed-dataset-migration`
96+
`S3-compatible endpoint URL`: `https://8af7f92a8a0661cf7f1ac0420c932980.r2.cloudflarestorage.com/mixed-dataset-migration`
9797
- Required Credentials
98-
`Access Key ID`: `a733fac6c32a549e0d48f9f7cf67d758`
99-
`Secret Access Key`: `f50cab456587f015ad21c48c3e23c7ff0e6f1ad5a22c814c3a50d1a4b7c76bb9`
98+
`Access Key ID`: `e70cd26850f697479bbb5fd9413713f4`
99+
`Secret Access Key`: `11e3364d6ef70e44d671863fb6de32d474aa6220fa2c9c3df45c5e012ebfbda3`
100100

101101
#### Page 2
102102

103103
- Select destination R2 bucket
104-
`Bucket name`: `gemma-migration`
104+
`Bucket name`: `mixed-dataset-migration`
105105
`Access Key ID`: your_write_id
106106
`Access Key`: your_secret_write_id
107107
`Overwrite files?`: `Yes, overwrite (recommended)`
@@ -122,8 +122,8 @@ curl https://rclone.org/install.sh | sudo bash
122122
# Configure source (read-only)
123123
rclone config create r2-source s3 \
124124
provider=Cloudflare \
125-
access_key_id=a733fac6c32a549e0d48f9f7cf67d758 \
126-
secret_access_key=f50cab456587f015ad21c48c3e23c7ff0e6f1ad5a22c814c3a50d1a4b7c76bb9 \
125+
access_key_id=e70cd26850f697479bbb5fd9413713f4 \
126+
secret_access_key=11e3364d6ef70e44d671863fb6de32d474aa6220fa2c9c3df45c5e012ebfbda3 \
127127
endpoint=https://8af7f92a8a0661cf7f1ac0420c932980.r2.cloudflarestorage.com \
128128
acl=private
129129

@@ -139,7 +139,7 @@ rclone config create r2-dest s3 \
139139
##### Copy all shards (Full Migration)
140140
```bash
141141
# Copy entire tokenized directory (all shards and sample IDs)
142-
rclone copy r2-source:gemma-migration/tokenized/ r2-dest:<your-bucket-name>/tokenized/ \
142+
rclone copy r2-source:mixed-dataset-migration/tokenized/ r2-dest:<your-bucket-name>/tokenized/ \
143143
--transfers 32 \
144144
--checkers 16 \
145145
--progress
@@ -149,10 +149,10 @@ rclone copy r2-source:gemma-migration/tokenized/ r2-dest:<your-bucket-name>/toke
149149
If you want to test with just the first two shards:
150150
```bash
151151
# Copy first two training shards and their sample IDs
152-
rclone copy r2-source:gemma-migration/tokenized/train_000000.npy r2-dest:<your-bucket-name>/tokenized/ --progress
153-
rclone copy r2-source:gemma-migration/tokenized/train_000001.npy r2-dest:<your-bucket-name>/tokenized/ --progress
154-
rclone copy r2-source:gemma-migration/tokenized/sample_ids_000000.bin r2-dest:<your-bucket-name>/tokenized/ --progress
155-
rclone copy r2-source:gemma-migration/tokenized/sample_ids_000001.bin r2-dest:<your-bucket-name>/tokenized/ --progress
152+
rclone copy r2-source:mixed-dataset-migration/tokenized/train_000000.npy r2-dest:<your-bucket-name>/tokenized/ --progress
153+
rclone copy r2-source:mixed-dataset-migration/tokenized/train_000001.npy r2-dest:<your-bucket-name>/tokenized/ --progress
154+
rclone copy r2-source:mixed-dataset-migration/tokenized/sample_ids_000000.npy r2-dest:<your-bucket-name>/tokenized/ --progress
155+
rclone copy r2-source:mixed-dataset-migration/tokenized/sample_ids_000001.npy r2-dest:<your-bucket-name>/tokenized/ --progress
156156
```
157157

158158
After migration, update your environment variables to point to your bucket:

hparams/hparams.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"blocks_per_window": 115,
1212
"windows_per_weights": 3,
1313
"outer_steps_per_shard": 455,
14+
"shard_reset_outer_step": 4040,
1415
"momentum_decay": 0.95,
1516
"topk_compression": 64,
1617
"target_chunk": 64,
@@ -43,8 +44,8 @@
4344
"eval_lr_factor": 0.2,
4445
"openskill_beta": 7,
4546
"openskill_tau": 0.1,
46-
"checkpoint_init_version": "2.1.15",
47-
"checkpoint_init_window": 59637,
47+
"checkpoint_init_version": "2.1.17",
48+
"checkpoint_init_window": 60711,
4849
"num_evaluation_bins": 5,
4950
"quantization_bins": 4,
5051
"quantization_range": 6,
@@ -62,6 +63,8 @@
6263
"scheduler": {
6364
"warmup_steps": 1500,
6465
"warmup_inner_steps": 30,
66+
"initial_warmup_inner_steps": 200,
67+
"replay_rewind_inner_steps": 20000,
6568
"t_max": 140000,
6669
"eta_min_factor": 0.1,
6770
"flatten_start_step": 2740,
@@ -78,6 +81,8 @@
7881
"scheduler": {
7982
"warmup_steps": 1500,
8083
"warmup_inner_steps": 30,
84+
"initial_warmup_inner_steps": 200,
85+
"replay_rewind_inner_steps": 20000,
8186
"t_max": 140000,
8287
"eta_min_factor": 0.1,
8388
"flatten_start_step": null,

neurons/miner.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ def __init__(self):
347347
token_dtype=np.uint32, # Match preprocessing script dtype
348348
)
349349
self.outer_steps_per_shard = getattr(self.hparams, "outer_steps_per_shard")
350+
self.shard_reset_outer_step = getattr(
351+
self.hparams, "shard_reset_outer_step", None
352+
)
350353

351354
tplr.logger.info("[Init] ✔ fully done – entering run()")
352355

@@ -417,7 +420,11 @@ async def run(self):
417420

418421
self.comms.start_commitment_fetcher()
419422

420-
current_shard = self.global_step // self.outer_steps_per_shard
423+
current_shard_epoch, current_shard = tplr.sharded_dataset.compute_shard_state(
424+
self.global_step,
425+
self.outer_steps_per_shard,
426+
self.shard_reset_outer_step,
427+
)
421428
tplr.logger.info(
422429
f"Starting with global_step={self.global_step} (actual outer steps)"
423430
)
@@ -432,6 +439,7 @@ async def run(self):
432439
self.set_dataloader()
433440

434441
# Track the current shard to avoid double-swapping at initialization
442+
last_shard_epoch = current_shard_epoch
435443
last_shard = current_shard
436444

437445
# Put a dummy gradient to mark this miner as active for validators
@@ -474,8 +482,24 @@ async def run(self):
474482
self.sampler.set_window_uid(self.uid, step_window)
475483

476484
# Check if we need to swap dataset based on shard index change
477-
current_shard_check = self.global_step // self.outer_steps_per_shard
478-
if current_shard_check > last_shard:
485+
shard_epoch_check, current_shard_check = (
486+
tplr.sharded_dataset.compute_shard_state(
487+
self.global_step,
488+
self.outer_steps_per_shard,
489+
self.shard_reset_outer_step,
490+
)
491+
)
492+
if shard_epoch_check != last_shard_epoch:
493+
tplr.logger.info(
494+
f"Resetting shard schedule at outer_step {self.global_step} "
495+
f"to shard {current_shard_check}"
496+
)
497+
await self.dataset_manager.initialize_datasets(current_shard_check)
498+
self.set_dataloader()
499+
dist_helper.safe_barrier("sync_shard_switch", self.local_rank)
500+
last_shard_epoch = shard_epoch_check
501+
last_shard = current_shard_check
502+
elif current_shard_check > last_shard:
479503
tplr.logger.info(
480504
f"Swapping dataset after {self.global_step} outer steps at window {step_window}"
481505
)

neurons/validator.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,9 @@ def __init__(self):
521521
self.param_change_alpha = 0.2
522522

523523
self.outer_steps_per_shard = getattr(self.hparams, "outer_steps_per_shard")
524+
self.shard_reset_outer_step = getattr(
525+
self.hparams, "shard_reset_outer_step", None
526+
)
524527
self.dataset_manager = tplr.sharded_dataset.ShardedDatasetManager(
525528
sequence_length=self.hparams.sequence_length,
526529
rank=self.local_rank, # Use local_rank for proper file operations
@@ -1245,7 +1248,11 @@ async def run(self):
12451248
aggregator_device="cpu",
12461249
)
12471250

1248-
current_shard = self.global_step // self.outer_steps_per_shard
1251+
shard_epoch, current_shard = tplr.sharded_dataset.compute_shard_state(
1252+
self.global_step,
1253+
self.outer_steps_per_shard,
1254+
self.shard_reset_outer_step,
1255+
)
12491256

12501257
# Initialize datasets (only rank 0 downloads, handled internally by dataset_manager)
12511258
_ = await self.dataset_manager.initialize_datasets(current_shard)
@@ -1256,6 +1263,7 @@ async def run(self):
12561263
self.set_dataloader(validator=True)
12571264

12581265
# Track the current shard to avoid double-swapping at initialization
1266+
last_shard_epoch = shard_epoch
12591267
last_shard = current_shard
12601268

12611269
if self.is_master:
@@ -1287,8 +1295,24 @@ async def run(self):
12871295
window_start = tplr.T()
12881296

12891297
# Check if we need to swap dataset based on shard index change
1290-
current_shard_check = self.global_step // self.outer_steps_per_shard
1291-
if current_shard_check > last_shard:
1298+
shard_epoch_check, current_shard_check = (
1299+
tplr.sharded_dataset.compute_shard_state(
1300+
self.global_step,
1301+
self.outer_steps_per_shard,
1302+
self.shard_reset_outer_step,
1303+
)
1304+
)
1305+
if shard_epoch_check != last_shard_epoch:
1306+
tplr.logger.info(
1307+
f"Resetting shard schedule at outer_step {self.global_step} "
1308+
f"to shard {current_shard_check}"
1309+
)
1310+
await self.dataset_manager.initialize_datasets(current_shard_check)
1311+
self.set_dataloader(validator=True)
1312+
dist_helper.safe_barrier("sync_shard_switch", self.local_rank)
1313+
last_shard_epoch = shard_epoch_check
1314+
last_shard = current_shard_check
1315+
elif current_shard_check > last_shard:
12921316
tplr.logger.info(
12931317
f"Swapping dataset after {self.global_step} outer steps at window {self.current_window}"
12941318
)

scripts/dataset_prep/02_consolidate_shards.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,37 @@ async def run_preprocessing(
108108
)
109109
)
110110

111-
tokens_view = np.memmap(tokens_file, dtype=token_dtype, mode="r")
112-
tok_u32 = tokens_view.view(np.uint32) # reinterpret for 4-byte hashing
111+
# Load tokens - if .npy file, use np.load to respect embedded dtype
112+
if tokens_file.endswith(".npy"):
113+
tokens_view = np.load(tokens_file, mmap_mode="r", allow_pickle=False)
114+
# Ensure it's uint32 (step 01 saves as uint32)
115+
if tokens_view.dtype != np.uint32:
116+
tqdm.write(
117+
f"Warning: Shard {i} has dtype {tokens_view.dtype}, expected uint32. "
118+
f"Converting (this may indicate a preprocessing mismatch)."
119+
)
120+
tok_u32 = tokens_view.astype(np.uint32)
121+
else:
122+
tok_u32 = tokens_view
123+
else:
124+
# Raw binary file - use specified dtype then reinterpret as uint32
125+
tokens_view = np.memmap(tokens_file, dtype=token_dtype, mode="r")
126+
tok_u32 = tokens_view.view(np.uint32)
127+
128+
# Only create sample IDs for complete sequences (discard partial sequence at end)
129+
total_tokens = tok_u32.shape[0]
130+
num_complete_samples = total_tokens // seq_len
131+
132+
# Warn if there's a partial sequence that will be discarded
133+
remainder = total_tokens % seq_len
134+
if remainder > 0:
135+
tqdm.write(
136+
f"Warning: Shard {i} has {remainder} tokens remaining after chunking "
137+
f"(total: {total_tokens}, seq_len: {seq_len}). "
138+
f"Creating {num_complete_samples} complete samples."
139+
)
113140

114-
raw_idx = np.arange(0, tok_u32.shape[0] + 1, seq_len)
141+
raw_idx = np.arange(0, num_complete_samples * seq_len + 1, seq_len)
115142
starts = raw_idx[:-1]
116143
ends = raw_idx[1:]
117144

@@ -208,7 +235,6 @@ async def main() -> None:
208235
print(f" • Shards path: {args.r2_prefix}")
209236
print(f" • Sequence length: {args.seq_len}")
210237
print(f" • Token dtype: {args.token_dtype}")
211-
print(f" • Skip Validation: {args.skip_validation}")
212238
print()
213239

214240
success = await run_preprocessing(args, args.seq_len, token_dtype)

src/tplr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# mypy: ignore-errors
2121
# type: ignore
2222

23-
__version__ = "2.1.17"
23+
__version__ = "2.1.18"
2424

2525
# Import package.
2626
from .chain import *

src/tplr/neurons.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,38 @@ async def handle_checkpoint_catchup(
689689
from_bootstrap: Whether checkpoint was from bootstrap version
690690
aggregator_device: which device to load aggregation results to
691691
"""
692+
# Determine scheduler config and warmup settings
693+
optimizer_cfg = getattr(instance.hparams, "optimizer", {})
694+
opt_type = optimizer_cfg.get("type", "adamw").lower()
695+
opt_cfg = optimizer_cfg.get(opt_type, {})
696+
scheduler_cfg = opt_cfg.get("scheduler", {})
697+
698+
default_warmup_inner = scheduler_cfg.get(
699+
"warmup_inner_steps", getattr(instance, "warmup_inner_steps", 0)
700+
)
701+
startup_warmup_inner = scheduler_cfg.get(
702+
"initial_warmup_inner_steps", default_warmup_inner
703+
)
704+
705+
# Set warmup length:
706+
# - If resuming from bootstrap, use the longer startup warmup.
707+
# - If resuming from a regular checkpoint, use the default.
708+
# - If global_step is 0, leave as-is; the scheduler's own warmup covers this.
709+
if ckpt_global_step == 0:
710+
tplr.logger.info("Global step is 0; leaving warmup settings unchanged.")
711+
elif from_bootstrap:
712+
instance.warmup_inner_steps = startup_warmup_inner
713+
tplr.logger.info(
714+
f"Applying startup warmup_inner_steps={startup_warmup_inner} (bootstrap resume)"
715+
)
716+
instance.warmup_steps_taken = 0
717+
else:
718+
instance.warmup_inner_steps = default_warmup_inner
719+
tplr.logger.info(
720+
f"Applying resumed warmup_inner_steps={default_warmup_inner} (checkpoint resume)"
721+
)
722+
instance.warmup_steps_taken = 0
723+
692724
# Decide catch-up windows and run catch-up on ALL ranks
693725
# When loading from bootstrap, we always need to catch up from start_window
694726
# to ensure we're using current version's gradients
@@ -726,7 +758,17 @@ async def handle_checkpoint_catchup(
726758
# Replay scheduler steps based on windows completed from checkpoint
727759
# ckpt_global_step tracks windows, scheduler needs inner_steps per window
728760
total_inner_steps = ckpt_global_step * instance.hparams.inner_steps
729-
if total_inner_steps > 0:
761+
762+
# Apply configurable rewind before replaying scheduler to give slack on restarts
763+
rewind_inner_steps = scheduler_cfg.get("replay_rewind_inner_steps", 0)
764+
if rewind_inner_steps > 0:
765+
total_inner_steps = max(total_inner_steps - rewind_inner_steps, 0)
766+
tplr.logger.info(
767+
f"Rewinding scheduler replay by {rewind_inner_steps} inner steps; "
768+
f"{total_inner_steps} steps remain to replay"
769+
)
770+
771+
if total_inner_steps > 0 and getattr(instance, "inner_scheduler", None) is not None:
730772
for _ in range(total_inner_steps):
731773
# Respect flatten window during replay
732774
if not instance.should_skip_scheduler_step():

0 commit comments

Comments
 (0)