88to the Iris container. workflow_dispatch inputs override CANARY_TARGET_TOKENS.
99
1010 CANARY_ACCELERATOR tpu | gpu
11+ CANARY_ATTENTION_IMPLEMENTATION gpu-only attention backend, e.g. gpu_fa4_cute
1112 CANARY_TPU_TYPE tpu-only comma-separated slice types, primary first (default v5p-8,v4-8)
1213 CANARY_BATCH_SIZE per-device batch size
1314 CANARY_CACHE_COPY_MAX_WORKERS gpu-only cache-copy worker cap
2627 RUN_ID unique run identifier
2728"""
2829
30+ import dataclasses
2931import datetime
3032import os
33+ from typing import cast
3134
3235from fray .cluster import ResourceConfig
3336from levanter .callbacks .profiler import ProfilerConfig
37+ from levanter .data .text import DatasetComponent
38+ from levanter .grug .attention import GrugAttentionImplementation
3439from levanter .optim import AdamConfig
3540from levanter .tracker .json_logger import JsonLoggerConfig
3641from levanter .tracker .wandb import WandbConfig
6267 ema_beta = None ,
6368 log_every = 1 ,
6469)
70+ _GPU_FA4_CUTE_ATTENTION : GrugAttentionImplementation = "gpu_fa4_cute"
71+ _GPU_FA4_THD_ATTENTION : GrugAttentionImplementation = "gpu_fa4_thd"
72+ _GPU_ATTENTION_IMPLEMENTATIONS : tuple [GrugAttentionImplementation , ...] = (
73+ "reference" ,
74+ _GPU_FA4_CUTE_ATTENTION ,
75+ _GPU_FA4_THD_ATTENTION ,
76+ )
6577
6678# Compute budget passed to the heuristic when CANARY_HIDDEN_DIM scales the model.
6779# Only the model *shape* (from hidden_dim) is used here; the budget-derived batch
@@ -130,10 +142,37 @@ def _build_step_from_env() -> ExecutorStep:
130142 else :
131143 model , _ , _ , _ = build_from_heuristic (budget = _HEURISTIC_BUDGET , hidden_dim = hidden_dim )
132144
145+ attention_implementation = os .environ .get ("CANARY_ATTENTION_IMPLEMENTATION" , _GPU_FA4_CUTE_ATTENTION )
146+ if attention_implementation not in _GPU_ATTENTION_IMPLEMENTATIONS :
147+ raise ValueError (
148+ f"Unknown CANARY_ATTENTION_IMPLEMENTATION={ attention_implementation !r} , expected one of "
149+ f"{ _GPU_ATTENTION_IMPLEMENTATIONS } "
150+ )
151+ attention_implementation = cast (GrugAttentionImplementation , attention_implementation )
152+ model = dataclasses .replace (
153+ model ,
154+ attention_implementation = attention_implementation ,
155+ # The THD backend only handles full causal windows. Setting the model
156+ # window to 2x seq_len makes Grug's short-window mask a full window.
157+ sliding_window = (
158+ model .max_seq_len * 2 if attention_implementation == _GPU_FA4_THD_ATTENTION else model .sliding_window
159+ ),
160+ )
161+
133162 batch_size = env_int ("CANARY_BATCH_SIZE" , 32 )
134163 target_tokens = env_int ("CANARY_TARGET_TOKENS" , batch_size * model .max_seq_len * 50 )
135164
136165 data = slimpajama_6b_data ()
166+ if attention_implementation == _GPU_FA4_THD_ATTENTION :
167+ data = dataclasses .replace (
168+ data ,
169+ components = {
170+ name : (
171+ dataclasses .replace (component , pack = 1 ) if isinstance (component , DatasetComponent ) else component
172+ )
173+ for name , component in data .components .items ()
174+ },
175+ )
137176 resources = ResourceConfig .with_gpu (
138177 gpu_type ,
139178 count = gpu_count ,
@@ -142,16 +181,17 @@ def _build_step_from_env() -> ExecutorStep:
142181 disk = "256g" ,
143182 replicas = gpu_replicas ,
144183 )
145- name = f"canary-ferry-cw-{ gpu_type .lower ()} x{ gpu_count } -r{ gpu_replicas } -d{ hidden_dim } "
146- wandb_group = f"canary-ferry-moe-gpu-{ gpu_type .lower ()} -r{ gpu_replicas } "
147- wandb_tags = ["canary" , "ferry" , "grug" , "moe" , "gpu" , gpu_type .lower ()]
184+ attention_tag = attention_implementation .removeprefix ("gpu_" )
185+ name = f"canary-ferry-cw-{ gpu_type .lower ()} x{ gpu_count } -r{ gpu_replicas } -d{ hidden_dim } -{ attention_tag } "
186+ wandb_group = f"canary-ferry-moe-gpu-{ gpu_type .lower ()} -r{ gpu_replicas } -{ attention_tag } "
187+ wandb_tags = ["canary" , "ferry" , "grug" , "moe" , "gpu" , gpu_type .lower (), f"d{ hidden_dim } " , attention_tag ]
148188 eval_config = None
149189
150190 num_steps = env_int ("CANARY_STEPS" , target_tokens // (batch_size * model .max_seq_len ))
151191 if num_steps <= 0 :
152192 raise ValueError (
153193 f"CANARY_STEPS={ num_steps } invalid; set CANARY_STEPS or CANARY_TARGET_TOKENS high enough for "
154- f"batch_size={ batch_size } x seq_len={ GRUG_MOE_TRIAL_MODEL .max_seq_len } "
194+ f"batch_size={ batch_size } x seq_len={ model .max_seq_len } "
155195 )
156196 if os .environ .get ("CANARY_TRACKER" , "wandb" ).lower () == "json_logger" :
157197 tracker = JsonLoggerConfig (logger_name = os .environ .get ("CANARY_JSON_LOGGER" , "canary_ferry.metrics" ))
0 commit comments