Skip to content

Add wp.bfloat16 scalar data type #1332

@shi-eric

Description

@shi-eric

Summary

Add bfloat16 (brain floating-point 16) as a first-class scalar type in Warp, on par with the existing wp.float16. The bfloat16 format uses 8 exponent bits and 7 mantissa bits, giving it the same dynamic range as float32 at half the memory, which makes it widely used in ML training and inference workloads.

Scope

Python type system

  • New wp.bfloat16 scalar type registered alongside existing float types
  • Round-trip conversions between float32 and bfloat16 (round-to-nearest-even, with correct NaN handling)
  • Direct array construction from Python floats, lists, tuples, and NumPy float32 data
  • Support in vectors, matrices, and structs

Native C++/CUDA

  • wp_bfloat16 struct with float/half conversions and arithmetic operators
  • Platform-specific conversion paths for CUDA (__CUDA_ARCH__), Clang, and generic C++
  • Comparison operators with float promotion for correct IEEE 754 NaN/zero semantics

Kernel support

  • Arithmetic, comparison, and casting builtins
  • Math function overloads (abs, sqrt, log, exp, trig functions, etc.)
  • Autodiff support
  • Atomic operations: add, min, max (with CAS fallback)
  • Tile operations including matmul via cuBLASDx

Framework interop

  • DLPack import/export
  • PyTorch torch.bfloat16 tensors (zero-copy where possible)
  • JAX jax.numpy.bfloat16 arrays
  • Fabric and sparse array support

Compile-time optimization

  • WP_NO_BFLOAT16 guard in builtin.h, cuda_crt.h, and tile.h so modules that don't use bfloat16 skip compiling its overloads, reducing LLVM codegen time

Metadata

Metadata

Assignees

Labels

feature requestRequest for something to be added

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions