Fuse expresses AI as a small set of tensor equations: joins + projection (+ optional nonlinearity). This version ships a NumPy execution engine, a Torch FX lowering, and a JAX path that can be JIT-compiled to XLA. It keeps the code small and readable while covering end-to-end flows (sources → IR → backends → sinks).
Paper context: logical rules ≙ Einstein sums; anything not on the LHS is implicitly projected (summed) out; equations with the same LHS implicitly add; files are first-class sources/sinks (reading/writing tensors).
pip install fuse
pip install fuse[torch] # Torch FX backend
pip install fuse[jax] # JAX backend
pip install fuse[bench] # Torch + JAX bundle for benchmarks
pip install fuse[dev] # Linting, typing, testspip install -e ".[dev]"The repository ships a uv.lock for reproducible envs. To sync the dev environment:
uv sync --extra devYou can swap in other extras (e.g., --extra torch --extra jax or --all-extras) to mirror the pip install fuse[...] flows above.
Install pre-commit and enable the hooks to run Ruff (formatting + lint), pyupgrade, and import sorting automatically:
pip install pre-commit
pre-commit install
pre-commit run --all-files- All backends accept source tensors at call time via
runner.run(inputs={"Source": array}); file paths remain as defaults when no runtime tensors are provided. - Torch and JAX backends currently defer demand-driven execution (
mode="demand") and Monte-Carlo projection to the NumPy engine, falling back to the reference evaluator in those configurations. - Torch FX exports keep file-backed defaults baked into the traced graph today; prefer
runner.runwhen you need to feed dynamic tensors. The JAXrunner.xla_callableaccepts pytrees of runtime inputs. - The NumPy runner now supports
ExecutionConfig(fixpoint_strategy="semi_naive")for delta-driven fixpoint scheduling plus optional blocked einsums viaExecutionConfig(block_size=...).
python examples/01_attention_block.py
python examples/02_rules_and_embeddings.py
python examples/03_zero_input_lm.pyArtifacts (sources, IR, simple plans, outputs) are written into runs/ under each example’s folder.
| Topic | Description |
|---|---|
| DSL reference | One-page grammar & operator cheatsheet. |
| Backend matrix | Backend capabilities and constraints at a glance. |
| CLI usage | Running programs quickly via python -m fuse run. |
- A
Programthat parses a compact Fuse DSL:- Lines:
T[i,j] = A[i,k] B[k,j],Y[i.] = softmax(X[i]), axis-awareconcat,amax/avgprojections, and literal/keyword arguments. - Sources:
T[i,j] = "file.npy","out.jsonl" = T[i,j], text/CSV autoloaders, plus pluggable weight stores viaRuntimePolicies.
- Lines:
- Execution engines
- NumPy runner with fixpoint forward/backward chaining, recursion, enhanced
explain()(einsum canonicalization, projected indices, timing). - Torch FX backend that emits a graph module backed by
torch.einsum/NN ops (access viarunner.fx_module), honouring caching and policies. - JAX backend that evaluates with
jax.numpyand exposes a lazily-builtrunner.xla_callableforjax.jitexport.
- NumPy runner with fixpoint forward/backward chaining, recursion, enhanced
- Backend selection
- The Python API defaults to
backend="auto"inProgram.compile(). The chooser is hardware‑ and workload‑aware:- Picks NumPy for demand mode, Monte Carlo projection, or streaming programs.
- Prefers Torch on CUDA/MPS for attention/MLP‑style workloads when available.
- Otherwise considers JAX for heavier batched workloads when JAX is installed.
- Falls back to NumPy for small programs or when accelerators aren’t available.
- The Python API defaults to
- Execution controls via
ExecutionConfigprecisiondefaults tofp32. Mixed-precision runs can requestbf16,fp16, orauto(which selects the fastest supported dtype per backend/device). NumPy always stays infp32; Torch refusesfp16on CPU and checks CUDA/MPS support; JAX only permitsfp16on GPU and maps TPU/GPUautoruns tobf16.devicechooses where execution happens:auto,cpu,cuda[:index], ormps. NumPy can only target CPU; Torch/JAX resolve and pin all compilation artifacts to the requested accelerator so FX graphs and XLA lowers stay aligned.zero_copykeeps host↔device hand-offs lean. WhenTrue(default) the runners reuse host buffers on CPU and skip redundant.tolist()conversions; setFalseif you need defensive copies before handing tensors to external code.- For JAX you can opt into the experimental XLA cache with
ExecutionConfig(jax_enable_xla_cache=True, jax_cache_dir="~/.cache/fuse/jax")(path optional) and grab the lazily-builtrunner.xla_callableforjax.jitexecution. validate_device_transfers=Trueraises if GPU/TPU runs would implicitly copy NumPy inputs to device memory, forcing explicitjax.device_puthand-offs when you want to audit data movement.
- Quantised weights retain scale/zero-point metadata. During dequantisation we enforce float32 accumulation (at least fp16) and broadcast-compatible shapes; values are assumed to be pre-saturated/rounded into the int8 range, so Fuse only rescales without introducing extra clipping.
- Caching and policies
Program.compile(..., cache_dir="path")stores backend artifacts viaCacheManager.RuntimePoliciescaptures weight stores, sharding metadata, quantisation (e.g.int8dequant) and LoRA adapter rules.
- Examples:
- Masked attention in ~3 equations.
- A rule (
Aunt) + reasoning in embedding space with a temperature knobT. - Zero-input LM head: sources/sinks live in the program; just run the artifact.
- Parser is still line-oriented: no arithmetic, conditionals, or macro system yet.
- Fixpoint mode is synchronous (no semi-naïve delta joins) so large recursive programs may run slowly.
- Torch/JAX backends embed sources as constants; dynamic data loaders will need additional plumbed inputs.
- Policy hooks surface structure but do not yet include end-to-end distributed sharding or quant-aware training loops.