Skip to content

Commit 25e9c35

Browse files
committed
fix natural proportions
1 parent a1107ba commit 25e9c35

File tree

2 files changed

+61
-26
lines changed

2 files changed

+61
-26
lines changed

experiments/domain_phase_mix/config.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,21 @@ class Domain:
6464

6565
name: str
6666
components: list[DatasetComponent]
67-
natural_proportion: float = 1.0
67+
natural_proportion: float | None = None # If None, computed from total_weight
6868
description: str = ""
6969

70+
@property
71+
def total_weight(self) -> float:
72+
"""Sum of all component weights (typically token counts)."""
73+
return sum(c.weight for c in self.components)
74+
7075
def get_component_weights(self) -> dict[str, float]:
7176
"""Get normalized weights for components within this domain.
7277
7378
Returns:
7479
Dictionary mapping component names to their normalized weights.
7580
"""
76-
total = sum(c.weight for c in self.components)
81+
total = self.total_weight
7782
if total == 0:
7883
# Uniform weights if all are zero
7984
n = len(self.components)
@@ -300,9 +305,24 @@ def experiment_budget(self) -> int:
300305
return self.total_steps * self.tokens_per_step
301306

302307
def get_natural_proportions(self) -> dict[str, float]:
303-
"""Get natural proportions for all domains (normalized)."""
304-
total = sum(d.natural_proportion for d in self.domains)
305-
return {d.name: d.natural_proportion / total for d in self.domains}
308+
"""Get natural proportions for all domains (normalized).
309+
310+
If a domain has natural_proportion set, uses that value.
311+
Otherwise, uses the domain's total_weight (sum of component weights).
312+
The final proportions are normalized to sum to 1.
313+
"""
314+
315+
def get_domain_weight(d: Domain) -> float:
316+
if d.natural_proportion is not None:
317+
return d.natural_proportion
318+
return d.total_weight
319+
320+
total = sum(get_domain_weight(d) for d in self.domains)
321+
if total == 0:
322+
# Uniform if all weights are zero
323+
n = len(self.domains)
324+
return {d.name: 1.0 / n for d in self.domains}
325+
return {d.name: get_domain_weight(d) / total for d in self.domains}
306326

307327
def get_all_components(self) -> dict[str, ExecutorStep]:
308328
"""Get all dataset components across all domains."""

experiments/domain_phase_mix/domains.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,69 +120,74 @@ def _nemotron_low_actual():
120120
return _get_nemotron_tokenized()["nemotron_cc/low_actual"]
121121

122122

123+
# Conversion factor: ~500B tokens per TiB for typical text data
124+
TIB_TO_TOKENS_B = 500.0
125+
123126
# High-quality Nemotron splits (no synthetic)
127+
# Convert TiB to billions of tokens for consistent units
124128
NEMOTRON_HQ_DOMAIN = register_domain(
125129
Domain(
126130
name="nemotron_hq",
127131
components=[
128132
DatasetComponent(
129133
name="nemotron_cc/hq_actual",
130134
step_fn=_nemotron_hq_actual,
131-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/hq_actual", 0.91),
135+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/hq_actual", 0.91) * TIB_TO_TOKENS_B,
132136
),
133137
DatasetComponent(
134138
name="nemotron_cc/medium_high",
135139
step_fn=_nemotron_medium_high,
136-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium_high", 0.82),
140+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium_high", 0.82) * TIB_TO_TOKENS_B,
137141
),
138142
DatasetComponent(
139143
name="nemotron_cc/medium",
140144
step_fn=_nemotron_medium,
141-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium", 3.38),
145+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium", 3.38) * TIB_TO_TOKENS_B,
142146
),
143147
],
144-
natural_proportion=0.70,
148+
# natural_proportion computed from total_weight (~2.5T tokens)
145149
description="High-quality Nemotron CC splits (hq_actual, medium_high, medium) - no synthetic data",
146150
)
147151
)
148152

149153
# Full Nemotron domain (including synthetic and lower quality)
154+
# Convert TiB to billions of tokens for consistent units
150155
NEMOTRON_FULL_DOMAIN = register_domain(
151156
Domain(
152157
name="nemotron_full",
153158
components=[
154159
DatasetComponent(
155160
name="nemotron_cc/hq_actual",
156161
step_fn=_nemotron_hq_actual,
157-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/hq_actual", 0.91),
162+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/hq_actual", 0.91) * TIB_TO_TOKENS_B,
158163
),
159164
DatasetComponent(
160165
name="nemotron_cc/hq_synth",
161166
step_fn=_nemotron_hq_synth,
162-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/hq_synth", 0.5),
167+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/hq_synth", 0.5) * TIB_TO_TOKENS_B,
163168
),
164169
DatasetComponent(
165170
name="nemotron_cc/medium_high",
166171
step_fn=_nemotron_medium_high,
167-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium_high", 0.82),
172+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium_high", 0.82) * TIB_TO_TOKENS_B,
168173
),
169174
DatasetComponent(
170175
name="nemotron_cc/medium",
171176
step_fn=_nemotron_medium,
172-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium", 3.38),
177+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium", 3.38) * TIB_TO_TOKENS_B,
173178
),
174179
DatasetComponent(
175180
name="nemotron_cc/medium_low",
176181
step_fn=_nemotron_medium_low,
177-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium_low", 1.0),
182+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/medium_low", 1.0) * TIB_TO_TOKENS_B,
178183
),
179184
DatasetComponent(
180185
name="nemotron_cc/low_actual",
181186
step_fn=_nemotron_low_actual,
182-
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/low_actual", 0.5),
187+
weight=NEMOTRON_WEIGHTS.get("nemotron_cc/low_actual", 0.5) * TIB_TO_TOKENS_B,
183188
),
184189
],
185-
natural_proportion=0.70,
190+
# natural_proportion computed from total_weight
186191
description="Full Nemotron CC dataset including synthetic and lower quality splits",
187192
)
188193
)
@@ -207,17 +212,20 @@ def _get_fineweb_edu():
207212
return _fineweb_edu_cache
208213

209214

215+
# Weight in billions of tokens
216+
FINEWEB_EDU_TOKENS_B = 1300.0 # ~1.3T tokens
217+
210218
FINEWEB_EDU_DOMAIN = register_domain(
211219
Domain(
212220
name="fineweb_edu",
213221
components=[
214222
DatasetComponent(
215223
name="fineweb_edu",
216224
step_fn=_get_fineweb_edu,
217-
weight=1.0,
225+
weight=FINEWEB_EDU_TOKENS_B,
218226
),
219227
],
220-
natural_proportion=0.25,
228+
# natural_proportion computed from total_weight (~1.3T tokens)
221229
description="FineWeb-Edu dataset (~1.3T tokens of educational web content)",
222230
)
223231
)
@@ -273,6 +281,7 @@ def _dolmino_wiki():
273281
# Total: ~832.56B tokens
274282

275283
# Full Dolmino domain with all non-math splits
284+
# Component weights are in billions of tokens, total ~833B tokens
276285
DOLMINO_DOMAIN = register_domain(
277286
Domain(
278287
name="dolmino",
@@ -303,7 +312,7 @@ def _dolmino_wiki():
303312
weight=DOLMINO_WEIGHTS["wiki"],
304313
),
305314
],
306-
natural_proportion=0.25,
315+
# natural_proportion computed from total_weight (~833B tokens)
307316
description="Full Dolmino dataset (dclm, flan, pes2o, stackexchange, wiki) for mid-training",
308317
)
309318
)
@@ -313,24 +322,29 @@ def _dolmino_wiki():
313322
# SFT DOMAINS
314323
# ============================================================================
315324

316-
# SFT dataset definitions
325+
# SFT dataset definitions with estimated token counts (in billions)
326+
# Token estimates verified against HuggingFace dataset pages
317327
SFT_DATASETS = {
318328
"tulu_3_sft_mixture": {
319329
"hf_id": "allenai/tulu-3-sft-mixture",
320330
"sample_count": 939343,
331+
"tokens_b": 0.15, # ~939K samples, HF estimates 100M-200M tokens
321332
"description": "General instruction tuning mixture",
322333
},
323334
"openthoughts_114k_math": {
324335
"hf_id": "open-r1/OpenThoughts-114k-math",
325336
"sample_count": 89120,
337+
"tokens_b": 0.45, # ~89K samples with long CoT reasoning (~5K tokens/sample)
326338
"description": "Math reasoning with chain-of-thought",
327339
},
328340
"verifiable_math_problems": {
329341
"hf_id": "PrimeIntellect/verifiable-math-problems",
330342
"sample_count": 777457,
343+
"tokens_b": 1.4, # ~777K samples, HF estimates ~1.4B tokens
331344
"description": "Verifiable math problem solving",
332345
},
333346
}
347+
# Total SFT: ~2.0B tokens
334348

335349
# Pre-tokenized paths (if available)
336350
SFT_TOKENIZED_PATHS = {
@@ -373,27 +387,28 @@ def _verifiable_math():
373387

374388

375389
# Math-focused SFT domain
390+
# Component weights in billions of tokens, total ~1.6B tokens
376391
MATH_SFT_DOMAIN = register_domain(
377392
Domain(
378393
name="math_sft",
379394
components=[
380395
DatasetComponent(
381396
name="tulu_3_sft_mixture",
382397
step_fn=_tulu_3_sft,
383-
weight=SFT_DATASETS["tulu_3_sft_mixture"]["sample_count"],
398+
weight=SFT_DATASETS["tulu_3_sft_mixture"]["tokens_b"],
384399
),
385400
DatasetComponent(
386401
name="openthoughts_114k_math",
387402
step_fn=_openthoughts_math,
388-
weight=SFT_DATASETS["openthoughts_114k_math"]["sample_count"],
403+
weight=SFT_DATASETS["openthoughts_114k_math"]["tokens_b"],
389404
),
390405
DatasetComponent(
391406
name="verifiable_math_problems",
392407
step_fn=_verifiable_math,
393-
weight=SFT_DATASETS["verifiable_math_problems"]["sample_count"],
408+
weight=SFT_DATASETS["verifiable_math_problems"]["tokens_b"],
394409
),
395410
],
396-
natural_proportion=0.05,
411+
# natural_proportion computed from total_weight (~1.6B tokens)
397412
description="Math-focused SFT datasets (Tulu-3 + math reasoning)",
398413
)
399414
)
@@ -406,10 +421,10 @@ def _verifiable_math():
406421
DatasetComponent(
407422
name="tulu_3_sft_mixture",
408423
step_fn=_tulu_3_sft,
409-
weight=1.0,
424+
weight=SFT_DATASETS["tulu_3_sft_mixture"]["tokens_b"],
410425
),
411426
],
412-
natural_proportion=0.05,
427+
# natural_proportion computed from total_weight (~0.94B tokens)
413428
description="General instruction tuning with Tulu-3 mixture",
414429
)
415430
)

0 commit comments

Comments
 (0)