Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;

fn set_seed(&self, _: u64) -> Result<()>;
fn get_current_seed(&self) -> Result<u64>;

/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
Expand Down
19 changes: 17 additions & 2 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Implement conversion traits for tensors
use crate::{DType, Device, Error, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;

Expand Down Expand Up @@ -94,6 +93,8 @@ from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
from_tensor!(i32);
from_tensor!(i16);
from_tensor!(u32);
from_tensor!(u8);

Expand Down Expand Up @@ -131,6 +132,16 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I16 => {
for v in vs.to_vec1::<i16>()? {
f.write_i16::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
Expand All @@ -141,10 +152,14 @@ impl Tensor {
f.write_all(&vs)?;
}
DType::F8E4M3 => {
for v in vs.to_vec1::<F8E4M3>()? {
let vs = vs.to_vec1::<float8::F8E4M3>()?;
for v in vs {
f.write_u8(v.to_bits())?
}
}
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt())
}
}
Ok(())
}
Expand Down
34 changes: 34 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand All @@ -163,6 +185,18 @@ impl VecOps for i64 {
}
}

impl VecOps for float8::F8E4M3 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}

#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
Expand Down
Loading
Loading