Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ fn unsqueeze_like<B: Backend>(
}

impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C> {
fn float_dtypes(device: &Device<Self>) -> Vec<FloatDType> {
B::float_dtypes(device)
}

fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
AutodiffTensor::new(B::float_from_data(data, device))
}
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ use crate::{
use super::base::{expand, permute, sign, unfold};

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_dtypes(_device: &Device<Self>) -> Vec<FloatDType> {
vec![
FloatDType::F64,
FloatDType::F32,
FloatDType::F16,
FloatDType::BF16,
]
}

fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
match data.dtype {
burn_tensor::DType::F64 => super::base::from_data::<f64>(data, device),
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-cubecl/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ where
I: IntElement,
BT: BoolElement,
{
fn float_dtypes(_device: &Device<Self>) -> Vec<FloatDType> {
vec![
FloatDType::F64,
FloatDType::F32,
FloatDType::F16,
FloatDType::BF16,
]
}

fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
match data.dtype {
DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ use burn_tensor::{
use std::marker::PhantomData;

impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_dtypes(device: &Device<Self>) -> Vec<FloatDType> {
B::float_dtypes(device)
}

fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
Expand Down
6 changes: 5 additions & 1 deletion crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use alloc::vec::Vec;
use burn_tensor::ops::FloatTensor;
use burn_tensor::ops::InterpolateMode;
use burn_tensor::{TensorMetadata, cast::ToElement};
use burn_tensor::{TensorMetadata, cast::ToElement, Device};

// Current crate
use super::{
Expand Down Expand Up @@ -52,6 +52,10 @@ where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn float_dtypes(_device: &Device<Self>) -> Vec<FloatDType> {
NdArrayTensor::dtypes()
}

fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
NdArrayTensor::from_data(data)
}
Expand Down
18 changes: 11 additions & 7 deletions crates/burn-ndarray/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use core::mem;

use burn_tensor::{
DType, Element, Shape, TensorData, TensorMetadata,
quantization::{
QParams, QTensorPrimitive, QuantLevel, QuantMode, QuantScheme, QuantValue,
QuantizationStrategy, SymmetricQuantization,
},
};
use burn_tensor::{DType, Element, Shape, TensorData, TensorMetadata, quantization::{
QParams, QTensorPrimitive, QuantLevel, QuantMode, QuantScheme, QuantValue,
QuantizationStrategy, SymmetricQuantization,
}, FloatDType};

use alloc::vec::Vec;
use ndarray::{ArcArray, ArrayD, IxDyn};
Expand Down Expand Up @@ -458,6 +455,13 @@ macro_rules! reshape {
}

impl NdArrayTensor {
pub fn dtypes() -> Vec<FloatDType> {
vec![
FloatDType::F64,
FloatDType::F32,
]
}

/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
pub fn from_data(mut data: TensorData) -> NdArrayTensor {
let shape = mem::take(&mut data.shape);
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-router/src/ops/op_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ use burn_tensor::{
};

impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
fn float_dtypes(_device: &Device<Self>) -> Vec<FloatDType> {
todo!()
}

fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
let client = get_client::<R>(device);
let out = client.register_tensor_data(data);
Expand Down
15 changes: 10 additions & 5 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use super::TchOps;
use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
use burn_tensor::ops::{BoolTensor, FloatTensor};
use burn_tensor::{
DType, Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata,
backend::Backend,
ops::{FloatTensorOps, IntTensor},
};
use burn_tensor::{DType, Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata, backend::Backend, ops::{FloatTensorOps, IntTensor}, Device};
use half::{bf16, f16};

impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
fn float_dtypes(_device: &Device<Self>) -> Vec<FloatDType> {
vec![
FloatDType::F64,
FloatDType::F32,
FloatDType::F16,
FloatDType::BF16,
]
}

fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
match data.dtype {
DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-tensor/src/tensor/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use alloc::vec::Vec;

/// Operations on float tensors.
pub trait FloatTensorOps<B: Backend> {
/// Returns the `FloatDType`s supported by the device.
fn float_dtypes(device: &Device<B>) -> Vec<FloatDType>;

/// Creates a new tensor from the data structure.
///
/// # Arguments
Expand Down
Loading