Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2
updates:
- package-ecosystem: cargo
directory: "/"
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: "/"
schedule:
interval: weekly
13 changes: 10 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches: [main]

permissions:
contents: read

env:
CARGO_TERM_COLOR: always

Expand All @@ -12,11 +15,12 @@ jobs:
name: Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- uses: dtolnay/rust-toolchain@3c5f7ea28cd621ae0bf5283f0e981fb97b8a7af9 # stable
with:
toolchain: stable
components: rustfmt, clippy
- uses: Swatinem/rust-cache@v2
- uses: Swatinem/rust-cache@23869a5bd66c73db3c0ac40331f3206eb23791dc # v2.9.1

- name: Format check
run: cargo fmt --check
Expand All @@ -32,6 +36,9 @@ jobs:
- name: Test
run: cargo test

- name: Security audit
run: cargo install cargo-audit --locked && cargo audit
Comment thread
SaschaOnTour marked this conversation as resolved.
Comment thread
SaschaOnTour marked this conversation as resolved.

- name: Install rustqual
run: cargo install rustqual
Comment thread
SaschaOnTour marked this conversation as resolved.

Expand Down
19 changes: 13 additions & 6 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
push:
branches: [main]

permissions:
contents: read

env:
CARGO_TERM_COLOR: always

Expand All @@ -15,7 +18,7 @@ jobs:
changed: ${{ steps.check.outputs.changed }}
version: ${{ steps.check.outputs.version }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4

- name: Detect version change
id: check
Expand All @@ -39,11 +42,12 @@ jobs:
if: needs.version-check.outputs.changed == 'true'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- uses: dtolnay/rust-toolchain@3c5f7ea28cd621ae0bf5283f0e981fb97b8a7af9 # stable
with:
toolchain: stable
components: rustfmt, clippy
- uses: Swatinem/rust-cache@v2
- uses: Swatinem/rust-cache@23869a5bd66c73db3c0ac40331f3206eb23791dc # v2.9.1

- name: Format check
run: cargo fmt --check
Expand All @@ -56,6 +60,9 @@ jobs:
- name: Test
run: cargo test

- name: Security audit
run: cargo install cargo-audit --locked && cargo audit
Comment thread
SaschaOnTour marked this conversation as resolved.
Comment thread
SaschaOnTour marked this conversation as resolved.

- name: Install rustqual
run: cargo install rustqual
Comment thread
SaschaOnTour marked this conversation as resolved.

Expand All @@ -75,7 +82,7 @@ jobs:
permissions:
contents: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4

- name: Create git tag
run: |
Expand All @@ -84,7 +91,7 @@ jobs:
git push origin "$TAG"

- name: GitHub Release
uses: softprops/action-gh-release@v2
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
with:
tag_name: v${{ needs.version-check.outputs.version }}
generate_release_notes: true
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Changed

- **CI hardening**: All GitHub Actions pinned to immutable commit SHAs, explicit `permissions: contents: read`, `cargo audit` step added.
- **Dependabot**: Added `.github/dependabot.yml` for Cargo and GitHub Actions weekly updates.
- **Public API safety**: `PqoCache::new()`, `TqCache::new()`, and `compute_qjl_signs()` now return `Result` instead of panicking on invalid input.
- **head_dim guard**: `GpuPrecomputed::new()` returns an error if `head_dim > 1024` (prevents silent CUDA shared memory overflow).
Comment thread
SaschaOnTour marked this conversation as resolved.
Outdated
- **CUDA pack helpers**: Extracted `tq_pack_2bit`, `tq_pack_3bit`, `tq_pack_4bit` into `tq_common.h` — eliminated 3x copy-pasted packing logic in `tq_quant_kernel.cu`.

## [0.2.0] - 2026-03-29

### Added
Expand Down
64 changes: 62 additions & 2 deletions src/cache/cuda/kernels/tq_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,25 @@ float tq_sign(uint64_t seed, int index) {
return (combined & 1ULL) == 0 ? 1.0f : -1.0f;
}

/* -----------------------------------------------------------------------
* Atomic byte-OR helper.
*
* On CUDA device: uses atomicOr on the containing uint32 word.
* On host: plain OR (single-threaded, no atomics needed).
* ----------------------------------------------------------------------- */

static inline __device__ __host__
void tq_atomic_byte_or(uint8_t *packed, int byte_idx, uint32_t bits) {
#ifdef __CUDA_ARCH__
if (bits) {
atomicOr((unsigned int *)(packed + (byte_idx & ~3)),
bits << ((byte_idx & 3) * 8));
}
#else
packed[byte_idx] |= (uint8_t)bits;
#endif
}

/* -----------------------------------------------------------------------
* 2-bit packing: 4 indices per byte (for quantize kernel)
*
Expand All @@ -97,8 +116,49 @@ static inline __device__ __host__
void tq_pack_2bit(uint8_t *packed, int tid, uint8_t idx) {
int byte_idx = tid >> 2;
int shift = (tid & 3) << 1;
/* Use atomicOr in CUDA, or sequential writes if single-threaded. */
packed[byte_idx] |= (idx & 0x3) << shift;
tq_atomic_byte_or(packed, byte_idx, (uint32_t)((idx & 0x3) << shift));
}

/* -----------------------------------------------------------------------
* 3-bit packing: 8 indices per 3 bytes (for quantize kernel)
*
* Matches the exact bit layout of tq_unpack_3bit above.
* Thread `tid` contributes its 3-bit index to the correct bit positions
* across up to 2 bytes. Uses tq_atomic_byte_or for thread-safe writes.
* ----------------------------------------------------------------------- */

static inline __device__ __host__
void tq_pack_3bit(uint8_t *packed, int tid, uint8_t idx) {
int group = tid >> 3;
int pos = tid & 7;
int base = group * 3;
uint32_t b0_bits = 0, b1_bits = 0, b2_bits = 0;
switch (pos) {
case 0: b0_bits = (idx & 0x7); break;
case 1: b0_bits = (idx & 0x7) << 3; break;
case 2: b0_bits = (idx & 0x3) << 6;
b1_bits = (idx >> 2) & 0x1; break;
case 3: b1_bits = (idx & 0x7) << 1; break;
case 4: b1_bits = (idx & 0x7) << 4; break;
case 5: b1_bits = (idx & 0x1) << 7;
b2_bits = (idx >> 1) & 0x3; break;
case 6: b2_bits = (idx & 0x7) << 2; break;
case 7: b2_bits = (idx & 0x7) << 5; break;
}
tq_atomic_byte_or(packed, base, b0_bits);
tq_atomic_byte_or(packed, base + 1, b1_bits);
tq_atomic_byte_or(packed, base + 2, b2_bits);
}

/* -----------------------------------------------------------------------
* 4-bit packing: 2 indices per byte (for quantize kernel)
* ----------------------------------------------------------------------- */

static inline __device__ __host__
void tq_pack_4bit(uint8_t *packed, int tid, uint8_t idx) {
int byte_idx = tid >> 1;
int shift = (tid & 1) << 2;
tq_atomic_byte_or(packed, byte_idx, (uint32_t)((idx & 0xF) << shift));
}

/* -----------------------------------------------------------------------
Expand Down
117 changes: 10 additions & 107 deletions src/cache/cuda/kernels/tq_quant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,45 +135,10 @@ __global__ void tq_quant_kernel(
}
__syncthreads();

/* Each thread atomically ORs its index bits into the correct byte */
if (bits == 2) {
int byte_idx = tid >> 2;
int shift = (tid & 3) << 1;
atomicOr((unsigned int *)(s_packed + (byte_idx & ~3)),
(unsigned int)((idx & 0x3) << shift) << ((byte_idx & 3) * 8));
} else if (bits == 3) {
/* 3-bit packing: 8 indices per 3 bytes — use atomicOr on bytes */
int group = tid >> 3;
int pos = tid & 7;
int base = group * 3;
/* Each position writes to specific bit positions.
* We use atomicOr on uint8 via atomicOr on aligned uint32. */
uint32_t b0_bits = 0, b1_bits = 0, b2_bits = 0;
switch (pos) {
case 0: b0_bits = (idx & 0x7); break;
case 1: b0_bits = (idx & 0x7) << 3; break;
case 2: b0_bits = (idx & 0x3) << 6;
b1_bits = (idx >> 2) & 0x1; break;
case 3: b1_bits = (idx & 0x7) << 1; break;
case 4: b1_bits = (idx & 0x7) << 4; break;
case 5: b1_bits = (idx & 0x1) << 7;
b2_bits = (idx >> 1) & 0x3; break;
case 6: b2_bits = (idx & 0x7) << 2; break;
case 7: b2_bits = (idx & 0x7) << 5; break;
}
if (b0_bits) atomicOr((unsigned int *)(s_packed + (base & ~3)),
b0_bits << ((base & 3) * 8));
if (b1_bits) atomicOr((unsigned int *)(s_packed + ((base + 1) & ~3)),
b1_bits << (((base + 1) & 3) * 8));
if (b2_bits) atomicOr((unsigned int *)(s_packed + ((base + 2) & ~3)),
b2_bits << (((base + 2) & 3) * 8));
} else {
/* 4-bit packing */
int byte_idx = tid >> 1;
int shift = (tid & 1) << 2;
atomicOr((unsigned int *)(s_packed + (byte_idx & ~3)),
(unsigned int)((idx & 0xF) << shift) << ((byte_idx & 3) * 8));
}
/* Each thread packs its index bits via shared helpers (tq_common.h) */
if (bits == 2) tq_pack_2bit(s_packed, tid, idx);
else if (bits == 3) tq_pack_3bit(s_packed, tid, idx);
else tq_pack_4bit(s_packed, tid, idx);
__syncthreads();

/* Write packed output */
Expand Down Expand Up @@ -281,40 +246,9 @@ __global__ void tq_quant_maxnorm_kernel(
}
__syncthreads();

if (bits == 2) {
int byte_idx = tid >> 2;
int shift = (tid & 3) << 1;
atomicOr((unsigned int *)(s_packed + (byte_idx & ~3)),
(unsigned int)((idx & 0x3) << shift) << ((byte_idx & 3) * 8));
} else if (bits == 3) {
int group = tid >> 3;
int pos = tid & 7;
int base = group * 3;
uint32_t b0_bits = 0, b1_bits = 0, b2_bits = 0;
switch (pos) {
case 0: b0_bits = (idx & 0x7); break;
case 1: b0_bits = (idx & 0x7) << 3; break;
case 2: b0_bits = (idx & 0x3) << 6;
b1_bits = (idx >> 2) & 0x1; break;
case 3: b1_bits = (idx & 0x7) << 1; break;
case 4: b1_bits = (idx & 0x7) << 4; break;
case 5: b1_bits = (idx & 0x1) << 7;
b2_bits = (idx >> 1) & 0x3; break;
case 6: b2_bits = (idx & 0x7) << 2; break;
case 7: b2_bits = (idx & 0x7) << 5; break;
}
if (b0_bits) atomicOr((unsigned int *)(s_packed + (base & ~3)),
b0_bits << ((base & 3) * 8));
if (b1_bits) atomicOr((unsigned int *)(s_packed + ((base + 1) & ~3)),
b1_bits << (((base + 1) & 3) * 8));
if (b2_bits) atomicOr((unsigned int *)(s_packed + ((base + 2) & ~3)),
b2_bits << (((base + 2) & 3) * 8));
} else {
int byte_idx = tid >> 1;
int shift = (tid & 1) << 2;
atomicOr((unsigned int *)(s_packed + (byte_idx & ~3)),
(unsigned int)((idx & 0xF) << shift) << ((byte_idx & 3) * 8));
}
if (bits == 2) tq_pack_2bit(s_packed, tid, idx);
else if (bits == 3) tq_pack_3bit(s_packed, tid, idx);
else tq_pack_4bit(s_packed, tid, idx);
__syncthreads();

if (tid < bytes_per_block) {
Expand Down Expand Up @@ -469,40 +403,9 @@ __global__ void tq_pack_kernel(
}
__syncthreads();

if (bits == 2) {
int byte_idx = tid >> 2;
int shift = (tid & 3) << 1;
atomicOr((unsigned int *)(s_packed + (byte_idx & ~3)),
(unsigned int)((idx & 0x3) << shift) << ((byte_idx & 3) * 8));
} else if (bits == 3) {
int group = tid >> 3;
int pos = tid & 7;
int base = group * 3;
uint32_t b0_bits = 0, b1_bits = 0, b2_bits = 0;
switch (pos) {
case 0: b0_bits = (idx & 0x7); break;
case 1: b0_bits = (idx & 0x7) << 3; break;
case 2: b0_bits = (idx & 0x3) << 6;
b1_bits = (idx >> 2) & 0x1; break;
case 3: b1_bits = (idx & 0x7) << 1; break;
case 4: b1_bits = (idx & 0x7) << 4; break;
case 5: b1_bits = (idx & 0x1) << 7;
b2_bits = (idx >> 1) & 0x3; break;
case 6: b2_bits = (idx & 0x7) << 2; break;
case 7: b2_bits = (idx & 0x7) << 5; break;
}
if (b0_bits) atomicOr((unsigned int *)(s_packed + (base & ~3)),
b0_bits << ((base & 3) * 8));
if (b1_bits) atomicOr((unsigned int *)(s_packed + ((base + 1) & ~3)),
b1_bits << (((base + 1) & 3) * 8));
if (b2_bits) atomicOr((unsigned int *)(s_packed + ((base + 2) & ~3)),
b2_bits << (((base + 2) & 3) * 8));
} else {
int byte_idx = tid >> 1;
int shift = (tid & 1) << 2;
atomicOr((unsigned int *)(s_packed + (byte_idx & ~3)),
(unsigned int)((idx & 0xF) << shift) << ((byte_idx & 3) * 8));
}
if (bits == 2) tq_pack_2bit(s_packed, tid, idx);
else if (bits == 3) tq_pack_3bit(s_packed, tid, idx);
else tq_pack_4bit(s_packed, tid, idx);
__syncthreads();

if (tid < bytes_per_block) {
Expand Down
25 changes: 12 additions & 13 deletions src/cache/pqo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,27 @@ pub struct PqoCache {
impl PqoCache {
/// Create a new PQO/PQ/TQ cache from configuration.
///
/// # Panics
///
/// Panics if `head_dim` is not divisible by `QUANT_BLOCK_SIZE` (32).
pub fn new(config: CacheConfig) -> Self {
assert!(
config.head_dim % QUANT_BLOCK_SIZE == 0,
"head_dim ({}) must be divisible by QUANT_BLOCK_SIZE ({QUANT_BLOCK_SIZE}). \
Models with head_dim={} are not supported by TurboQuant compression.",
config.head_dim,
config.head_dim
);
/// Returns an error if `head_dim` is not divisible by `QUANT_BLOCK_SIZE` (32).
pub fn new(config: CacheConfig) -> candle_core::Result<Self> {
if config.head_dim % QUANT_BLOCK_SIZE != 0 {
candle_core::bail!(
"head_dim ({}) must be divisible by QUANT_BLOCK_SIZE ({QUANT_BLOCK_SIZE}). \
Models with head_dim={} are not supported by TurboQuant compression.",
config.head_dim,
config.head_dim
);
}
let storage = CompressedStorage::new(
config.num_kv_heads,
config.head_dim,
config.bits,
config.num_layers,
);
Self {
Ok(Self {
config,
storage,
precomputed: None,
}
})
}

/// Ensure precomputed tensors are initialized on the given device.
Expand Down
8 changes: 8 additions & 0 deletions src/cache/precomputed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ impl GpuPrecomputed {
config.bits
)));
}
const CUDA_MAX_HEAD_DIM: usize = 1024;
if config.head_dim > CUDA_MAX_HEAD_DIM {
return Err(super::cache_err(format!(
"head_dim {} exceeds CUDA_MAX_HEAD_DIM ({}). \
CUDA shared memory buffer overflow would occur.",
config.head_dim, CUDA_MAX_HEAD_DIM
)));
Comment thread
SaschaOnTour marked this conversation as resolved.
Outdated
}
let block_dim = QUANT_BLOCK_SIZE;
let polar_bits = config.bits - 1;
let head_dim = config.head_dim;
Expand Down
Loading
Loading