-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Add dummy dtypes #3195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add dummy dtypes #3195
Conversation
|
signed dtypes are nice 👌 I've been having to pass u32s as i32s in cuda launch code and have been worried that would blow up in my face at some point |
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is gonna be a good one! 🙌
| DType::F64 => convert_slice::<f64>(data, shape, device), | ||
| DType::F8E4M3 => convert_slice::<F8E4M3>(data, shape, device), | ||
| DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device), | ||
| DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't have to be in this PR, but I'd prefer to hoist this out into a helper fn.
Perhaps use convert_slice::<u8>(data, shape, device) and manually change the storage dtype? Might not even need a dedicated fn now that I think about it 🤔
| let shape = view.shape(); | ||
|
|
||
| // Create storage with the appropriate dummy type variant | ||
| let storage = match device { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Déjà vu helper fn 👀
| #[test] | ||
| fn load_i8() { | ||
| let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; | ||
| std::fs::write("test_i8.safetensors", bytes).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, just noting down while I'm here: we should use temp files for these kinds of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely 👍
|
Addressed most of the review comments; left some as unresolved for posterity. |
| let data = unary_map(storage, layout, |v| v as f64); | ||
| Ok(Self::F64(data)) | ||
| } | ||
| (Self::I32(storage), DType::F8E4M3) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have an idea for how to reduce the massive size of this match.
Adding it to the ever growing list of things to do :)
| S::F16(s) => self.f(s, d, l, S::F16)?, | ||
| S::F32(s) => self.f(s, d, l, S::F32)?, | ||
| S::F64(s) => self.f(s, d, l, S::F64)?, | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You resolved this but looks the same to me?
| (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
candle-core/src/op.rs
Outdated
| #[inline(always)] | ||
| fn f32(v: f32) -> f32 { | ||
| (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v | ||
| Self::f64(v as f64) as f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still revert! ;)
| } | ||
|
|
||
| fn get_current_seed(&self) -> Result<u64> { | ||
| crate::bail!("cannot get the CPU rng seed with get_current_seed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a look into this later
| S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks slightly off to me
| (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not wrapping in storage?
| (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { | ||
| S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not supported?
| (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_l, src, src_l, d), | ||
| (S::F8E4M3(_), S::F8E4M3(_)) => Err(CudaError::InternalError( | ||
| "Map2InPlace not supported for F8E4M3", | ||
| ))?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct?
| S::F16(s) => self.f(s, d, l, S::F16)?, | ||
| S::F32(s) => self.f(s, d, l, S::F32)?, | ||
| S::F64(s) => self.f(s, d, l, S::F64)?, | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You resolved this but looks the same to me?
| (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
| DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), | ||
| DType::F8E4M3 => "asort_desc_f8e4m3", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use the same logic as above, no?
Adds support for:
i32
i16
f6e2m3
f6e3m2
f4
f8e8m0
These are "dummy" dtypes: this just means a typed bitbucket essentially.
CPU compiles
CUDA compiles
Metal compiles