Skip to content

Commit e871be4

Browse files
authored
chore: try to use kernels lib to get flash attention kernel (#30)
In an attempt to ease some deps and open up more potential paths for easier benchmarking, this let's us specify flash attention implementations/backends on the main config objects for both training and generation. There's some typing-related changes in this too that affect the artifact structure's internals. I'm going to do a follow up to fix _a lot_ of typing issues throughout the repo soon; i have a draft in progress. --------- Signed-off-by: Aaron Gonzales <aagonzales@nvidia.com> Signed-off-by: aagonzales <aagonzales@nvidia.com>
1 parent efec5c3 commit e871be4

12 files changed

Lines changed: 227 additions & 56 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ $(info local system architecture: $(PLATFORM)/$(ARCH))
3939
.PHONY: help
4040
help:
4141
@echo "Makefile commands:"
42-
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
42+
@grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4343

4444

4545
### BOOTSTRAP AND SETUP ###

README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,47 @@ Commands:
191191
validate Validate a Safe Synthesizer configuration.
192192
```
193193
194+
## Attention Configuration
195+
196+
Safe Synthesizer exposes attention implementation settings for both training and generation.
197+
198+
### Training (`attn_implementation`)
199+
200+
Controls the HuggingFace attention backend used during model loading for training. Set via config YAML, CLI, or SDK:
201+
202+
```yaml
203+
# config.yaml
204+
training:
205+
attn_implementation: "kernels-community/vllm-flash-attn3"
206+
```
207+
208+
```bash
209+
# CLI override
210+
safe-synthesizer run --training__attn_implementation sdpa --url my_data.csv
211+
```
212+
213+
| Value | Description | Requires |
214+
|-------|-------------|----------|
215+
| `kernels-community/vllm-flash-attn3` | Flash Attention 3 via HuggingFace Kernels Hub (default) | `kernels` pip package |
216+
| `kernels-community/flash-attn2` | Flash Attention 2 via HuggingFace Kernels Hub | `kernels` pip package |
217+
| `flash_attention_2` | Flash Attention 2 (traditional) | `flash-attn` pip package |
218+
| `sdpa` | PyTorch scaled dot product attention | None (built-in) |
219+
| `eager` | Standard PyTorch attention | None (built-in) |
220+
221+
If the default `kernels-community/vllm-flash-attn3` is configured but the `kernels` package is not installed, the backend automatically falls back to `sdpa`.
222+
223+
### Generation (`attention_backend`)
224+
225+
Controls the vLLM attention backend used during synthetic data generation. Defaults to `"auto"`, which lets vLLM auto-select the best available backend.
226+
227+
```yaml
228+
# config.yaml
229+
generation:
230+
attention_backend: "FLASH_ATTN"
231+
```
232+
233+
Common values: `FLASHINFER`, `FLASH_ATTN`, `TORCH_SDPA`, `TRITON_ATTN`, `FLEX_ATTENTION`.
234+
194235
## Artifacts and Workdirs
195236
196237
Safe Synthesizer uses a structured directory format to manage artifacts (trained models, synthetic data, logs).

src/nemo_safe_synthesizer/cli/artifact_structure.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import os
89
from dataclasses import dataclass, field
910
from datetime import datetime
1011
from pathlib import Path
@@ -196,11 +197,11 @@ def __get__(self, obj: object | None, objtype: type | None = None) -> DirNode |
196197
raise TypeError(f"DirNode can only be used with BoundDir or Workdir, got {type(obj)}")
197198

198199

199-
class BoundDir:
200+
class BoundDir(os.PathLike[str]):
200201
"""Runtime class representing a bound directory path.
201202
202-
Provides access to child FileNode and DirNode descriptors as attributes,
203-
and implements __fspath__ for use with os.path functions.
203+
Provides access to child FileNode and DirNode descriptors as attributes.
204+
Implements os.PathLike[str] so instances can be used wherever paths are expected.
204205
"""
205206

206207
def __init__(self, path: Path, children: dict[str, FileNode | DirNode]):
@@ -240,12 +241,6 @@ def __eq__(self, other: object) -> bool:
240241
def __hash__(self) -> int:
241242
return hash(self._path)
242243

243-
def __getattribute__(self, name: str) -> Path | BoundDir:
244-
# Allow access to special methods, private attrs, and the path property
245-
if name.startswith("_") or name == "path":
246-
return super().__getattribute__(name)
247-
return self.__getattr__(name)
248-
249244
def __getattr__(self, name: str) -> Path | BoundDir:
250245
if name.startswith("_"):
251246
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")

src/nemo_safe_synthesizer/config/generate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class GenerateParameters(Parameters, BaseModel):
7575
patience: Number of invalid records fraction before stopping.
7676
invalid_fraction_threshold: "The fraction of invalid records that will stop generation after the `patience` limit is reached."
7777
use_structured_generation: Whether to use structured generation for better format control.
78+
attention_backend: The attention backend for the vLLM engine. If None, vLLM will
79+
auto-select the best available backend.
7880
7981
"""
8082

@@ -179,3 +181,15 @@ class GenerateParameters(Parameters, BaseModel):
179181
description="Validation parameters controlling validation logic and automatic fixes when parsing LLM output and converting to tabular data.",
180182
default_factory=ValidationParameters,
181183
)
184+
185+
attention_backend: Annotated[
186+
str | None,
187+
Field(
188+
title="attention_backend",
189+
description=(
190+
"The attention backend for the vLLM engine. Common values: 'FLASHINFER', "
191+
"'FLASH_ATTN', 'TRITON_ATTN', 'FLEX_ATTENTION'. "
192+
"If None or 'auto', vLLM will auto-select the best available backend."
193+
),
194+
),
195+
] = "auto"

src/nemo_safe_synthesizer/config/training.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ class TrainingHyperparams(Parameters):
7979
peft_implementation: The PEFT (Parameter-Efficient Fine-Tuning) implementation to use.
8080
Options include 'lora' for Low-Rank Adaptation, QLoRA for Quantized LoRA. Each method has its own trade-offs in terms of performance
8181
and resource requirements.
82+
attn_implementation: The attention implementation to use for model loading.
83+
Default uses Flash Attention 3 via the HuggingFace Kernels Hub. Falls back to 'sdpa'
84+
if the kernels package is not installed. Other common values include 'flash_attention_2',
85+
'sdpa', and 'eager'.
8286
"""
8387

8488
num_input_records_to_sample: Annotated[
@@ -285,3 +289,19 @@ class TrainingHyperparams(Parameters):
285289
description="The fraction of the total VRAM to use for training. Default is 0.9. Modify this to allow longer sequences to be used.",
286290
),
287291
] = 0.80
292+
293+
attn_implementation: Annotated[
294+
str,
295+
Field(
296+
title="attn_implementation",
297+
description=(
298+
"The attention implementation to use for model loading. "
299+
"Default uses Flash Attention 3 via the HuggingFace Kernels Hub "
300+
"(requires the 'kernels' pip package; falls back to 'sdpa' if unavailable). "
301+
"Other common values: 'flash_attention_2' (requires flash-attn pip package), "
302+
"'sdpa' (PyTorch scaled dot product attention), 'eager' (standard PyTorch). "
303+
"Custom HuggingFace Kernels Hub paths (e.g. 'kernels-community/flash-attn2') "
304+
"are also supported."
305+
),
306+
),
307+
] = "kernels-community/vllm-flash-attn3"

src/nemo_safe_synthesizer/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,5 @@
9292
EPS = 1e-15
9393
NUM_SPECIAL_TOKENS = 2
9494
DEFAULT_CACHE_PREFIX = "safe-synthesizer-dataset-cache"
95+
DEFAULT_ATTN_IMPLEMENTATION = "kernels-community/vllm-flash-attn3"
96+
BACKUP_ATTN_IMPLEMENTATION = "sdpa"

src/nemo_safe_synthesizer/generation/vllm_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@ def __del__(self) -> None:
8181

8282
def initialize(self, **kwargs) -> None:
8383
"""Initialize and load the model into memory."""
84-
max_vram = get_max_vram(as_fraction=True)
85-
max_vram = max_vram.get(0)
84+
# vLLM 0.11.x uses an environment variable for attention backend selection.
85+
# When vLLM is upgraded to 0.12+, migrate to the attention_backend constructor arg.
86+
if self.config.generation.attention_backend not in [None, "auto"]:
87+
os.environ["VLLM_ATTENTION_BACKEND"] = self.config.generation.attention_backend
88+
89+
max_vram = get_max_vram()
90+
# note this only works for single GPU setups
91+
max_vram = max_vram.get(0, 0.8)
8692

8793
# vllm requires this "config" to set the backend ahead of time.
8894
structured_outputs_config = StructuredOutputsConfig(
@@ -91,7 +97,7 @@ def initialize(self, **kwargs) -> None:
9197
)
9298
self.llm = vLLM(
9399
model=self.config.training.pretrained_model,
94-
gpu_memory_utilization=float(max_vram),
100+
gpu_memory_utilization=max_vram,
95101
enable_lora=True,
96102
max_lora_rank=self.config.training.lora_r,
97103
structured_outputs_config=structured_outputs_config,

src/nemo_safe_synthesizer/llm/utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,35 +63,30 @@ def round_gb(value: float) -> float:
6363
logger.info(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
6464

6565

66-
def get_max_vram(
67-
memory_fraction: float | None = None, as_string: bool = True, as_fraction: bool = False
68-
) -> dict[int, float | str]:
66+
def get_max_vram(max_vram_fraction: float | None = None) -> dict[int, float]:
6967
"""
70-
Calculate max memory allocation for each available GPU and CPU.
68+
Calculate max memory allocation for each available GPU and CPU as a fraction of total GPU memory.
7169
7270
Args:
7371
memory_fraction: Fraction of total GPU memory to allocate (default 0.8 for 80%)
7472
7573
Returns:
76-
Dictionary mapping device IDs to memory limits
74+
Dictionary mapping device IDs to memory limits as a fraction of total GPU memory
7775
"""
78-
if memory_fraction is None:
79-
memory_fraction = 0.8
76+
if max_vram_fraction is None:
77+
max_vram_fraction = 0.8
8078
max_memory = {}
8179

8280
if torch.cuda.is_available():
8381
num_gpus = torch.cuda.device_count()
8482
for i in range(num_gpus):
8583
free, total = torch.cuda.mem_get_info(device=i)
8684
safe_free = free - (2 * 1024**3)
87-
gpu_memory_utilization = min(memory_fraction, safe_free / total)
85+
gpu_memory_utilization = min(max_vram_fraction, safe_free / total)
8886
memory_gib = gpu_memory_utilization * total / (1024**3)
89-
if as_fraction:
90-
max_memory[i] = gpu_memory_utilization
91-
else:
92-
max_memory[i] = memory_gib if not as_string else f"{memory_gib:.2f}GiB"
87+
max_memory[i] = gpu_memory_utilization
9388
logger.info(
94-
f"GPU {i}: Will allocate {memory_gib:.2f}GiB ({memory_fraction * 100}% of {total / (1024**3):.2f}GiB)"
89+
f"GPU {i}: Will allocate {memory_gib:.2f}GiB ({max_vram_fraction * 100}% of {total / (1024**3):.2f}GiB)"
9590
)
9691

9792
return max_memory

src/nemo_safe_synthesizer/training/backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from datasets import Dataset
1414
from peft import PeftModel
1515
from transformers import (
16-
AutoTokenizer,
1716
PreTrainedModel,
1817
PreTrainedTokenizer,
1918
Trainer,
@@ -50,7 +49,7 @@ class NSSTrainerResult:
5049

5150
class TrainingBackend(metaclass=abc.ABCMeta):
5251
model: PreTrainedModel | PeftModel
53-
tokenizer: AutoTokenizer | PreTrainedTokenizer
52+
tokenizer: PreTrainedTokenizer
5453
quant_params: dict
5554
load_params: dict
5655
trainer_type: type[OpacusDPTrainer | Trainer | FastLanguageModel]
@@ -59,7 +58,7 @@ class TrainingBackend(metaclass=abc.ABCMeta):
5958
results: NSSTrainerResult
6059
training_examples: TrainingExamples
6160
df_train: pd.DataFrame
62-
df_test: pd.DataFrame
61+
df_test: pd.DataFrame | None
6362
dataset_schema: dict | None
6463
training_output_dir: Path
6564
workdir: Workdir

0 commit comments

Comments
 (0)