Skip to content

Commit f7a0777

Browse files
forforever73CISC
andauthored
convert : support Step3.7-Flash (#23845)
* feat: support step3.7 * fix: register Step-3.7 BPE pre-tokenizer hash * delete fromjson * register step3.7 arch to Step35Model * drop vit projector in base filter * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * restore blank line --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 4f3a4be commit f7a0777

3 files changed

Lines changed: 25 additions & 7 deletions

File tree

conversion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
"Starcoder2ForCausalLM": "starcoder",
216216
"Step3p5ForCausalLM": "step3",
217217
"StepVLForConditionalGeneration": "step3",
218+
"Step3p7ForConditionalGeneration": "step3",
218219
"T5EncoderModel": "t5",
219220
"T5ForConditionalGeneration": "t5",
220221
"T5WithLMHeadModel": "t5",
@@ -283,6 +284,7 @@
283284
"Sarashina2VisionForCausalLM": "sarashina2",
284285
"SmolVLMForConditionalGeneration": "smolvlm",
285286
"StepVLForConditionalGeneration": "step3",
287+
"Step3p7ForConditionalGeneration": "step3",
286288
"UltravoxModel": "ultravox",
287289
"VoxtralForConditionalGeneration": "ultravox",
288290
"YoutuVLForConditionalGeneration": "youtuvl",

conversion/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2593,7 +2593,7 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
25932593
# Step3-VL keeps text config under text_config but uses a custom top-level architecture.
25942594
# For text conversion we route to a dedicated text-only class.
25952595
# TODO: refactor this later to avoid adding exception here
2596-
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM", "Exaone4_5_ForConditionalGeneration"):
2596+
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM", "Exaone4_5_ForConditionalGeneration", "Step3p7ForConditionalGeneration"):
25972597
return arch
25982598

25992599
# if "architectures" is found in the sub-config, use that instead

conversion/step3.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .qwen import Qwen3Model
1616

1717

18-
@ModelBase.register("StepVLForConditionalGeneration")
18+
@ModelBase.register("StepVLForConditionalGeneration", "Step3p7ForConditionalGeneration")
1919
class Step3VLVisionModel(MmprojModel):
2020
def __init__(self, *args, **kwargs):
2121
super().__init__(*args, **kwargs)
@@ -95,7 +95,7 @@ class Step3VLTextModel(Qwen3Model):
9595
model_arch = gguf.MODEL_ARCH.QWEN3
9696

9797

98-
@ModelBase.register("Step3p5ForCausalLM")
98+
@ModelBase.register("Step3p5ForCausalLM", "Step3p7ForConditionalGeneration")
9999
class Step35Model(TextModel):
100100
model_arch = gguf.MODEL_ARCH.STEP35
101101

@@ -203,11 +203,23 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
203203
if isinstance(rope_theta, list):
204204
rope_theta = rope_theta[0]
205205
base = float(rope_theta)
206-
if (dim := self.hparams.get("head_dim")) is None:
207-
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
208-
dim = int(dim)
209206

210-
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
207+
if (storage_dim := self.hparams.get("head_dim")) is None:
208+
storage_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
209+
storage_dim = int(storage_dim)
210+
211+
# Llama 3 factors apply only to the rotary dims used by full_attention layers
212+
# (partial_rotary_factor * head_dim). Remaining slots are padded with 1.0 so
213+
# sliding_attention layers remain unaffected. set_gguf_parameters already
214+
# guarantees at least one full_attention layer.
215+
layer_types = (self.hparams.get("layer_types") or [])[: self.block_count]
216+
partial_rotary_factors = (self.hparams.get("partial_rotary_factors") or [])[: self.block_count]
217+
full_attention_factor = next(
218+
float(f) for lt, f in zip(layer_types, partial_rotary_factors) if lt == "full_attention"
219+
)
220+
rotary_dim = int(storage_dim * full_attention_factor)
221+
222+
freqs = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim))
211223

212224
factor = float(rope_params.get("factor", 8.0))
213225
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
@@ -228,4 +240,8 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
228240
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
229241
rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth))
230242

243+
# Pad to head_dim/2 with 1.0 so non-scaled layers remain neutral.
244+
if len(rope_factors) < storage_dim // 2:
245+
rope_factors.extend([1.0] * (storage_dim // 2 - len(rope_factors)))
246+
231247
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))

0 commit comments

Comments
 (0)