Skip to content

Commit cb493b6

Browse files
committed
support float16 and bfloat16. beware of precision issues! close #1
1 parent 907f1a9 commit cb493b6

File tree

5 files changed

+121
-84
lines changed

5 files changed

+121
-84
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The scan efficiently solves first-order recurrences of the form `x[t] = gate[t]
99
The `accelerated_scan.warp` C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available
1010
on each level of hierarchy: [warp shuffles](https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/) within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.
1111

12-
The derivation of [Chunked Scan](https://proger.github.io/posts/scan/chunk.html) has been used to extend tree-level Blelloch algorithm to block
12+
The derivation of [Chunked Scan](https://proger.github.io/posts/scan/chunk.html) has been used to extend tree-level Blelloch algorithm to block.
1313

1414
A similar implementation is available in `accelerated_scan.triton` using a Triton's `tl.associative_scan` primitive. It [requires Triton 2.2 for its `enable_fp_fusion` flag](https://twitter.com/darkproger/status/1742663555835363635).
1515

@@ -58,3 +58,9 @@ forward speed of (8,1536,seqlen), inference mode:
5858
8 32768.0 31.459671 62.557182 5.645697
5959
9 65536.0 66.787331 125.208572 11.297921
6060
```
61+
62+
## Notes on Precision
63+
64+
When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the recurrent implementation):
65+
66+
![max-abs-error.png](max-abs-error.png)

0 commit comments

Comments
 (0)