Skip to content

Commit 68d17ab

Browse files
committed
Residual fixes
1 parent ccaa447 commit 68d17ab

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

candle-core/src/cuda_backend/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,7 @@ impl BackendStorage for CudaStorage {
13491349
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
13501350
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
13511351
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
1352-
S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8e4m3"),
1352+
S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8_e4m3"),
13531353
S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => {
13541354
return Err(CudaError::UnsupportedDtype {
13551355
dtype: self.dtype(),

candle-core/src/safetensors.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -630,15 +630,15 @@ mod tests {
630630
}
631631

632632
#[test]
633-
fn load_i8() {
634-
let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03";
635-
std::fs::write("test_i8.safetensors", bytes).unwrap();
636-
let weights = load("test_i8.safetensors", &Device::Cpu).unwrap();
633+
fn load_u8() {
634+
let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"U8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03";
635+
std::fs::write("test_u8.safetensors", bytes).unwrap();
636+
let weights = load("test_u8.safetensors", &Device::Cpu).unwrap();
637637
let tensor = weights.get("x").unwrap();
638638
assert_eq!(tensor.dims(), &[2]);
639-
assert_eq!(tensor.dtype(), DType::I64);
640-
let data: Vec<i64> = tensor.to_vec1().unwrap();
639+
assert_eq!(tensor.dtype(), DType::U8);
640+
let data: Vec<u8> = tensor.to_vec1().unwrap();
641641
assert_eq!(data, vec![1, 3]);
642-
std::fs::remove_file("test_i8.safetensors").unwrap();
642+
std::fs::remove_file("test_u8.safetensors").unwrap();
643643
}
644644
}

0 commit comments

Comments
 (0)