@@ -26,7 +26,7 @@ def tiny_llama_config(fixture_stub_tokenizer_path):
2626 num_hidden_layers = 2 ,
2727 num_attention_heads = 2 ,
2828 num_key_value_heads = 2 ,
29- max_position_embeddings = 128 ,
29+ max_position_embeddings = 512 ,
3030 )
3131
3232
@@ -91,11 +91,22 @@ def iris_df():
9191
9292@pytest .fixture (scope = "session" )
9393def timeseries_df ():
94- """Minimal timeseries stub: 2 groups, 5 rows each, elapsed_seconds ."""
94+ """Minimal timeseries stub: 2 groups, 5 rows each, 60s intervals ."""
9595 return pd .DataFrame (
9696 {
9797 "group_id" : ["A" , "A" , "A" , "A" , "A" , "B" , "B" , "B" , "B" , "B" ],
98- "elapsed_seconds" : [0 , 60 , 120 , 180 , 240 , 0 , 60 , 120 , 180 , 240 ],
98+ "timestamp" : [
99+ "2024-01-01 00:00:00" ,
100+ "2024-01-01 00:01:00" ,
101+ "2024-01-01 00:02:00" ,
102+ "2024-01-01 00:03:00" ,
103+ "2024-01-01 00:04:00" ,
104+ "2024-01-01 00:00:00" ,
105+ "2024-01-01 00:01:00" ,
106+ "2024-01-01 00:02:00" ,
107+ "2024-01-01 00:03:00" ,
108+ "2024-01-01 00:04:00" ,
109+ ],
99110 "value" : [10 , 20 , 30 , 40 , 50 , 100 , 110 , 120 , 130 , 140 ],
100111 }
101112 )
@@ -149,17 +160,17 @@ def train_with_sdk(config, data_df, save_path):
149160
150161@pytest .fixture
151162def _patch_attn_eager (monkeypatch ):
152- """Override attn_implementation to 'eager' for tiny model compatibility .
163+ """Override attn_implementation from 'flashinfer' (not a valid HF option) to 'sdpa' .
153164
154- The HuggingFaceBackend defaults to 'flashinfer' which can fail with
155- head_dim=32 (our tiny model: hidden_size=64 / 2 heads) .
165+ The HuggingFaceBackend defaults to 'flashinfer' which is not supported by
166+ HuggingFace's from_pretrained. PyTorch SDPA is universally compatible .
156167 """
157168 from nemo_safe_synthesizer .training .huggingface_backend import HuggingFaceBackend
158169
159- original = HuggingFaceBackend ._build_base_framework_params
170+ original_build = HuggingFaceBackend ._build_base_framework_params
160171
161- def patched (self , model_kwargs ):
162- model_kwargs .setdefault ("attn_implementation" , "eager " )
163- return original (self , model_kwargs )
172+ def patched_build (self , model_kwargs ):
173+ model_kwargs .setdefault ("attn_implementation" , "sdpa " )
174+ return original_build (self , model_kwargs )
164175
165- monkeypatch .setattr (HuggingFaceBackend , "_build_base_framework_params" , patched )
176+ monkeypatch .setattr (HuggingFaceBackend , "_build_base_framework_params" , patched_build )
0 commit comments