Skip to content

Commit fd61cbd

Browse files
committed
add missing cactus transpile flags
Signed-off-by: jakmro <kubamroz124@gmail.com>
1 parent 8a9adf8 commit fd61cbd

4 files changed

Lines changed: 105 additions & 31 deletions

File tree

python/cactus/cli/__init__.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,35 @@ def create_parser():
121121
122122
cactus transpile <model> build a runnable bundle from CQ weights
123123
--weights-dir <path> path to CQ weights (default: weights/<model>)
124-
--task <auto|...> force task type (default: auto)
124+
--task <auto|...> task (default: auto, inferred from model config)
125125
--artifact-dir <path> write bundle here (default: weights/<model>)
126+
--prompt <text> representative prompt for shape capture
127+
--system-prompt <text> system prompt for multimodal chat
128+
--enable-thinking enable thinking markers when supported
129+
--input-ids <a,b,...> token ids for causal-LM shape capture
130+
--image-file <path> representative image (repeatable)
131+
--audio-file <path> representative audio file (WAV)
132+
--max-new-tokens <n> preallocate decode context for causal LM
133+
--component-pipeline auto|on|off force component-pipeline transpilation
134+
--components <a,b,...> subset of components to transpile
135+
--torch-dtype <dtype> float16 | float32 | bfloat16
136+
--token <token> HuggingFace token (defaults to $HF_TOKEN)
137+
--trust-remote-code allow HF remote code at transpile time
138+
--local-files-only require model/processor to be local
139+
--allow-unconverted-weights debug-only: skip the CQ-weights check
140+
--execute-after-transpile run a reference execution after lowering
141+
--graph-filename <name> override saved graph filename
142+
--skip-reference-compare skip PyTorch comparison (with --execute-…)
143+
--no-fuse-rms-norm disable RMSNorm fusion
144+
--no-fuse-rope disable RoPE fusion
145+
--no-fuse-attention disable attention fusion
146+
--no-fuse-attention-block disable attention-block fusion
147+
--no-fuse-add-clipped disable add-clipped fusion
148+
--no-fuse-gated-deltanet disable gated DeltaNet fusion
149+
--npu also emit CoreML .mlpackage(s) for NPU
150+
--npu-quantize 0|4|8 force both NPU encoders to this quant
151+
--npu-audio-quantize 0|4|8 audio encoder quant (default int8)
152+
--npu-vision-quantize 0|4|8 vision encoder quant (default fp16)
126153
127154
cactus serve [model] OpenAI-compatible local HTTP server
128155
--host <addr> bind address (default: 127.0.0.1)
@@ -289,44 +316,71 @@ def create_parser():
289316
transpile_parser = subparsers.add_parser("transpile",
290317
help="Build a runnable bundle from CQ weights")
291318
transpile_parser.add_argument("model_id", type=_hf_id_or_path,
292-
help="HuggingFace model id (e.g. openai/whisper-base) or local PyTorch checkpoint path")
319+
help="HuggingFace model id or local checkpoint path")
293320
transpile_parser.add_argument("--weights-dir",
294-
help="Path to converted CQ weights (default: weights/<model_name>)")
321+
help="CQ weights directory (default: weights/<model>)")
295322
transpile_parser.add_argument("--task", default="auto",
296323
choices=["auto", "causal_lm_logits", "multimodal_causal_lm_logits",
297324
"ctc_logits", "encoder_hidden_states",
298325
"seq2seq_transcription", "tdt_transcription"],
299-
help="Transpile task (default: auto, inferred from weights)")
326+
help="Transpile task (default: auto, from model config)")
300327
transpile_parser.add_argument("--prompt",
301-
help="Representative prompt for causal/multimodal graph shape capture")
328+
help="Prompt for causal/multimodal shape capture")
329+
transpile_parser.add_argument("--system-prompt", default=None,
330+
help="System prompt for multimodal chat formats")
331+
transpile_parser.add_argument("--enable-thinking", action="store_true",
332+
help="Enable thinking markers when the prompt supports them")
333+
transpile_parser.add_argument("--input-ids", default=None,
334+
help="Comma-separated token ids for causal-LM shape capture")
302335
transpile_parser.add_argument("--image-file", action="append", default=[],
303-
help="Representative image file for multimodal transpile (repeatable)")
336+
help="Image for multimodal shape capture (repeatable)")
304337
transpile_parser.add_argument("--audio-file",
305-
help="Representative audio file for audio/multimodal transpile")
338+
help="Audio file (WAV) for audio/multimodal shape capture")
306339
transpile_parser.add_argument("--max-new-tokens", type=_positive_int, default=None,
307-
help="Generation room to preallocate for causal decode graphs")
340+
help="Decode context to preallocate for causal LM (default: 32)")
308341
transpile_parser.add_argument("--component-pipeline", default="auto", choices=["auto", "on", "off"],
309-
help="Use split component graph transpilation when supported")
342+
help="Split-component transpilation when supported (default: auto)")
310343
transpile_parser.add_argument("--components",
311-
help="Comma-separated component subset for component-pipeline models")
344+
help="Comma-separated component subset (e.g. vision_encoder,decoder)")
345+
transpile_parser.add_argument("--torch-dtype", default=None,
346+
choices=["float16", "float32", "bfloat16"],
347+
help="Torch dtype for HF loading (default: float16)")
348+
transpile_parser.add_argument("--token", default=None,
349+
help="HuggingFace token for gated models (default: $HF_TOKEN)")
312350
transpile_parser.add_argument("--trust-remote-code", action="store_true",
313-
help="Allow HF remote code during the transpile phase")
351+
help="Pass trust_remote_code=True to HF loaders")
314352
transpile_parser.add_argument("--local-files-only", action="store_true",
315-
help="Require HF model/processor files to already be local during transpile")
353+
help="Require model/processor to already be local")
316354
transpile_parser.add_argument("--allow-unconverted-weights", action="store_true",
317-
help="Transpile against an unconverted source checkpoint (skip the CQ weights check)")
355+
help="Debug: transpile without CQ weights")
318356
transpile_parser.add_argument("--execute-after-transpile", action="store_true",
319-
help="Run a reference execution against the produced bundle after transpiling")
357+
help="Run a reference execution after lowering")
320358
transpile_parser.add_argument("--artifact-dir",
321359
help="Output directory (default: weights/<model>)")
360+
transpile_parser.add_argument("--graph-filename", default=None,
361+
help="Saved graph filename (default: graph.cactus)")
322362
transpile_parser.add_argument("--skip-reference-compare", action="store_true",
323-
help="Skip PyTorch vs transpiled output comparison")
363+
help="Skip PyTorch comparison (requires --execute-after-transpile)")
324364
transpile_parser.add_argument("--no-fuse-rms-norm", action="store_true",
325365
help="Disable RMSNorm fusion")
326366
transpile_parser.add_argument("--no-fuse-rope", action="store_true",
327367
help="Disable RoPE fusion")
328368
transpile_parser.add_argument("--no-fuse-attention", action="store_true",
329369
help="Disable attention fusion")
370+
transpile_parser.add_argument("--no-fuse-attention-block", action="store_true",
371+
help="Disable attention-block fusion")
372+
transpile_parser.add_argument("--no-fuse-add-clipped", action="store_true",
373+
help="Disable add-clipped fusion")
374+
transpile_parser.add_argument("--no-fuse-gated-deltanet", action="store_true",
375+
help="Disable gated DeltaNet fusion")
376+
transpile_parser.add_argument("--npu", action="store_true",
377+
help="Also emit CoreML .mlpackage(s) for Apple NPU encoders")
378+
transpile_parser.add_argument("--npu-quantize", type=int, choices=[0, 4, 8], default=None,
379+
help="Legacy: force both NPU encoders to same quant (0=fp16, 4=int4, 8=int8)")
380+
transpile_parser.add_argument("--npu-audio-quantize", type=int, choices=[0, 4, 8], default=None,
381+
help="NPU audio encoder quant: 0=fp16, 4=int4, 8=int8 (default: 8)")
382+
transpile_parser.add_argument("--npu-vision-quantize", type=int, choices=[0, 4, 8], default=None,
383+
help="NPU vision encoder quant: 0=fp16, 4=int4, 8=int8 (default: 0; int4 degrades Gemma4 vision)")
330384

331385
return parser
332386

python/cactus/cli/convert.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def cmd_transpile(args):
8787
extra_args.extend(["--task", args.task])
8888
if args.prompt is not None:
8989
extra_args.extend(["--prompt", args.prompt])
90+
if args.system_prompt is not None:
91+
extra_args.extend(["--system-prompt", args.system_prompt])
92+
if args.enable_thinking:
93+
extra_args.append("--enable-thinking")
94+
if args.input_ids is not None:
95+
extra_args.extend(["--input-ids", args.input_ids])
9096

9197
image_files = list(args.image_file or [])
9298
audio_file = args.audio_file
@@ -107,12 +113,18 @@ def cmd_transpile(args):
107113
extra_args.extend(["--component-pipeline", args.component_pipeline])
108114
if args.components:
109115
extra_args.extend(["--components", args.components])
116+
if args.torch_dtype:
117+
extra_args.extend(["--torch-dtype", args.torch_dtype])
118+
if args.token:
119+
extra_args.extend(["--token", args.token])
110120
if args.trust_remote_code:
111121
extra_args.append("--trust-remote-code")
112122
if args.local_files_only:
113123
extra_args.append("--local-files-only")
114124
if args.artifact_dir:
115125
extra_args.extend(["--artifact-dir", args.artifact_dir])
126+
if args.graph_filename:
127+
extra_args.extend(["--graph-filename", args.graph_filename])
116128
if args.skip_reference_compare:
117129
extra_args.append("--skip-reference-compare")
118130
if args.no_fuse_rms_norm:
@@ -121,6 +133,20 @@ def cmd_transpile(args):
121133
extra_args.append("--no-fuse-rope")
122134
if args.no_fuse_attention:
123135
extra_args.append("--no-fuse-attention")
136+
if args.no_fuse_attention_block:
137+
extra_args.append("--no-fuse-attention-block")
138+
if args.no_fuse_add_clipped:
139+
extra_args.append("--no-fuse-add-clipped")
140+
if args.no_fuse_gated_deltanet:
141+
extra_args.append("--no-fuse-gated-deltanet")
142+
if args.npu:
143+
extra_args.append("--npu")
144+
if args.npu_quantize is not None:
145+
extra_args.extend(["--npu-quantize", str(args.npu_quantize)])
146+
if args.npu_audio_quantize is not None:
147+
extra_args.extend(["--npu-audio-quantize", str(args.npu_audio_quantize)])
148+
if args.npu_vision_quantize is not None:
149+
extra_args.extend(["--npu-vision-quantize", str(args.npu_vision_quantize)])
124150

125151
return run_transpile(
126152
args.model_id,

python/cactus/transpile/npu/audio.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
2626
def _import_coremltools() -> Any:
2727
try:
2828
import coremltools as ct
29-
from .coremltools_patches import apply_all_coremltools_patches
30-
apply_all_coremltools_patches()
31-
return ct
32-
except Exception:
33-
return None
29+
except ImportError as exc:
30+
raise RuntimeError("--npu requires `pip install coremltools`") from exc
31+
from .coremltools_patches import apply_all_coremltools_patches
32+
apply_all_coremltools_patches()
33+
return ct
3434

3535

3636
def _apply_weight_quantization(mlmodel: Any, bits: int) -> Any:
@@ -65,9 +65,6 @@ def emit_audio_encoder_mlpackage(
6565
quantize_bits: int | None = None,
6666
) -> str | None:
6767
ct = _import_coremltools()
68-
if ct is None:
69-
print("npu.audio: coremltools not installed; skipping mlpackage emit")
70-
return None
7168

7269
wrapper = AudioEncoderWrapper(audio_module, baked_inputs)
7370
wrapper.eval()

python/cactus/transpile/npu/vision.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
2626
def _import_coremltools() -> Any:
2727
try:
2828
import coremltools as ct
29-
from .coremltools_patches import apply_all_coremltools_patches
30-
apply_all_coremltools_patches()
31-
return ct
32-
except Exception:
33-
return None
29+
except ImportError as exc:
30+
raise RuntimeError("--npu requires `pip install coremltools`") from exc
31+
from .coremltools_patches import apply_all_coremltools_patches
32+
apply_all_coremltools_patches()
33+
return ct
3434

3535

3636
def _apply_weight_quantization(mlmodel: Any, bits: int) -> Any:
@@ -65,9 +65,6 @@ def emit_vision_encoder_mlpackage(
6565
quantize_bits: int | None = None,
6666
) -> str | None:
6767
ct = _import_coremltools()
68-
if ct is None:
69-
print("npu.vision: coremltools not installed; skipping mlpackage emit")
70-
return None
7168

7269
wrapper = VisionEncoderWrapper(vision_module, baked_inputs)
7370
wrapper.eval()

0 commit comments

Comments
 (0)