Skip to content

fix(caching): use dtype.itemsize instead of finfo().bits for byte size#262

Open
Gpgabriel25 wants to merge 1 commit into
erfanzar:mainfrom
Gpgabriel25:pr/kv-cache-dtype-sizing
Open

fix(caching): use dtype.itemsize instead of finfo().bits for byte size#262
Gpgabriel25 wants to merge 1 commit into
erfanzar:mainfrom
Gpgabriel25:pr/kv-cache-dtype-sizing

Conversation

@Gpgabriel25
Copy link
Copy Markdown

Problem

jnp.finfo(dtype).bits // 8 crashes with ZeroDivisionError for dtypes
that lack an finfo entry — float4_e2m1fn, integer types, and other
exotic quantization formats. This silently blocks use of low-precision
KV caches in UnifiedAttentionCacheConfig, RaggedPagesCacheConfig,
MlaRaggedPagesCacheConfig, and get_dtype_size() in the eSurge engine.

Fix

Replace all five occurrences with jnp.dtype(dtype).itemsize, which is
the correct, dtype-agnostic way to query byte width in JAX.

Evidence

Probe run against main vs this branch, same venv, jnp.float4_e2m1fn:

branch unified num_pages ragged num_pages ratio fp4/bf16
main ZeroDivisionError ZeroDivisionError
this PR bf16: 262 144 → fp4: 524 288 bf16: 262 144 → fp4: 524 288 2.0×

fp4 now correctly gets double the page capacity relative to bf16, as
expected from halving the per-element byte cost.

Files

  • easydel/caching/_specs.py
  • easydel/caching/mla_ragged_page/cache.py
  • `easydel/caching/ragged_paBody:
## Problem

`jnp.finfo(dtype).bits // 8` crashes with `ZeroDivisionError` for dtypes
that lack an `finfo` entry — `float4_e2m1fn`, integer types, and other
exotic quantization formats. This silently blocks use of low-precision
KV caches in `UnifiedAttentionCacheConfig`, `RaggedPagesCacheConfig`,
`MlaRaggedPagesCacheConfig`, and `get_dtype_size()` in the eSurge engine.

## Fix

Replace all five occurrences with `jnp.dtype(dtype).itemsize`, which is
the correct, dtype-agnostic way to query byte width in JAX.

## Evidence

Probe run against `main` vs this branch, same venv, `jnp.float4_e2m1fn`:

| branch | unified num_pages | ragged num_pages | ratio fp4/bf16 |
|--------|------------------|-----------------|----------------|
| main   | `ZeroDivisionError` | `ZeroDivisionError` ||
| **this PR** | bf16: 262 144 → fp4: 524 288 | bf16: 262 144 → fp4: 524 288 | **2.0×** |

fp4 now correctly gets double the page capacity relative to bf16, as
expected from halving the per-element byte cost.

## Files

- `easydel/caching/_specs.py`
- `easydel/caching/mla_ragged_page/cache.py`
- `easydel/caching/ragged_page/cache.py`
- `easydel/caching/unified_attention/cache.py`
- `easydel/inference/esurge/utils.py`

## Tests

- `tests/inference/test_esurge_kvdtype_override.py` (new, 5 tests)
- `tests/layers/caching/test_kv_budget_parallelism.py` (new, 2 tests)
```ge/cache.py`
- `easydel/caching/unified_attention/cache.py`
- `easydel/inference/esurge/utils.py`

## Tests

- `tests/inference/test_esurge_kvdtype_override.py` (new, 5 tests)
- `tests/layers/caching/test_kv_budget_parallelism.py` (new, 2 tests)

@Gpgabriel25 Gpgabriel25 changed the title Pr/kv cache dtype sizing fix(caching): use dtype.itemsize instead of finfo().bits for byte size Apr 24, 2026
@Gpgabriel25 Gpgabriel25 marked this pull request as draft April 24, 2026 17:54
@Gpgabriel25 Gpgabriel25 force-pushed the pr/kv-cache-dtype-sizing branch from 9bca629 to 8d94f20 Compare April 24, 2026 17:58
@Gpgabriel25 Gpgabriel25 marked this pull request as ready for review April 24, 2026 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant