Skip to content

Commit 07ed42b

Browse files
authored
[GPU] Fix CoreWeave FA4 canaries (#6237)
Add the FA4 CuTe and THD runtime dependencies to the GPU extras used by CoreWeave task images, and make explicit-resource StepRunner jobs infer the accelerator extra when callers do not override dependency groups. Wire the CoreWeave canary workflow to dispatch either gpu_fa4_cute or gpu_fa4_thd, and allow the THD backend to use the upstream SM90 kernels needed on H100 while keeping unsupported-image failures explicit. Fixes #6226
1 parent d600675 commit 07ed42b

17 files changed

Lines changed: 652 additions & 176 deletions

File tree

.github/workflows/marin-canary-ferry-coreweave.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ on:
2626
type: number
2727
default: 1
2828
required: false
29+
attention_implementation:
30+
description: 'GPU attention backend'
31+
type: choice
32+
options:
33+
- gpu_fa4_cute
34+
- gpu_fa4_thd
35+
- reference
36+
default: gpu_fa4_cute
37+
required: false
2938

3039
permissions:
3140
contents: read # actions/checkout
@@ -48,6 +57,7 @@ jobs:
4857
env:
4958
RUN_ID: canary-gpu-${{ github.run_id }}-${{ github.run_attempt }}
5059
CANARY_ACCELERATOR: gpu
60+
CANARY_ATTENTION_IMPLEMENTATION: ${{ github.event_name == 'workflow_dispatch' && inputs.attention_implementation || 'gpu_fa4_cute' }}
5161
CANARY_BATCH_SIZE: "16"
5262
CANARY_GPU_REPLICAS: ${{ github.event_name == 'workflow_dispatch' && format('{0}', inputs.gpu_replicas) || '1' }}
5363
# TODO(#5524): remove this override once Levanter profiler stop is
@@ -123,6 +133,7 @@ jobs:
123133
-e MARIN_PREFIX s3://marin-na/marin/ \
124134
-e RUN_ID "$RUN_ID" \
125135
-e CANARY_ACCELERATOR "$CANARY_ACCELERATOR" \
136+
-e CANARY_ATTENTION_IMPLEMENTATION "$CANARY_ATTENTION_IMPLEMENTATION" \
126137
-e CANARY_BATCH_SIZE "$CANARY_BATCH_SIZE" \
127138
-e CANARY_GPU_REPLICAS "$CANARY_GPU_REPLICAS" \
128139
-e CANARY_PROFILER_NUM_STEPS "$CANARY_PROFILER_NUM_STEPS" \

experiments/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@
4848
lm_mixture_data_config,
4949
)
5050
from marin.processing.tokenize.tokenize import HfTokenizeConfig, TokenizeConfigBase
51+
from marin.training.run_environment import extras_for_resources
5152
from marin.training.training import (
5253
TrainDpoOnPodConfig,
5354
TrainLmOnPodConfig,
5455
bake_output_path,
5556
check_train_config_paths,
56-
extras_for_resources,
5757
impute_run_id,
5858
resolve_training_env,
5959
run_levanter_train_dpo,

experiments/ferries/canary_ferry.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
to the Iris container. workflow_dispatch inputs override CANARY_TARGET_TOKENS.
99
1010
CANARY_ACCELERATOR tpu | gpu
11+
CANARY_ATTENTION_IMPLEMENTATION gpu-only attention backend, e.g. gpu_fa4_cute
1112
CANARY_TPU_TYPE tpu-only comma-separated slice types, primary first (default v5p-8,v4-8)
1213
CANARY_BATCH_SIZE per-device batch size
1314
CANARY_CACHE_COPY_MAX_WORKERS gpu-only cache-copy worker cap
@@ -26,11 +27,15 @@
2627
RUN_ID unique run identifier
2728
"""
2829

30+
import dataclasses
2931
import datetime
3032
import os
33+
from typing import cast
3134

3235
from fray.cluster import ResourceConfig
3336
from levanter.callbacks.profiler import ProfilerConfig
37+
from levanter.data.text import DatasetComponent
38+
from levanter.grug.attention import GrugAttentionImplementation
3439
from levanter.optim import AdamConfig
3540
from levanter.tracker.json_logger import JsonLoggerConfig
3641
from levanter.tracker.wandb import WandbConfig
@@ -62,6 +67,13 @@
6267
ema_beta=None,
6368
log_every=1,
6469
)
70+
_GPU_FA4_CUTE_ATTENTION: GrugAttentionImplementation = "gpu_fa4_cute"
71+
_GPU_FA4_THD_ATTENTION: GrugAttentionImplementation = "gpu_fa4_thd"
72+
_GPU_ATTENTION_IMPLEMENTATIONS: tuple[GrugAttentionImplementation, ...] = (
73+
"reference",
74+
_GPU_FA4_CUTE_ATTENTION,
75+
_GPU_FA4_THD_ATTENTION,
76+
)
6577

6678
# Compute budget passed to the heuristic when CANARY_HIDDEN_DIM scales the model.
6779
# Only the model *shape* (from hidden_dim) is used here; the budget-derived batch
@@ -130,10 +142,37 @@ def _build_step_from_env() -> ExecutorStep:
130142
else:
131143
model, _, _, _ = build_from_heuristic(budget=_HEURISTIC_BUDGET, hidden_dim=hidden_dim)
132144

145+
attention_implementation = os.environ.get("CANARY_ATTENTION_IMPLEMENTATION", _GPU_FA4_CUTE_ATTENTION)
146+
if attention_implementation not in _GPU_ATTENTION_IMPLEMENTATIONS:
147+
raise ValueError(
148+
f"Unknown CANARY_ATTENTION_IMPLEMENTATION={attention_implementation!r}, expected one of "
149+
f"{_GPU_ATTENTION_IMPLEMENTATIONS}"
150+
)
151+
attention_implementation = cast(GrugAttentionImplementation, attention_implementation)
152+
model = dataclasses.replace(
153+
model,
154+
attention_implementation=attention_implementation,
155+
# The THD backend only handles full causal windows. Setting the model
156+
# window to 2x seq_len makes Grug's short-window mask a full window.
157+
sliding_window=(
158+
model.max_seq_len * 2 if attention_implementation == _GPU_FA4_THD_ATTENTION else model.sliding_window
159+
),
160+
)
161+
133162
batch_size = env_int("CANARY_BATCH_SIZE", 32)
134163
target_tokens = env_int("CANARY_TARGET_TOKENS", batch_size * model.max_seq_len * 50)
135164

136165
data = slimpajama_6b_data()
166+
if attention_implementation == _GPU_FA4_THD_ATTENTION:
167+
data = dataclasses.replace(
168+
data,
169+
components={
170+
name: (
171+
dataclasses.replace(component, pack=1) if isinstance(component, DatasetComponent) else component
172+
)
173+
for name, component in data.components.items()
174+
},
175+
)
137176
resources = ResourceConfig.with_gpu(
138177
gpu_type,
139178
count=gpu_count,
@@ -142,16 +181,17 @@ def _build_step_from_env() -> ExecutorStep:
142181
disk="256g",
143182
replicas=gpu_replicas,
144183
)
145-
name = f"canary-ferry-cw-{gpu_type.lower()}x{gpu_count}-r{gpu_replicas}-d{hidden_dim}"
146-
wandb_group = f"canary-ferry-moe-gpu-{gpu_type.lower()}-r{gpu_replicas}"
147-
wandb_tags = ["canary", "ferry", "grug", "moe", "gpu", gpu_type.lower()]
184+
attention_tag = attention_implementation.removeprefix("gpu_")
185+
name = f"canary-ferry-cw-{gpu_type.lower()}x{gpu_count}-r{gpu_replicas}-d{hidden_dim}-{attention_tag}"
186+
wandb_group = f"canary-ferry-moe-gpu-{gpu_type.lower()}-r{gpu_replicas}-{attention_tag}"
187+
wandb_tags = ["canary", "ferry", "grug", "moe", "gpu", gpu_type.lower(), f"d{hidden_dim}", attention_tag]
148188
eval_config = None
149189

150190
num_steps = env_int("CANARY_STEPS", target_tokens // (batch_size * model.max_seq_len))
151191
if num_steps <= 0:
152192
raise ValueError(
153193
f"CANARY_STEPS={num_steps} invalid; set CANARY_STEPS or CANARY_TARGET_TOKENS high enough for "
154-
f"batch_size={batch_size} x seq_len={GRUG_MOE_TRIAL_MODEL.max_seq_len}"
194+
f"batch_size={batch_size} x seq_len={model.max_seq_len}"
155195
)
156196
if os.environ.get("CANARY_TRACKER", "wandb").lower() == "json_logger":
157197
tracker = JsonLoggerConfig(logger_name=os.environ.get("CANARY_JSON_LOGGER", "canary_ferry.metrics"))

experiments/grug/moe/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def _batch_reshard(x: jax.Array) -> jax.Array:
6969
return reshard(x, _batch_spec())
7070

7171

72+
def _layer_attention_masks(mask: AttentionMask, *, sliding_window: int) -> tuple[AttentionMask, AttentionMask]:
73+
return mask.with_sliding_window(sliding_window // 2), mask.with_sliding_window(sliding_window)
74+
75+
7276
@dataclass(frozen=True)
7377
class GrugModelConfig:
7478
"""Hyperparameters for the grug MoE transformer.
@@ -518,9 +522,9 @@ def __call__(
518522
hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec)
519523
hidden = self.embed_gated_norm(self.embed_norm(hidden))
520524

521-
segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None
522-
short_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window // 2, segment_ids=segment_ids)
523-
long_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window, segment_ids=segment_ids)
525+
if not isinstance(mask, AttentionMask):
526+
mask = AttentionMask.causal()
527+
short_mask, long_mask = _layer_attention_masks(mask, sliding_window=cfg.sliding_window)
524528

525529
moe_router_stats: list[dict[str, jax.Array]] = []
526530
for i, block in enumerate(self.blocks):

experiments/tutorials/train_tiny_sweep_tpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from levanter.main.train_lm import TrainLmConfig
1919
from marin.execution.sweep import SweepTarget, claim_and_run
2020
from marin.execution.types import versioned
21-
from marin.training.training import extras_for_resources, resolve_training_env
21+
from marin.training.run_environment import extras_for_resources
22+
from marin.training.training import resolve_training_env
2223

2324
from experiments.defaults import _run_training_on_worker, prepare_lm_train
2425
from experiments.evals.task_configs import CORE_TASKS

lib/levanter/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ gpu = [
8484
"nvidia-cublas>=13.2.0.9; sys_platform == 'linux'",
8585
# Preserve the CoreWeave H100 all-to-all guard under CUDA 13.
8686
"nvidia-nccl-cu13>=2.28.3; sys_platform == 'linux'",
87+
# FA4 CuTe/THD attention backends.
88+
"nvidia-cutlass-dsl[cu13]>=4.5.2,<4.6; sys_platform == 'linux'",
89+
"flash-attn-4[cu13]>=4.0.0b16,<4.1; sys_platform == 'linux'",
8790
# Optional raw Sonic MoE gather/combine backend. Keep jax-triton exact:
8891
# 0.3.0 is missing the CUDA backend API needed by our current stack.
8992
"jax-triton==0.3.1; sys_platform == 'linux'",

0 commit comments

Comments
 (0)