Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() {
match op {

Check failure on line 54 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

non-exhaustive patterns: `&op::Op::UnaryScalar(_, _, _)` not covered

Check failure on line 54 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Clippy

non-exhaustive patterns: `&op::Op::UnaryScalar(_, _, _)` not covered

Check failure on line 54 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

non-exhaustive patterns: `&op::Op::UnaryScalar(_, _, _)` not covered
Op::IndexAdd(t1, t2, t3, _)
| Op::Scatter(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
Expand Down Expand Up @@ -129,8 +129,6 @@
| Op::Permute(node, _)
| Op::Narrow(node, _, _, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::Powf(node, _)
| Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
Expand Down Expand Up @@ -639,7 +637,7 @@
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => {

Check failure on line 640 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

no variant or associated item named `Elu` found for enum `op::Op` in the current scope

Check failure on line 640 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Clippy

no variant or associated item named `Elu` found for enum `op::Op` in the current scope

Check failure on line 640 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

no variant or associated item named `Elu` found for enum `op::Op` in the current scope
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
Expand All @@ -648,10 +646,10 @@
// node == alpha * (e^x - 1) for x <= 0, reuse it
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?

Check failure on line 649 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

the `?` operator can only be applied to values that implement `Try`

Check failure on line 649 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Clippy

the `?` operator can only be applied to values that implement `std::ops::Try`

Check failure on line 649 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

the `?` operator can only be applied to values that implement `Try`
}
Op::Powf(arg, e) => {

Check failure on line 651 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

no variant or associated item named `Powf` found for enum `op::Op` in the current scope

Check failure on line 651 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Clippy

no variant or associated item named `Powf` found for enum `op::Op` in the current scope

Check failure on line 651 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

no variant or associated item named `Powf` found for enum `op::Op` in the current scope
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;

Check failure on line 652 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

the `?` operator can only be applied to values that implement `Try`

Check failure on line 652 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Clippy

the `?` operator can only be applied to values that implement `std::ops::Try`

Check failure on line 652 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

the `?` operator can only be applied to values that implement `Try`
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Expand Down
80 changes: 77 additions & 3 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::Tensor;
use float8::F8E4M3;
use half::{bf16, f16};
use num_traits::float::Float;
use num_traits::{float::Float, PrimInt};

Check warning on line 7 in candle-core/src/op.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused import: `PrimInt`

Check failure on line 7 in candle-core/src/op.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused import: `PrimInt`

Check warning on line 7 in candle-core/src/op.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

unused import: `PrimInt`

#[derive(Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Expand Down Expand Up @@ -72,10 +72,18 @@
Sign,
}

// Op that is applied to itself with an additional tensor wide scalar argument
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryScalarOp {
Elu,
Powf
}

#[derive(Clone)]
pub enum Op {
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
UnaryScalar(Tensor, Tensor, UnaryScalarOp),
Cmp(Tensor, CmpOp),
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Expand Down Expand Up @@ -164,8 +172,6 @@
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
Expand Down Expand Up @@ -241,6 +247,33 @@
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
}

pub trait UnaryScalarOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
fn i64(v1: i64, v2: i64) -> i64;

// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
const BF16_VEC: bool = false;
fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16], _: bf16) {}
const F16_VEC: bool = false;
fn f16_vec(_xs: &[f16], _ys: &mut [f16], _: f16) {}
const F8E4M3_VEC: bool = false;
fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3], _: F8E4M3) {}
const F32_VEC: bool = false;
fn f32_vec(_xs: &[f32], _ys: &mut [f32], _: f32) {}
const F64_VEC: bool = false;
fn f64_vec(_xs: &[f64], _ys: &mut [f64], _: f64) {}
}

pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
Expand All @@ -266,6 +299,8 @@
pub(crate) struct Ceil;
pub(crate) struct Round;
pub(crate) struct Sign;
pub(crate) struct Elu;
pub(crate) struct Powf;

macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
Expand Down Expand Up @@ -929,6 +964,45 @@
}
}

impl UnaryScalarOpT for Powf {
const NAME: &'static str = "powf";
const KERNEL: &'static str = "upowf";
const V: Self = Powf;
#[inline(always)]
fn bf16(v: bf16, exponent: bf16) -> bf16 {
v.powf(exponent)
}
#[inline(always)]
fn f16(v: f16, exponent: f16) -> f16 {
v.powf(exponent)
}
#[inline(always)]
fn f8e4m3(v: F8E4M3, exponent: F8E4M3) -> F8E4M3 {
v.powf(exponent)
}
#[inline(always)]
fn f32(v: f32, exponent: f32) -> f32 {
v.powf(exponent)
}
#[inline(always)]
fn f64(v: f64, exponent: f64) -> f64 {
v.powf(exponent)
}
#[inline(always)]
fn u8(v: u8, exponent: u8) -> u8 {
v.pow(exponent.into())
}
#[inline(always)]
fn u32(v: u32, exponent: u32) -> u32 {
v.pow(exponent.into())
}
#[inline(always)]
fn i64(v: i64, exponent: i64) -> i64 {
v.pow(exponent.try_into().expect("exponent must be positive"))
}
}


/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
/// properly checked when creating a new value
#[derive(Clone)]
Expand Down
Loading