Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;

fn set_seed(&self, _: u64) -> Result<()>;
fn get_current_seed(&self) -> Result<u64>;

/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
Expand Down
19 changes: 17 additions & 2 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Implement conversion traits for tensors
use crate::{DType, Device, Error, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;

Expand Down Expand Up @@ -94,6 +93,8 @@ from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
from_tensor!(i32);
from_tensor!(i16);
from_tensor!(u32);
from_tensor!(u8);

Expand Down Expand Up @@ -131,6 +132,16 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I16 => {
for v in vs.to_vec1::<i16>()? {
f.write_i16::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
Expand All @@ -141,10 +152,14 @@ impl Tensor {
f.write_all(&vs)?;
}
DType::F8E4M3 => {
for v in vs.to_vec1::<F8E4M3>()? {
let vs = vs.to_vec1::<float8::F8E4M3>()?;
for v in vs {
f.write_u8(v.to_bits())?
}
}
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt())
}
}
Ok(())
}
Expand Down
34 changes: 34 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand All @@ -163,6 +185,18 @@ impl VecOps for i64 {
}
}

impl VecOps for float8::F8E4M3 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}

#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
Expand Down
Loading
Loading