@@ -75,13 +75,16 @@ def compute_tokens_and_batch(
7575 flops_per_token : float ,
7676 target_steps : int = DEFAULT_TARGET_STEPS ,
7777 min_batch_size : int = MIN_BATCH_SIZE ,
78- seq_len : int = SEQ_LEN ,
7978) -> tuple [float , int , int ]:
80- """Derive (tokens, batch_size, num_steps) from a compute budget and FLOPs-per-token."""
79+ """Derive (tokens, batch_size, num_steps) from a compute budget and FLOPs-per-token.
80+
81+ Uses the module-level `SEQ_LEN` constant (4096) — the whole heuristic is
82+ anchored there; see the module docstring.
83+ """
8184 tokens = budget / (3 * flops_per_token )
82- batch_exact = tokens / (target_steps * seq_len )
85+ batch_exact = tokens / (target_steps * SEQ_LEN )
8386 batch_size = max (min_batch_size , _round_to_power_of_two (batch_exact ))
84- train_steps = max (1 , round (tokens / (batch_size * seq_len )))
87+ train_steps = max (1 , round (tokens / (batch_size * SEQ_LEN )))
8588 return tokens , batch_size , train_steps
8689
8790
@@ -246,7 +249,6 @@ def build_from_heuristic(
246249 heuristic : MoeAdamHHeuristic | None = None ,
247250 target_steps : int = DEFAULT_TARGET_STEPS ,
248251 min_batch_size : int = MIN_BATCH_SIZE ,
249- seq_len : int = SEQ_LEN ,
250252) -> tuple [GrugModelConfig , GrugMoeAdamHConfig , int , int ]:
251253 """Construct (model, optimizer, batch_size, num_steps) for a compute budget.
252254
@@ -263,7 +265,6 @@ def build_from_heuristic(
263265 fpt ,
264266 target_steps = target_steps ,
265267 min_batch_size = min_batch_size ,
266- seq_len = seq_len ,
267268 )
268269 optimizer_cfg = h .build_optimizer_config (batch_size , tokens , hidden_dim )
269270 return model_cfg , optimizer_cfg , batch_size , num_steps
0 commit comments