Skip to content

Commit 29f3816

Browse files
authored
Merge pull request #78 from 0xClandestine/feat/add-dflash
feat: add DFlash speculative decoding
2 parents 05d0b6c + 65d74a9 commit 29f3816

38 files changed

Lines changed: 8401 additions & 94 deletions

.github/workflows/ci.yml

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,99 @@ jobs:
218218
path: /tmp/SwiftLM-test-speculative.log
219219
retention-days: 7
220220

221+
# ── DFlash Speculative Decoding E2E ──
222+
# Uses the standard macos-15 runner (7 GB RAM).
223+
dflash-speculative-decoding:
224+
runs-on: macos-15
225+
timeout-minutes: 45
226+
needs: build_and_unit_test
227+
steps:
228+
- uses: actions/checkout@v4
229+
with:
230+
submodules: recursive
231+
232+
- name: Install Metal Toolchain
233+
run: xcodebuild -downloadComponent MetalToolchain || true
234+
235+
- name: Cache Swift packages
236+
uses: actions/cache@v4
237+
with:
238+
path: .build
239+
key: ${{ runner.os }}-spm-SwiftLM-v3-${{ hashFiles('Package.resolved') }}
240+
restore-keys: |
241+
${{ runner.os }}-spm-SwiftLM-v3-
242+
243+
- name: Clear stale module cache
244+
run: find .build -type d -name ModuleCache -exec rm -rf {} + 2>/dev/null || true
245+
246+
- name: Resolve dependencies
247+
run: swift package resolve
248+
249+
- name: Build (Release)
250+
run: swift build -c release
251+
252+
- name: Compile and install custom MLX Metal library
253+
run: |
254+
if [ -d "mlx-swift/Source/Cmlx/mlx" ]; then
255+
MLX_SRC="mlx-swift/Source/Cmlx/mlx"
256+
else
257+
MLX_SRC=".build/checkouts/mlx-swift/Source/Cmlx/mlx"
258+
fi
259+
mkdir -p .build/metallib_build
260+
pushd .build/metallib_build
261+
cmake "../../$MLX_SRC" \
262+
-DMLX_BUILD_TESTS=OFF \
263+
-DMLX_BUILD_EXAMPLES=OFF \
264+
-DMLX_BUILD_BENCHMARKS=OFF \
265+
-DMLX_BUILD_PYTHON_BINDINGS=OFF \
266+
-DMLX_METAL_JIT=OFF \
267+
-DMLX_ENABLE_NAX=1 \
268+
-DCMAKE_BUILD_TYPE=Release 2>&1 | tail -20
269+
make mlx-metallib -j$(sysctl -n hw.ncpu) 2>&1 | tail -20
270+
popd
271+
BUILT=$(find .build/metallib_build -name "mlx.metallib" | head -1)
272+
cp "$BUILT" .build/release/mlx.metallib
273+
python3 -m venv /tmp/mlx_venv
274+
/tmp/mlx_venv/bin/pip install --quiet huggingface_hub hf
275+
276+
- name: Cache MLX models (dflash + main)
277+
uses: actions/cache@v4
278+
with:
279+
path: ~/.cache/huggingface
280+
key: mlx-dflash-qwen35-4b
281+
282+
- name: Pre-download HuggingFace models
283+
run: |
284+
source /tmp/mlx_venv/bin/activate
285+
hf download mlx-community/Qwen3.5-4B-4bit || true
286+
hf download z-lab/Qwen3.5-4B-DFlash || true
287+
288+
- name: Run DFlash E2E
289+
env:
290+
HF_HUB_DOWNLOAD_TIMEOUT: "900"
291+
run: |
292+
chmod +x tests/test-dflash.sh
293+
for attempt in 1 2 3; do
294+
echo "Attempt $attempt of 3..."
295+
if tests/test-dflash.sh .build/release/SwiftLM 15415; then
296+
exit 0
297+
fi
298+
if [ "$attempt" -lt 3 ]; then
299+
echo "Test failed, retrying in 10s..."
300+
sleep 10
301+
fi
302+
done
303+
echo "All attempts failed"
304+
exit 1
305+
306+
- name: Upload dflash test logs on failure
307+
if: failure()
308+
uses: actions/upload-artifact@v4
309+
with:
310+
name: dflash-test-logs
311+
path: /tmp/SwiftLM-test-dflash.log
312+
retention-days: 7
313+
221314
# ── Speculative Decoding Memory Evaluation ──
222315
# Runs the 2B model with NUM_DRAFT_TOKENS=2 to check peak
223316
# memory compression/efficiency. Emits vm_stat readings as step summary.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ tmp/
2828
.agents/harness/audio-omni-gemma4/runs/
2929
.venv/
3030
mem-palace/
31+
32+
33+
tests/DFlash/intermediates/

Package.resolved

Lines changed: 22 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ let package = Package(
66
platforms: [.macOS(.v14), .iOS(.v17)],
77
products: [
88
.library(name: "MLXInferenceCore", targets: ["MLXInferenceCore"]),
9+
.library(name: "DFlash", targets: ["DFlash"]),
910
.executable(name: "SwiftLM", targets: ["SwiftLM"]),
10-
.executable(name: "SwiftBuddy", targets: ["SwiftBuddy"])
11+
.executable(name: "SwiftBuddy", targets: ["SwiftBuddy"]),
12+
.executable(name: "DFlashKernelBench", targets: ["DFlashKernelBench"])
1113
],
1214
dependencies: [
1315
// Local Apple MLX Swift fork for C++ extensions
@@ -29,6 +31,7 @@ let package = Package(
2931
name: "SwiftLM",
3032
dependencies: [
3133
"MLXInferenceCore",
34+
"DFlash",
3235
.product(name: "MLX", package: "mlx-swift"),
3336
.product(name: "MLXLLM", package: "mlx-swift-lm"),
3437
.product(name: "MLXVLM", package: "mlx-swift-lm"),
@@ -40,6 +43,16 @@ let package = Package(
4043
],
4144
path: "Sources/SwiftLM"
4245
),
46+
// ── DFlash Kernel Micro-Benchmark ───────────────────────────
47+
.executableTarget(
48+
name: "DFlashKernelBench",
49+
dependencies: [
50+
"DFlash",
51+
.product(name: "MLX", package: "mlx-swift"),
52+
.product(name: "MLXNN", package: "mlx-swift"),
53+
],
54+
path: "Sources/DFlashKernelBench"
55+
),
4356
// ── STFT Audio Profiling Testing Script (macOS only) ───────────
4457
.executableTarget(
4558
name: "SwiftLMTestSTFT",
@@ -86,6 +99,17 @@ let package = Package(
8699
.enableExperimentalFeature("StrictConcurrency")
87100
]
88101
),
102+
// ── DFlash Speculative Decoding ─────────────────────────────
103+
.target(
104+
name: "DFlash",
105+
dependencies: [
106+
.product(name: "MLX", package: "mlx-swift"),
107+
.product(name: "MLXLLM", package: "mlx-swift-lm"),
108+
.product(name: "MLXLMCommon", package: "mlx-swift-lm"),
109+
],
110+
path: "Sources/DFlash",
111+
exclude: ["DFlashKernelsOptimized.swift"]
112+
),
89113
// ── Automated Test Harness ──────────────────────────────────
90114
.testTarget(
91115
name: "SwiftBuddyTests",

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ curl http://localhost:5413/v1/chat/completions \
438438
| `--turbo-kv` | `false` | Enable TurboQuant 3-bit KV cache compression (activates after 2048 tokens, server-wide) |
439439
| `--draft-model` | (none) | Draft model path/ID for speculative decoding. When used with `--stream-experts`, `--num-draft-tokens` is auto-capped to 1 to minimise SSD I/O fan-out (see performance note above). |
440440
| `--num-draft-tokens` | `4` | Tokens per speculation round. Auto-capped to 1 when combined with `--stream-experts`. |
441+
| `--dflash` | `false` | Enable DFlash block-diffusion speculative decoding. Requires a compatible DFlash draft model |
442+
| `--dflash-block-size`| (auto) | Number of tokens per DFlash draft block. Defaults to draft model config |
441443

442444
## 🔧 Per-Request API Parameters
443445

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2026 SwiftLM Contributors
2+
// MIT License — see LICENSE file
3+
// Based on DFlash (arXiv:2602.06036)
4+
5+
import Foundation
6+
import MLX
7+
import MLXLMCommon
8+
import MLXNN
9+
10+
// MARK: - Draft Backend
11+
12+
/// Backend for generating draft tokens using the DFlash draft model.
13+
public final class DFlashDraftBackend: @unchecked Sendable {
14+
15+
public init() {}
16+
17+
/// Create the draft cache (one `ContextOnlyDraftKVCache` per layer).
18+
public func makeCache(
19+
draftModel: DFlashDraftModel,
20+
sinkSize: Int = 64,
21+
windowSize: Int = 1024
22+
) -> [ContextOnlyDraftKVCache] {
23+
(0 ..< draftModel.layers.count).map { _ in
24+
ContextOnlyDraftKVCache(sinkSize: sinkSize, windowSize: windowSize)
25+
}
26+
}
27+
28+
/// Generate draft tokens greedily using the DFlash draft model.
29+
///
30+
/// - Parameters:
31+
/// - targetModel: The target model (must conform to DFlashTargetModel for embed/lm_head access)
32+
/// - draftModel: The DFlash draft model
33+
/// - draftCache: The draft model's KV caches
34+
/// - stagedFirst: The first token (already verified by the target)
35+
/// - targetHidden: The target model's hidden states for context
36+
/// - blockLen: Number of tokens to draft
37+
/// - maskTokenTail: Mask token IDs for positions 1..blockLen-1
38+
/// - suppressTokenMask: Optional mask to suppress certain tokens
39+
/// - Returns: Draft token IDs [blockLen-1]
40+
public func draftGreedy(
41+
targetModel: any DFlashTargetModel,
42+
draftModel: DFlashDraftModel,
43+
draftCache: [ContextOnlyDraftKVCache],
44+
stagedFirst: MLXArray,
45+
targetHidden: MLXArray,
46+
blockLen: Int,
47+
maskTokenTail: MLXArray,
48+
suppressTokenMask: MLXArray? = nil
49+
) -> MLXArray {
50+
precondition(blockLen > 1, "draftGreedy requires blockLen > 1")
51+
52+
let blockTokenIDs = concatenated(
53+
[stagedFirst[..<1], maskTokenTail[..<(blockLen - 1)]],
54+
axis: 0
55+
)
56+
57+
// Get noise embedding from target model's embed_tokens
58+
let noiseEmbedding = targetModel.dflashEmbedTokens(blockTokenIDs[.newAxis])
59+
if DFlashDumper.isEnabled {
60+
DFlashDumper.saveInt("swift_block_token_ids", blockTokenIDs[.newAxis])
61+
DFlashDumper.save("swift_noise_embedding", noiseEmbedding)
62+
}
63+
64+
// Run the draft model
65+
let draftHidden = draftModel(
66+
noiseEmbedding: noiseEmbedding,
67+
targetHidden: targetHidden,
68+
cache: draftCache
69+
)
70+
if DFlashDumper.isEnabled {
71+
DFlashDumper.save("swift_draft_hidden", draftHidden)
72+
}
73+
74+
// Get draft logits via the target model's lm_head
75+
let draftLogits = targetModel.dflashLmHeadLogits(
76+
draftHidden[.ellipsis, 1..., 0...]
77+
)
78+
if DFlashDumper.isEnabled {
79+
DFlashDumper.save("swift_draft_logits", draftLogits)
80+
}
81+
82+
// Greedy decode
83+
let drafted = DFlashRuntime.greedyTokensWithMask(
84+
logits: draftLogits,
85+
suppressTokenMask: suppressTokenMask
86+
).squeezed(axis: 0)
87+
88+
asyncEval(drafted)
89+
return drafted
90+
}
91+
}

0 commit comments

Comments
 (0)