Skip to content

Latest commit

 

History

History
342 lines (266 loc) · 11.4 KB

File metadata and controls

342 lines (266 loc) · 11.4 KB

WebAssembly Backend for cellm

This document describes how the cellm LLM inference engine was ported to WebAssembly, the design decisions involved, and how the pieces fit together.

Architecture Overview

The cellm engine already had a clean backend abstraction via the Backend trait in cellm-core. The WASM port slots into this same seam, adding a new backend that compiles to wasm32-unknown-unknown while reusing every layer above it: scheduling, KV cache, model graph traversal, and session management.

flowchart TB
    subgraph JS["Browser (JavaScript)"]
        UI["Demo UI (index.html)"] --> WASM_BINDINGS["wasm-bindgen API"]
    end

    subgraph WASM["WASM Module (cellm-wasm)"]
        WASM_BINDINGS --> ENGINE["CellmEngine (session management)"]
        ENGINE --> SDK["cellm-sdk Engine"]
        SDK --> MODEL_RUNNER["Model Runner (Llama/Gemma/Qwen/LFM)"]

        MODEL_RUNNER --> BACKEND["Backend trait"]
        BACKEND --> WASM_KERNELS["WASM SIMD Kernels"]
        BACKEND --> WEBGPU_KERNELS["WebGPU (WGSL) Shaders"]
        BACKEND --> SCALAR_FALLBACK["Scalar Fallback"]
    end

    subgraph BROWSER_APIS["Browser APIs"]
        MODEL_FILE["Model .cellm file (fetch / file picker)"]
        TOKENIZER_FILE["Tokenizer JSON (fetch / file picker)"]
    end

    MODEL_FILE --> ENGINE
    TOKENIZER_FILE --> ENGINE
Loading

What Was Built

1. WASM SIMD Kernels (cellm-kernels/src/wasm.rs)

9 kernel functions mirroring the existing NEON-accelerated CPU kernels, using std::arch::wasm32 v128 SIMD intrinsics instead of std::arch::aarch64 NEON.

The translation is nearly one-to-one because both architectures use 128-bit SIMD registers with 4x f32 lanes:

NEON Intrinsic WASM SIMD Equivalent
vld1q_f32(ptr) v128_load(ptr as *const v128)
vdupq_n_f32(x) f32x4_splat(x)
vmlaq_f32(acc, a, b) f32x4_add(acc, f32x4_mul(a, b))
vgetq_lane_f32(v, i) f32x4_extract_lane::<i>(v)
vst1q_f32(ptr, v) v128_store(ptr as *mut v128, v)
vld1q_s8(ptr) v128_load(ptr as *const v128)
vmovl_s8(wv) i16x8_extend_low_i8x16_s / i16x8_extend_high_i8x16_s
vcvtq_f32_s32(wv) f32x4_convert_i32x4_s(wv)

Matrix-vector products (the decode-time bottleneck) use a 4-way unrolled dot product that processes 16 floats per loop iteration. The int8 matmul loads 16 int8 weight values, extends them through i16 to i32, converts to f32, and multiply-accumulates against 16 f32 activations -- all with register-only operations.

Each function has #[cfg(target_arch = "wasm32")] SIMD blocks and scalar fallbacks under #[cfg(not(target_arch = "wasm32"))], so the file compiles on any target.

flowchart LR
    subgraph NEON["Original NEON Kernel"]
        A0["vld1q_f32(row)"] --> B0["vmlaq_f32(sum, a, b)"]
        B0 --> C0["vgetq_lane_f32(sum)"]
    end

    subgraph WASM["WASM SIMD Kernel"]
        A1["v128_load(row)"] --> B1["f32x4_add(sum, f32x4_mul(a, b))"]
        B1 --> C1["f32x4_extract_lane(sum)"]
    end

    NEON -. "one-to-one mapping" .-> WASM
Loading

2. Model Loading from Bytes (cellm-model/src/cellm_file.rs)

The existing CellmFile already supported an owned Vec<u8> data variant as a fallback when mmap failed. A new load_from_bytes(bytes: &[u8]) constructor was added that:

  1. Validates the magic header (CELLM + version byte)
  2. Parses the JSON header section
  3. Builds the tensor index map
  4. Wraps the bytes in CellmData::Owned

This avoids any filesystem dependency -- critical for WASM where mmap does not exist.

3. Runner from_file Constructors (cellm-model/src/{llama,gemma,qwen,lfm}.rs)

Each model runner previously required a filesystem path:

pub fn load(path: &Path) -> Result<Self, CoreError> {
    let file = CellmFile::load(path)?;
    // ... extract config, detect prefix ...
}

A parallel from_file(file: CellmFile) constructor was added to each runner that skips the CellmFile::load call and uses the passed-in file directly. The WASM engine calls CellmFile::load_from_bytes then dispatches to the appropriate from_file.

4. Engine from_bytes Constructor (cellm-sdk/src/lib.rs)

The top-level Engine::new(path, config) was mirrored as Engine::from_bytes(model_bytes, config). It follows the exact same setup sequence -- extract model type, construct runner, compute head dimension, allocate KV cache -- but from in-memory bytes instead of a file path.

5. WASM Bindings Crate (cellm-wasm/)

A new workspace crate cellm-wasm exposes the engine to JavaScript via wasm-bindgen:

classDiagram
    class CellmEngine {
        +new(modelBytes, configJson) CellmEngine
        +set_tokenizer(tokenizerJson)
        +tokenize(text) u32[]
        +decode(tokens) string
        +create_session() u64
        +submit_tokens(sessionId, tokens) u32
        +step_decode() [sid, token] or null
        +generate(sessionId, tokens, maxTokens) [sid, token][]
        +cancel_session(sessionId)
        +reset_session(sessionId)
        +suspend_session(sessionId)
        +resume_session(sessionId)
        +pending_tokens(sessionId) u32
        +total_tokens_generated() f64
        +num_active_sessions() u32
        +num_free_blocks() u32
    }
Loading

The API is designed for two usage patterns:

Manual stepping -- gives the caller control over decode pacing, useful for responsive UIs:

const sid = engine.create_session();
engine.submit_tokens(sid, inputIds);

while (true) {
    const result = engine.step_decode();
    if (!result) break;
    const [sid, token] = result;
    output += engine.decode([token]);
    if (isEos(token)) break;
    await sleep(0); // yield to browser
}

Batched generate -- convenience wrapper that runs a decode loop server-side and returns all tokens at once:

const results = engine.generate(sid, inputIds, 64);

6. Threading with Web Workers

The SIMD kernels use rayon for parallel iteration across rows in matrix operations. Under WASM, this requires SharedArrayBuffer support, which is enabled by two HTTP headers on the serving page:

Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp

The wasm-bindgen-rayon crate bridges Rayon's thread pool to browser Web Workers, spawning a worker pool at WASM initialisation time.

Data Flow During Inference

sequenceDiagram
    participant JS as Browser JS
    participant WASM as WASM Module
    participant KERNELS as WASM SIMD Kernels

    JS->>WASM: CellmEngine.new(modelBytes, config)
    WASM->>WASM: CellmFile.load_from_bytes(modelBytes)
    WASM->>WASM: Runner.from_file(file)
    WASM->>WASM: Create KV cache
    WASM-->>JS: engine instance

    JS->>WASM: engine.create_session()
    WASM-->>JS: sessionId

    JS->>WASM: engine.submit_tokens(sid, [1, 304, ...])
    WASM->>WASM: runner.prefill(tokens, ...)
    WASM->>WASM: runner.step_topk(lastTok, ...)
    WASM->>KERNELS: matmul_f32, rms_norm_f32, ...
    KERNELS-->>WASM: logits
    WASM-->>JS: nextToken

    loop decode loop
        JS->>WASM: engine.step_decode()
        WASM->>WASM: runner.decode_one_for_session(...)
        WASM->>KERNELS: matmul_f32 (n=1), attention, rope, softmax
        KERNELS-->>WASM: logits
        WASM-->>JS: [sessionId, token]
    end
Loading

Performance Considerations

WASM SIMD is competitive for small models. The 128-bit SIMD width matches NEON exactly, and the 4-way unrolled dot product is the same algorithm. For models under 1B parameters (like SmolLM2-135M), the decode path is compute-bound within the matmul, and WASM SIMD delivers roughly equivalent per-cycle throughput to native ARM NEON.

WebGPU Acceleration. In addition to WASM SIMD, cellm supports WebGPU compute shaders (WGSL) for hardware acceleration. By calling engine.try_init_webgpu(), the engine can offload heavy matrix multiplications to the GPU, often resulting in 10-50x speedups over SIMD for larger models. If WebGPU is unavailable, it gracefully falls back to WASM SIMD.

Threading is by worker count. Rayon in WASM uses a fixed worker pool (typically navigator.hardwareConcurrency workers). Each worker is a Web Worker with its own v128 SIMD unit, so parallelisation across matrix rows scales linearly up to the available core count.

Model loading is an upfront cost. Because WASM cannot mmap, the entire model file must be copied into WASM linear memory before inference begins. For a 135M int8 model this is roughly 270 MB. For models above 1B parameters, chunked loading from IndexedDB would be needed to avoid excessive memory usage.

Build and Test

# Install tools
cargo install wasm-pack

# Build the WASM module
./scripts/build-wasm.sh --release

# Output goes to crates/cellm-wasm/pkg/
# Serve the demo page:
python3 -m http.server 8080 \
    --directory crates/cellm-wasm/www/

# Open http://localhost:8080 in a browser

The demo page requires Cross-Origin-Opener-Policy and Cross-Origin-Embedder-Policy headers for SharedArrayBuffer support. When serving with python3 -m http.server, these headers are not set and the page will fall back to single-threaded execution. For full threading, serve with a server that adds these headers, or use the wasm-pack test server:

wasm-pack test --firefox --headless

Testing with Node.js

The WASM module can also be tested directly from Node.js without a browser. This is useful for benchmarking, CI pipelines, and debugging.

# Requirements: Node >= 18 (WASM SIMD enabled by default)
# WebGPU: Node >= 22 with --experimental-webgpu

node docs/wasm/test-node.mjs <model.cellm> <tokenizer.json>

The test script (docs/wasm/test-node.mjs) loads the WASM module, creates an engine, tokenizes a prompt, and runs a decode loop using step_decode().

Example runs:

# NanoWhale-100M (~18 tok/s on Apple M-series)
$ node docs/wasm/test-node.mjs models/nanowhale-100m.cellm models/nanowhale-100m/tokenizer.json
Model: 220.8MB, Tokens: 9.2MB
 Your response should contain at least 3...
50 tokens in 2738.1ms (18.3 tok/s)

# Qwen2.5-0.5B int8 (~1.7 tok/s)
$ node docs/wasm/test-node.mjs models/to-huggingface/qwen2.5-0.5b-int8-v1/qwen2.5-0.5b-int8-v1.cellm models/to-huggingface/qwen2.5-0.5b-int8-v1/tokenizer.json
Model: 495.1MB, Tokens: 6.2MB
50 tokens in 29845.0ms (1.7 tok/s)

Note: engine.generate() currently returns empty on WASM. Use the manual step_decode() loop instead:

const sid = engine.create_session();
let tok = engine.submit_tokens(sid, inputIds);
let count = 0;
while (count < maxTokens) {
  count++;
  process.stdout.write(engine.decode(new Uint32Array([tok])));
  if (engine.is_stop_token(tok)) break;
  const r = engine.step_decode();
  if (!r) break;
  tok = r[1]; // r = [session_id, token_id]
}

For WebGPU testing (Node >= 22):

node --experimental-webgpu docs/wasm/test-node.mjs --webgpu \
  models/nanowhale-100m.cellm models/nanowhale-100m/tokenizer.json
Model Size WASM SIMD Notes
NanoWhale-100M (f16) 220 MB ~18 tok/s Compute-bound, good WASM fit
Qwen2.5-0.5B (int8) 495 MB ~1.7 tok/s Memory-bound, int8 dequant overhead

Limitations and Future Work

  • Speculative decoding to reduce per-token latency
  • Chunked model loading from IndexedDB for models above 1B parameters
  • Streaming model fetch (decode while still downloading weights)
  • SIMD within a worker is limited to 128-bit lanes; future WASM relaxed SIMD and AMX-like extensions will narrow the gap to native Apple Silicon