Skip to content

Conversation

@EricLBuehler
Copy link
Member

@EricLBuehler EricLBuehler commented Nov 17, 2025

Adds support for:

  • i32

  • i16

  • f6e2m3

  • f6e3m2

  • f4

  • f8e8m0
    These are "dummy" dtypes: this just means a typed bitbucket essentially.

  • CPU compiles

  • CUDA compiles

  • Metal compiles

@EricLBuehler EricLBuehler marked this pull request as ready for review November 17, 2025 19:07
@zackangelo
Copy link
Contributor

signed dtypes are nice 👌 I've been having to pass u32s as i32s in cuda launch code and have been worried that would blow up in my face at some point

Copy link
Member

@ivarflakstad ivarflakstad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is gonna be a good one! 🙌

DType::F64 => convert_slice::<f64>(data, shape, device),
DType::F8E4M3 => convert_slice::<F8E4M3>(data, shape, device),
DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device),
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be in this PR, but I'd prefer to hoist this out into a helper fn.
Perhaps use convert_slice::<u8>(data, shape, device) and manually change the storage dtype? Might not even need a dedicated fn now that I think about it 🤔

let shape = view.shape();

// Create storage with the appropriate dummy type variant
let storage = match device {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Déjà vu helper fn 👀

#[test]
fn load_i8() {
let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03";
std::fs::write("test_i8.safetensors", bytes).unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, just noting down while I'm here: we should use temp files for these kinds of tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely 👍

@EricLBuehler
Copy link
Member Author

Addressed most of the review comments; left some as unresolved for posterity.

let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I32(storage), DType::F8E4M3) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an idea for how to reduce the massive size of this match.
Adding it to the ever growing list of things to do :)

S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You resolved this but looks the same to me?

(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

#[inline(always)]
fn f32(v: f32) -> f32 {
(crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
Self::f64(v as f64) as f32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still revert! ;)

}

fn get_current_seed(&self) -> Result<u64> {
crate::bail!("cannot get the CPU rng seed with get_current_seed")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look into this later

Comment on lines -27 to +29
S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?),
S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks slightly off to me

Comment on lines -52 to +59
(S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?),
(S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not wrapping in storage?

Comment on lines -91 to -93
(S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => {
S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not supported?

Comment on lines -126 to +134
(S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F8E4M3(_), S::F8E4M3(_)) => Err(CudaError::InternalError(
"Map2InPlace not supported for F8E4M3",
))?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct?

S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You resolved this but looks the same to me?

(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment on lines -183 to +205
DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
DType::F8E4M3 => "asort_desc_f8e4m3",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use the same logic as above, no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants