-
Notifications
You must be signed in to change notification settings - Fork 477
Add wp.bfloat16 scalar data type #1332
Copy link
Copy link
Open
Labels
feature requestRequest for something to be addedRequest for something to be added
Milestone
Description
Summary
Add bfloat16 (brain floating-point 16) as a first-class scalar type in Warp, on par with the existing wp.float16. The bfloat16 format uses 8 exponent bits and 7 mantissa bits, giving it the same dynamic range as float32 at half the memory, which makes it widely used in ML training and inference workloads.
Scope
Python type system
- New
wp.bfloat16scalar type registered alongside existing float types - Round-trip conversions between float32 and bfloat16 (round-to-nearest-even, with correct NaN handling)
- Direct array construction from Python floats, lists, tuples, and NumPy float32 data
- Support in vectors, matrices, and structs
Native C++/CUDA
wp_bfloat16struct with float/half conversions and arithmetic operators- Platform-specific conversion paths for CUDA (
__CUDA_ARCH__), Clang, and generic C++ - Comparison operators with float promotion for correct IEEE 754 NaN/zero semantics
Kernel support
- Arithmetic, comparison, and casting builtins
- Math function overloads (abs, sqrt, log, exp, trig functions, etc.)
- Autodiff support
- Atomic operations: add, min, max (with CAS fallback)
- Tile operations including matmul via cuBLASDx
Framework interop
- DLPack import/export
- PyTorch
torch.bfloat16tensors (zero-copy where possible) - JAX
jax.numpy.bfloat16arrays - Fabric and sparse array support
Compile-time optimization
WP_NO_BFLOAT16guard inbuiltin.h,cuda_crt.h, andtile.hso modules that don't use bfloat16 skip compiling its overloads, reducing LLVM codegen time
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
feature requestRequest for something to be addedRequest for something to be added