feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,46 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Core backend interfaces and data structures for executing tensor operations in Burn."
documentation = "https://docs.rs/burn-backend"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-backend"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-backend"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std"]
doc = ["default"]
std = ["rand/std", "num-traits/std", "burn-std/std", "cubecl?/std"]
tracing = ["burn-std/tracing", "cubecl/tracing"]
cubecl = ["dep:cubecl", "burn-std/cubecl"]
cubecl-cuda = ["cubecl", "cubecl/cuda"]
cubecl-hip = ["cubecl", "cubecl/hip"]
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
cubecl-cpu = ["cubecl", "cubecl/cpu"]
[dependencies]
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
cubecl = { workspace = true, optional = true, default-features = false }
bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
derive-new = { workspace = true }
enumset = { workspace = true }
hashbrown = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true, default-features = false }
rand_distr = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }
[dev-dependencies]
rand = { workspace = true, features = ["thread_rng"] }
paste = { workspace = true }

View File

@@ -0,0 +1,4 @@
# Burn Backend
This crate includes the core backend interfaces and data structures for executing tensor operations
in Burn.

View File

@@ -0,0 +1,391 @@
use burn_std::DType;
pub use burn_std::backtrace::BackTrace;
use alloc::string::String;
use enumset::{EnumSet, EnumSetType};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::element::Element;
use crate::ops::*;
use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use crate::{QTensorPrimitive, TensorData, TensorMetadata};
use super::DeviceOps;
/// This trait defines all types and functions needed for a backend to be used with burn.
///
/// ## Design
///
/// This trait aims to be as unopinionated as possible and allows implementations to define
/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
/// into this trait.
///
/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
/// Since we minimize assumptions, we chose to separate these types, as they are used in
/// different contexts. However, some backends may have a generic tensor type that is used
/// for all data types.
///
/// ### Eager Mode
///
/// Because burn supports dynamic graphs, the backend trait is designed around kernel
/// implementations that can be called without any mutable context or graph. This may not be
/// ideal for backends that want to configure their computational graphs and execute them
/// multiple times.
///
/// To implement this kind of backend, channels could be used to communicate with a backend
/// server thread to build the computation graphs and re-execute the ones that are repeated,
/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
/// be extracted from it, allowing other backends of the same kind to be quickly integrated
/// with burn. This pattern could also be used to create an operation fusion trait, which
/// allows backends to define what kind of graph structures can be fused into one operation.
///
/// ### Multi-Threaded
///
/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
/// reuse tensors' buffer without locking; see the next section on the Mutable API.
///
/// ### Mutable API
///
/// There is no mutable or inplace operation API to implement, but that does not mean that
/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
/// backends can dispatch to their owned inplace operations for better performance.
///
/// ## Documentation
///
/// Most of the documentation for each function can be found on the user API
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// struct in the `burn-tensor` crate.
/// For modules, public functions are often created, which can be used by `burn-core` modules.
pub trait Backend:
FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
+ ActivationOps<Self>
+ QTensorOps<Self>
+ TransactionOps<Self>
+ Clone
+ Default
+ Sized
+ Send
+ Sync
+ core::fmt::Debug
+ 'static
{
/// Device type.
type Device: DeviceOps;
/// Tensor primitive to be used for all float operations.
type FloatTensorPrimitive: TensorMetadata + 'static;
/// Default float element type.
type FloatElem: Element;
/// Tensor primitive to be used for all int operations.
type IntTensorPrimitive: TensorMetadata + 'static;
/// Int element type.
type IntElem: Element;
/// Tensor primitive to be used for all bool operations.
type BoolTensorPrimitive: TensorMetadata + 'static;
/// Tensor primitive to be used for all bool operations.
type BoolElem: Element;
/// Tensor primitive to be used for all quantized operations.
type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
/// If autodiff is enabled.
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
/// Sets the current allocation mode to persistent.
#[allow(unused_variables)]
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output {
func(input)
}
/// Manually triggers a memory cleanup on the given device.
#[allow(unused_variables)]
fn memory_cleanup(device: &Self::Device) {}
/// Name of the backend.
fn name(device: &Self::Device) -> String;
/// Seeds the backend on the specified device.
///
/// There is no guarantee that only the specified device will be seeded, but it is guaranteed
/// that at least the specified device will be seeded.
///
/// In all cases, this should ensure deterministic execution for a single-threaded program.
fn seed(device: &Self::Device, seed: u64);
/// Sync the backend, ensure that all computation are finished.
fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
Ok(())
}
/// Marks the given data as being used as a staging buffer for transfer between CPU and
/// accelerators like GPUs.
///
/// The given data might be transferred to pinned memory or another format to improve data transfer
/// speed.
fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
where
Iter: Iterator<Item = &'a mut TensorData>,
{
}
/// Whether the type is fully supported by the specified device for general operations.
///
/// A type is considered supported if it can be used for the full suite of tensor
/// operations, including storage, conversion, and basic arithmetic.
///
/// Returning `false` does not necessarily mean the device cannot handle the type at all.
/// For instance, a device might support a type only for specialized hardware
/// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
/// types should return `false` here as they are not globally supported.
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
}
/// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
}
/// An error that can happen when syncing a device.
#[derive(Error, Serialize, Deserialize)]
pub enum ExecutionError {
/// A generic error happened during execution.
///
/// The backtrace and context information should be included in the reason string.
#[error("An error happened during execution\nCaused by:\n {reason}")]
WithContext {
/// The reason of the error.
reason: String,
},
/// A generic error happened during execution thrown in the Burn project.
///
/// The full context isn't captured by the string alone.
#[error("An error happened during execution\nCaused by:\n {reason}")]
Generic {
/// The reason of the error.
reason: String,
/// The backtrace.
#[serde(skip)]
backtrace: BackTrace,
},
}
impl core::fmt::Debug for ExecutionError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
/// Trait that allows a backend to support autodiff.
pub trait AutodiffBackend: Backend {
/// The inner backend type.
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
/// Gradients type.
type Gradients: Send;
/// Backward pass.
///
/// # Arguments
///
/// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
///
/// # Returns
///
/// The gradients.
fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
/// Returns the gradients of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to extract the gradients from.
///
/// # Returns
///
/// An optional tensor containing the gradient.
fn grad(
tensor: &FloatTensor<Self>,
grads: &Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend>>;
/// Pops the gradients of a tensor and returns them.
///
/// # Arguments
///
/// * `tensor` - The tensor to pop the gradients from.
/// * `grads` - The gradients.
///
/// # Returns
///
/// An optional tensor containing the given gradients.
fn grad_remove(
tensor: &FloatTensor<Self>,
grads: &mut Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend>>;
/// Replace the gradients of a tensor with the one provided.
///
/// If no gradient existed for the provided tensor, register it.
///
/// # Arguments
///
/// * `tensor` - The tensor to pop the gradients from.
/// * `grads` - The gradients.
/// * `grad` - The updated grad tensor.
fn grad_replace(
tensor: &FloatTensor<Self>,
grads: &mut Self::Gradients,
grad: FloatTensor<Self::InnerBackend>,
);
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
}
/// Describes how a data type can be used on a given device.
///
/// A data type may be supported for different classes of operations. Not all
/// data types that appear in hardware or kernel implementations are suitable
/// for general-purpose tensor operations.
#[derive(Debug, EnumSetType)]
pub enum DTypeUsage {
/// The type can be stored in device memory and converted to and from
/// other supported data types.
Storage,
/// The type supports general-purpose arithmetic and common tensor
/// operations (e.g. elementwise ops, reductions, etc.).
Arithmetic,
/// The type is supported by hardware-accelerated execution paths.
///
/// This typically indicates support for accelerator-backed compute units (e.g., tensor
/// cores executing MMA instructions) for high-performance operations such as matrix
/// multiplication and operations that lower to it.
///
/// # Notes
/// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
/// [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
/// *and* accelerated paths.
/// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
/// suitable for general-purpose tensor operations and may only be used
/// in specific accelerated operations.
///
/// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
/// operations are accelerated or which accelerator features are available.
Accelerated,
}
/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
pub type DTypeUsageSet = EnumSet<DTypeUsage>;
impl DTypeUsage {
/// Returns the usage set required for general-purpose tensor support.
pub fn general() -> DTypeUsageSet {
DTypeUsage::Storage | DTypeUsage::Arithmetic
}
}

View File

@@ -0,0 +1,17 @@
pub use burn_std::device::*;
/// Device trait for all burn backend devices.
pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + Device {
/// Returns the [device id](DeviceId).
fn id(&self) -> DeviceId {
self.to_id()
}
/// Returns the inner device without autodiff enabled.
///
/// For most devices this is a no-op that returns `self`. For autodiff-enabled
/// devices, this returns the underlying inner device.
fn inner(&self) -> &Self {
self
}
}

View File

@@ -0,0 +1,10 @@
mod base;
mod device;
mod primitive;
pub use base::*;
pub use device::*;
pub use primitive::*;
/// Backend operations on tensors.
pub mod ops;

View File

@@ -0,0 +1,279 @@
use crate::tensor::FloatTensor;
use crate::{Backend, Scalar, TensorMetadata};
use core::f64::consts::SQRT_2;
/// Activation function operations.
///
/// This trait let backend implementations override activation functions for better performance.
pub trait ActivationOps<B: Backend> {
/// Applies the LeakyReLU activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
///
/// # Returns
///
/// The output tensor.
fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {
let mask = B::float_lower_elem(tensor.clone(), 0f32.into());
let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);
// Update the tensor where the values are `< 0` by `tensor * negative_slope`.
B::float_mask_where(tensor, mask, scaled_tensor)
}
/// Applies the ReLU activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into());
B::float_mask_fill(tensor, mask, 0f32.into())
}
/// Applies the ReLU activation function backward.
///
/// # Arguments
///
/// * `output` - The output tensor.
///
/// # Returns
///
/// The gradient.
fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
let mask = B::float_lower_equal_elem(output, 0f32.into());
B::float_mask_fill(grad, mask, 0.into())
}
/// Applies the Gelu activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());
let x = B::float_erf(x);
let x = B::float_add_scalar(x, 1f32.into());
let x = B::float_mul(tensor, x);
B::float_div_scalar(x, 2f32.into())
}
/// Applies the PReLu activation function.
/// # Arguments
/// * `tensor` - The input tensor
/// * `alpha` - The weight tensor
fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
let mask = B::float_lower_elem(tensor.clone(), 0f32.into());
let scaled_tensor = B::float_mul(tensor.clone(), alpha);
B::float_mask_where(tensor, mask, scaled_tensor)
}
/// Applies the Gelu activation function backward.
///
/// # Arguments
///
/// * `x` - The tensor.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output tensor.
fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
// Derivative of the approximate gelu implementation based on tanh.
let constant_1 = 0.0356774;
let constant_2 = 0.797885;
let constant_3 = 0.0535161;
let constant_4 = 0.398942;
let x3 = B::float_powi_scalar(x.clone(), 3.into());
let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());
let c2 = B::float_mul_scalar(x.clone(), constant_2.into());
let c3 = B::float_mul_scalar(x3, constant_3.into());
let c4 = B::float_mul_scalar(x, constant_4.into());
let inner1 = B::float_add(c1, c2);
let inner2 = B::float_add(c3, c4);
let tanh = B::float_tanh(inner1);
let sech = B::float_powi_scalar(tanh.clone(), 2.into());
let sech = B::float_neg(sech);
let sech = B::float_add_scalar(sech, 1.into());
let y1 = B::float_mul_scalar(tanh, 0.5.into());
let y2 = B::float_mul(inner2, sech);
let y2 = B::float_add_scalar(y2, 0.5.into());
let y = B::float_add(y1, y2);
B::float_mul(y, grad)
}
/// Applies the Sigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
let dtype = tensor.dtype();
let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
B::float_exp(B::float_neg(tensor_full)),
1.0.into(),
))));
B::float_cast(tensor_tmp, dtype.into())
}
/// Applies the Sigmoid activation function backward.
///
/// # Arguments
///
/// * `output` - The output tensor of the sigmoid function.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output tensor.
fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
let value = B::float_mul(
output.clone(),
B::float_add_scalar(B::float_neg(output), 1.0.into()),
);
B::float_mul(value, grad)
}
/// Applies the hard Sigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `alpha` - The alpha value that the tensor is multiplied with.
/// * `beta` - The beta value that is added to the tensor
///
/// # Returns
///
/// The output tensor.
fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {
let dtype = tensor.dtype();
let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
let tensor_tmp = B::float_clamp(
B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),
0.0.into(),
1.0.into(),
);
B::float_cast(tensor_tmp, dtype.into())
}
/// Applies the LogSigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
// To avoid overflow, we use the log-sum-exp trick.
//
// ```ignore
// log(sigmoid(x)) = log(1/(1 + exp(-x)))
// = log(1) - log(1 + exp(-x))
// = -log(1 + exp(-x))
// = -log(exp(0) + exp(-x))
// ```
// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
// following equivalence:
// ```ignore
// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
// ```
//
// This extends the range of values for which we obtain accurate results.
// max(-x, 0)
let tensor_neg = B::float_neg(tensor);
let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into());
let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());
let max_elem_neg = B::float_neg(max_elem.clone());
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
let z = B::float_add(
B::float_exp(max_elem_neg.clone()),
B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
);
// -max(-x, 0) - log(-z)
B::float_sub(max_elem_neg, B::float_log(z))
}
/// Applies the LogSigmoid activation function backward.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output gradient.
fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
// Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
// -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
// where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
//
// This simplifies to:
// -max_derive - (z-1)/z if x is >= 0
// -max_derive + (z-1)/z if x is < 0
let shape = x.shape();
let dtype = x.dtype();
let device = B::float_device(&x);
// max(-x, 0)
let x_neg = B::float_neg(x);
let mask = B::float_lower_elem(x_neg.clone(), 0f32.into()); // -x < 0 or x >= 0
let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
let z = B::float_add(
B::float_exp(B::float_neg(max_elem.clone())),
B::float_exp(B::float_sub(x_neg, max_elem)),
);
// Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
let ones = B::float_ones(shape, &device, dtype.into());
let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());
let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());
// grad * (max_derive - sign * (1 - (1 / z)))
B::float_mul(
grad,
B::float_sub(
max_derive,
B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
),
)
}
}

View File

@@ -0,0 +1,56 @@
use crate::tensor::{Device, IntTensor};
use crate::{Backend, TensorData, element::ElementConversion};
use alloc::vec::Vec;
use burn_std::Shape;
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `data` - The input tensor data.
///
/// # Returns
///
/// A 2D tensor containing the indices of all non-zero elements of the given tensor.
/// Each row contains the indices of a non-zero element.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> {
let dims = &data.shape;
let ndims = dims.len();
let count_nonzero = data.iter::<bool>().filter(|&v| v).count();
/// Converts a flat index into a vector of indices for the specified tensor shape
fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> {
shape
.iter()
.rev()
.scan(index, |i, size| {
let dim_idx = *i % size;
*i /= size;
Some((dim_idx as i64).elem())
})
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect()
}
let indices = data
.iter::<bool>()
.enumerate()
.filter_map(|(index, v)| if v { Some(index) } else { None })
.map(|index| unravel_index::<B>(index, dims))
.collect::<Vec<_>>()
.concat();
B::int_from_data(
TensorData::new(indices, Shape::new([count_nonzero, ndims])),
device,
)
}

View File

@@ -0,0 +1,563 @@
use super::{
argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
};
use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
use crate::{Backend, TensorData, TensorMetadata};
use crate::{ExecutionError, Scalar};
use alloc::vec::Vec;
use burn_std::{Shape, Slice};
use core::future::Future;
/// Bool Tensor API for basic operations, see
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// for documentation on each function.
pub trait BoolTensorOps<B: Backend> {
/// Creates a new bool tensor.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The boolean tensor with the given shape.
fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
/// Creates a new bool tensor filled false.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The boolean tensor filled with false.
fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
/// Creates a new bool tensor filled true.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The boolean tensor filled with true.
fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn bool_into_data(
tensor: BoolTensor<B>,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
/// Creates a tensor from the data structure.
///
/// # Arguments
///
/// * `data` - The data structure.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the data.
fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
/// Converts bool tensor to int tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The int tensor with the same data as the bool tensor.
fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
/// Converts bool tensor to float tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The float tensor with the same data as the bool tensor.
fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
/// Gets the device of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device of the tensor.
fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
/// Moves the tensor to the device.
fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
/// Reshapes the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `shape` - The new shape.
///
/// # Returns
///
/// The tensor with the new shape.
fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
/// Gets the values from the tensor for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `slices` - The slices specifying ranges and steps for each dimension.
///
/// # Returns
///
/// The tensor with the values for the given slices.
///
/// # Note
///
/// Empty slices (where start >= end) are handled at the high-level tensor API and will not
/// be passed to this method. Backend implementations do not need to handle empty slices.
fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
/// Sets the values in the tensor for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `ranges` - The ranges to set the values for.
/// * `value` - The values to set.
///
/// # Returns
///
/// The tensor with the values set for the given ranges.
///
/// # Note
///
/// Empty slice assignments (where any slice range produces 0 elements) are handled at the
/// high-level tensor API and will not be passed to this method. Backend implementations do
/// not need to handle empty slice assignments.
fn bool_slice_assign(
tensor: BoolTensor<B>,
slices: &[Slice],
value: BoolTensor<B>,
) -> BoolTensor<B>;
/// Fills the tensor with values from the value tensor if the mask is true at the given
/// indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `mask` - The mask.
/// * `value` - The value tensor.
///
/// # Returns
///
/// The tensor with the values filled.
fn bool_mask_where(
tensor: BoolTensor<B>,
mask: BoolTensor<B>,
value: BoolTensor<B>,
) -> BoolTensor<B>;
/// Fills the tensor with the given value if the mask is true at the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `mask` - The mask.
/// * `value` - The value.
///
/// # Returns
///
/// The tensor with the values filled.
fn bool_mask_fill(tensor: BoolTensor<B>, mask: BoolTensor<B>, value: Scalar) -> BoolTensor<B>;
/// Gather elements from the tensor at the given indices.
///
/// # Arguments
///
/// * `dim` - The dimension to gather from.
/// * `tensor` - The tensor.
/// * `indices` - The indices.
fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
/// Scatter a given value to the tensor at the given indices using boolean or reduction.
///
/// # Arguments
///
/// * `dim` - The dimension to scatter to.
/// * `tensor` - The tensor.
/// * `indices` - The indices.
/// * `value` - The value.
///
/// # Returns
///
/// The tensor with the values scattered.
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<B>,
indices: IntTensor<B>,
value: BoolTensor<B>,
) -> BoolTensor<B>;
/// Select tensor elements along the given dimension corresponding to the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indices` - The indices of the elements to select.
///
/// # Returns
///
/// The tensor with the selected elements.
fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
// Default implementation: convert to int, select, then convert back to bool
let int_tensor = B::bool_into_int(tensor);
let selected = B::int_select(int_tensor, dim, indices);
B::int_equal_elem(selected, 1.into())
}
/// Assign the selected elements along the given dimension corresponding to the given indices
/// to the given value using sum reduction.
///
/// # Arguments
///
/// * `tensor` - The tensor to assign the values to.
/// * `dim` - The dimension to select from.
/// * `indices` - The indices of the elements to assign.
/// * `value` - The values to assign.
///
/// # Returns
///
/// The tensor with the assigned values.
fn bool_select_or(
tensor: BoolTensor<B>,
dim: usize,
indices: IntTensor<B>,
value: BoolTensor<B>,
) -> BoolTensor<B> {
// Default implementation: convert to int, select_assign, then convert back to bool
let int_tensor = B::bool_into_int(tensor);
let int_values = B::bool_into_int(value);
let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
// After select_assign with sum reduction, any non-zero value should be true
B::int_greater_elem(assigned, 0.into())
}
/// Repeats one dimension of the tensor a given number of times along that dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to repeat.
/// * `times` - The number of times to repeat the dimension.
///
/// # Returns
///
/// The tensor with the dimension repeated.
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
}
/// Concatenates the tensors along the given dimension.
///
/// # Arguments
///
/// * `tensors` - The tensors to concatenate.
/// * `dim` - The dimension to concatenate along.
///
/// # Returns
///
/// The tensor with the tensors concatenated along the given dimension.
///
/// # Note
///
/// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
/// high-level tensor API and will not be passed to this method. Backend implementations do
/// not need to handle empty tensors.
fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
cat_with_slice_assign::<B, Bool>(tensors, dim)
}
/// Equates the two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor with the result of the equate.
fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
/// Element-wise non-equality comparison.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor with the result of the comparison.
fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
let equal_tensor = B::bool_equal(lhs, rhs);
B::bool_not(equal_tensor)
}
/// Element-wise equality comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left-hand side tensor.
/// * `rhs` - The right-hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn bool_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B>;
/// Element-wise non-equality comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left-hand side tensor.
/// * `rhs` - The right-hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B> {
let equal_tensor = B::bool_equal_elem(lhs, rhs);
B::bool_not(equal_tensor)
}
/// Inverses boolean values.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The tensor with the result of the negation.
fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
/// Executes the logical and (`&&`) operation on two boolean tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor with the result of the logical and.
fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
/// Executes the logical or (`||`) operation on two boolean tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor with the result of the logical or.
fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
/// Element-wise exclusive or.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor with the result of the comparison.
fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
Self::bool_not_equal(lhs, rhs)
}
/// Transposes a bool tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to transpose.
///
/// # Returns
///
/// The transposed tensor.
fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
let ndims = tensor.shape().num_dims();
Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
}
/// Swaps two dimensions of a bool tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to swap the dimensions of.
/// * `dim1` - The first dimension to swap.
/// * `dim2` - The second dimension to swap.
///
/// # Returns
///
/// The tensor with the dimensions swapped.
fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
/// Permutes the dimensions of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to permute the dimensions of.
/// * `axes` - The new order of the dimensions.
/// # Returns
///
/// The tensor with the dimensions permuted.
fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
/// Reverse the order of elements in a tensor along the given axes.
///
/// # Arguments
///
/// * `tensor` - The tensor to reverse.
/// * `axes` - The axes to reverse.
///
/// The tensor with the elements reversed.
fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
/// Tests if any element in the boolean `tensor` evaluates to True.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
///
/// # Returns
///
/// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
let sum = B::int_sum(B::bool_into_int(tensor));
B::int_greater_elem(sum, 0.into())
}
/// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
/// * `dim` - The axis along which to test.
///
/// # Returns
///
/// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
/// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
/// evaluates to True, False otherwise.
fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
B::int_greater_elem(sum, 0.into())
}
/// Tests if all elements in the boolean `tensor` evaluate to True.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
///
/// # Returns
///
/// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
/// evaluate to True, False otherwise.
fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
let num_elems = tensor.shape().num_elements() as i64;
let sum = B::int_sum(B::bool_into_int(tensor));
B::int_equal_elem(sum, num_elems.into())
}
/// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
/// * `dim` - The axis along which to test.
///
/// # Returns
///
/// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
/// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
/// evaluates to True, False otherwise.
fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
let num_elems = tensor.shape()[dim] as i64;
let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
B::int_equal_elem(sum, num_elems.into())
}
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `tensor` - The input tensor.
///
/// # Returns
///
/// A 2D tensor containing the indices of all non-zero elements of the given tensor.
/// Each row contains the indices of a non-zero element.
fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
async {
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
// Reading the data to count the number of truth values might cause sync but is required.
let device = B::bool_device(&tensor);
let data = B::bool_into_data(tensor)
.await
.expect("Can read the data without error");
argwhere_data::<B>(data, &device)
}
}
/// Broadcasts the bool `tensor` to the given `shape`.
fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
/// Unfold windows along a dimension.
///
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
/// where windows are advanced by `step` at each index.
///
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
///
/// # Arguments
///
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
/// * `dim` - the selected dim.
/// * `size` - the size of each unfolded window.
/// * `step` - the step between each window.
///
/// # Returns
///
/// A tensor view with shape ``[pre=..., windows, size, post=...]``.
fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
}

View File

@@ -0,0 +1,40 @@
use crate::{
Backend, TensorMetadata,
tensor::{BasicOps, TensorKind},
};
use alloc::vec::Vec;
use burn_std::Slice;
pub(crate) fn cat_with_slice_assign<B: Backend, K: TensorKind<B> + BasicOps<B>>(
tensors: Vec<K::Primitive>,
dim: usize,
) -> K::Primitive {
let first_tensor = tensors.first().expect("Tensors should not be empty");
let mut shape = first_tensor.shape();
let device = K::device(first_tensor);
let dtype = first_tensor.dtype();
let output_dim_length: usize = tensors.iter().map(|tensor| tensor.shape()[dim]).sum();
shape[dim] = output_dim_length;
let mut tensor_output = K::empty(shape.clone(), &device, dtype);
let indices_select_all = shape.iter().map(|d| 0..*d).collect::<Vec<_>>();
let mut output_index = 0;
for tensor in tensors {
let mut indices = indices_select_all.clone();
let tensor_dim_length = tensor.shape()[dim];
indices[dim] = output_index..output_index + tensor_dim_length;
output_index += tensor_dim_length;
// Convert ranges to Slice
let slices: Vec<Slice> = indices
.iter()
.map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1))
.collect();
tensor_output = K::slice_assign(tensor_output, &slices, tensor);
}
tensor_output
}

View File

@@ -0,0 +1,20 @@
mod activation;
mod bool_tensor;
mod int_tensor;
mod modules;
mod qtensor;
mod tensor;
mod transaction;
pub(crate) mod argwhere;
pub(crate) mod cat;
pub(crate) mod repeat_dim;
pub(crate) mod sort;
pub use activation::*;
pub use bool_tensor::*;
pub use int_tensor::*;
pub use modules::*;
pub use qtensor::*;
pub use tensor::*;
pub use transaction::*;

View File

@@ -0,0 +1,107 @@
use core::f32;
#[allow(unused_imports)]
use num_traits::Float as _;
use burn_std::Shape;
use crate::{
Backend, TensorMetadata,
ops::AttentionModuleOptions,
tensor::{BoolTensor, FloatTensor},
};
/// Computes softmax(QKᵗ * scale) · V using separate kernels.
/// Serves as a fallback when FlashAttention is not used.
pub fn attention_fallback<B: Backend>(
query: FloatTensor<B>,
key: FloatTensor<B>,
value: FloatTensor<B>,
mask: Option<BoolTensor<B>>,
attn_bias: Option<FloatTensor<B>>,
options: AttentionModuleOptions,
) -> FloatTensor<B> {
if let Some(softcap) = options.softcap {
assert!(softcap > 0.0, "softcap must be positive, got {softcap}");
}
// Attention scores: A = QKᵗ * scale
let query_shape = query.shape().dims::<4>();
let scale = options
.scale
.unwrap_or_else(|| 1.0 / (*query_shape.last().unwrap() as f64).sqrt());
let transposed_key = B::float_transpose(key);
let qk = B::float_matmul(query, transposed_key);
let attention_scores = B::float_mul_scalar(qk, scale.into());
// Softcap: softcap * tanh(scores / softcap)
// Applied to raw logits before any -inf masking, so that tanh does not
// map -inf to a finite value (which would break masking semantics).
let attention_scores = if let Some(softcap) = options.softcap {
let scaled = B::float_div_scalar(attention_scores, softcap.into());
let tanh = B::float_tanh(scaled);
B::float_mul_scalar(tanh, softcap.into())
} else {
attention_scores
};
// Bool masking
let attention_scores = if let Some(mask) = mask {
B::float_mask_fill(attention_scores, mask, f32::NEG_INFINITY.into())
} else {
attention_scores
};
// Causal masking: mask positions where col > row (future positions)
let attention_scores = if options.is_causal {
let causal_mask = build_causal_mask::<B>(&attention_scores);
B::float_mask_fill(attention_scores, causal_mask, f32::NEG_INFINITY.into())
} else {
attention_scores
};
// Additive bias (ALiBi, relative position biases, etc.)
let attention_scores = if let Some(bias) = attn_bias {
B::float_add(attention_scores, bias)
} else {
attention_scores
};
// Softmax: S = softmax(A)
let max_per_dim = B::float_max_dim(attention_scores.clone(), 3);
let minus_max = B::float_sub(attention_scores, max_per_dim);
let numerator = B::float_exp(minus_max);
let sum_exp = B::float_sum_dim(numerator.clone(), 3);
let softmax = B::float_div(numerator, sum_exp);
// Context: S · V
B::float_matmul(softmax, value)
}
/// Builds a causal (upper-triangular) bool mask where `true` means "mask this position".
/// Shape: [batch_size, num_heads, seq_q, seq_k], masking positions where col > row.
fn build_causal_mask<B: Backend>(attention_scores: &FloatTensor<B>) -> BoolTensor<B> {
let device = B::float_device(attention_scores);
let scores_shape = attention_scores.shape().dims::<4>();
let [batch_size, num_heads, seq_q, seq_k] = scores_shape;
// row indices [seq_q, 1] and col indices [1, seq_k]
// Offset col indices so that the causal boundary aligns at the bottom-right corner,
// which handles cross-attention (seq_k > seq_q) correctly.
let offset = seq_k as i64 - seq_q as i64;
let rows = B::int_reshape(
B::int_arange(0..seq_q as i64, &device),
Shape::new([seq_q, 1]),
);
let cols = B::int_reshape(
B::int_arange(0..seq_k as i64, &device),
Shape::new([1, seq_k]),
);
// mask where col > row + offset (upper triangle)
let rows_shifted = B::int_add_scalar(rows, offset.into());
let mask_2d = B::int_lower(rows_shifted, cols);
// Reshape to [1, 1, seq_q, seq_k] then expand to [batch_size, num_heads, seq_q, seq_k]
let mask_4d = B::bool_reshape(mask_2d, Shape::new([1, 1, seq_q, seq_k]));
B::bool_expand(mask_4d, Shape::new([batch_size, num_heads, seq_q, seq_k]))
}

View File

@@ -0,0 +1,312 @@
use crate::{
Backend, TensorMetadata,
ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
tensor::FloatTensor,
};
use alloc::vec;
use burn_std::{Shape, Slice};
/// Reference implementation of grid_sample_2d that supports all options.
///
/// # Arguments
///
/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
/// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
/// * `options` - Grid sampling options
///
/// # Returns
///
/// A tensor with shape (N, C, H_out, W_out)
pub fn float_grid_sample_2d_ref<B: Backend>(
tensor: FloatTensor<B>,
grid: FloatTensor<B>,
options: GridSampleOptions,
) -> FloatTensor<B> {
match options.mode {
InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(
tensor,
grid,
options.padding_mode,
options.align_corners,
),
_ => todo!(
"Default implementation for grid_sample_2d with {:?} unimplemented",
options.mode
),
}
}
/// Bilinear grid sampling implementation.
fn float_grid_sample_2d_bilinear<B: Backend>(
tensor: FloatTensor<B>,
grid: FloatTensor<B>,
padding_mode: GridSamplePaddingMode,
align_corners: bool,
) -> FloatTensor<B> {
let n = tensor.shape()[0];
let c = tensor.shape()[1];
let h_in = tensor.shape()[2];
let w_in = tensor.shape()[3];
let h_out = grid.shape()[1];
let w_out = grid.shape()[2];
let spatial_in = h_in * w_in;
let spatial_out = h_out * w_out;
// Separate x and y coordinates from grid
// shape: (N, H_out, W_out, 1)
let grid_x_slice = vec![
Slice::new(0, Some(n as isize), 1),
Slice::new(0, Some(h_out as isize), 1),
Slice::new(0, Some(w_out as isize), 1),
Slice::new(0, Some(1), 1),
];
let grid_y_slice = vec![
Slice::new(0, Some(n as isize), 1),
Slice::new(0, Some(h_out as isize), 1),
Slice::new(0, Some(w_out as isize), 1),
Slice::new(1, Some(2), 1),
];
let grid_x = B::float_slice(grid.clone(), &grid_x_slice);
let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));
let grid_y = B::float_slice(grid.clone(), &grid_y_slice);
let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));
// Convert normalized grid coordinates [-1, 1] to pixel coordinates
let w_in_f = w_in as f64;
let h_in_f = h_in as f64;
let (grid_x, grid_y) = if align_corners {
// align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2
// Maps -1 to 0 and 1 to width - 1
let grid_x = B::float_add_scalar(grid_x, 1f32.into());
let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).into());
let grid_y = B::float_add_scalar(grid_y, 1f32.into());
let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).into());
(grid_x, grid_y)
} else {
// align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5
// Maps -1 to -0.5 and 1 to width - 0.5
let grid_x = B::float_add_scalar(grid_x, 1f32.into());
let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).into());
let grid_x = B::float_sub_scalar(grid_x, 0.5f32.into());
let grid_y = B::float_add_scalar(grid_y, 1f32.into());
let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).into());
let grid_y = B::float_sub_scalar(grid_y, 0.5f32.into());
(grid_x, grid_y)
};
// Apply padding mode to coordinates
let (grid_x, grid_y) = match padding_mode {
GridSamplePaddingMode::Border => {
// Clamp coordinates to valid range [0, size-1]
let grid_x = B::float_clamp(grid_x, 0f32.into(), ((w_in - 1) as f32).into());
let grid_y = B::float_clamp(grid_y, 0f32.into(), ((h_in - 1) as f32).into());
(grid_x, grid_y)
}
GridSamplePaddingMode::Reflection => {
// Reflect coordinates at boundaries
let grid_x = reflect_coordinates::<B>(grid_x, w_in_f, align_corners);
let grid_y = reflect_coordinates::<B>(grid_y, h_in_f, align_corners);
(grid_x, grid_y)
}
GridSamplePaddingMode::Zeros => {
// Keep coordinates as-is, we'll mask out-of-bounds later
(grid_x, grid_y)
}
};
// Get floor indices for the four corners
let grid_x_floored = B::float_floor(grid_x.clone());
let grid_y_floored = B::float_floor(grid_y.clone());
// Compute interpolation weights (fractional part)
let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone());
let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone());
// Convert to integer indices
let x0 = B::float_into_int(grid_x_floored.clone());
let y0 = B::float_into_int(grid_y_floored.clone());
let x1 = B::float_into_int(B::float_add_scalar(grid_x_floored, 1f32.into()));
let y1 = B::float_into_int(B::float_add_scalar(grid_y_floored, 1f32.into()));
// Create masks for out-of-bounds coordinates (only used for zeros padding)
let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros {
let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.into());
let x0_valid = B::bool_and(
x0_valid,
B::int_lower_elem(x0.clone(), (w_in as i32).into()),
);
let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.into());
let x1_valid = B::bool_and(
x1_valid,
B::int_lower_elem(x1.clone(), (w_in as i32).into()),
);
let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.into());
let y0_valid = B::bool_and(
y0_valid,
B::int_lower_elem(y0.clone(), (h_in as i32).into()),
);
let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.into());
let y1_valid = B::bool_and(
y1_valid,
B::int_lower_elem(y1.clone(), (h_in as i32).into()),
);
(
Some(B::bool_and(x0_valid.clone(), y0_valid.clone())),
Some(B::bool_and(x0_valid.clone(), y1_valid.clone())),
Some(B::bool_and(x1_valid.clone(), y0_valid)),
Some(B::bool_and(x1_valid, y1_valid)),
)
} else {
(None, None, None, None)
};
// Clamp indices to valid range for gather
let x0_clamped = B::int_clamp(x0, 0.into(), ((w_in - 1) as i32).into());
let x1_clamped = B::int_clamp(x1, 0.into(), ((w_in - 1) as i32).into());
let y0_clamped = B::int_clamp(y0, 0.into(), ((h_in - 1) as i32).into());
let y1_clamped = B::int_clamp(y1, 0.into(), ((h_in - 1) as i32).into());
// Linear indices: idx = y * W_in + x
let w_in_scalar: i32 = w_in as i32;
let idx_00 = B::int_add(
B::int_mul_scalar(y0_clamped.clone(), w_in_scalar.into()),
x0_clamped.clone(),
);
let idx_01 = B::int_add(
B::int_mul_scalar(y1_clamped.clone(), w_in_scalar.into()),
x0_clamped,
);
let idx_10 = B::int_add(
B::int_mul_scalar(y0_clamped, w_in_scalar.into()),
x1_clamped.clone(),
);
let idx_11 = B::int_add(
B::int_mul_scalar(y1_clamped, w_in_scalar.into()),
x1_clamped,
);
// [N, 1, H_out, W_out] -> [N, 1, H_out * W_out]
let idx_00 = B::int_reshape(idx_00, Shape::new([n, 1, spatial_out]));
let idx_01 = B::int_reshape(idx_01, Shape::new([n, 1, spatial_out]));
let idx_10 = B::int_reshape(idx_10, Shape::new([n, 1, spatial_out]));
let idx_11 = B::int_reshape(idx_11, Shape::new([n, 1, spatial_out]));
// [N, 1, spatial] -> [N, C, spatial]
let idx_00 = B::int_expand(idx_00, Shape::new([n, c, spatial_out]));
let idx_01 = B::int_expand(idx_01, Shape::new([n, c, spatial_out]));
let idx_10 = B::int_expand(idx_10, Shape::new([n, c, spatial_out]));
let idx_11 = B::int_expand(idx_11, Shape::new([n, c, spatial_out]));
let tensor_flat = B::float_reshape(tensor, Shape::new([n, c, spatial_in]));
let sample_00 = B::float_gather(2, tensor_flat.clone(), idx_00);
let sample_01 = B::float_gather(2, tensor_flat.clone(), idx_01);
let sample_10 = B::float_gather(2, tensor_flat.clone(), idx_10);
let sample_11 = B::float_gather(2, tensor_flat, idx_11);
// Reshape samples to (N, C, H_out, W_out)
let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));
let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));
let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));
let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));
// Apply masks for zeros padding (set out-of-bounds samples to 0)
let (sample_00, sample_01, sample_10, sample_11) =
if padding_mode == GridSamplePaddingMode::Zeros {
let mask_00 = mask_00.unwrap();
let mask_01 = mask_01.unwrap();
let mask_10 = mask_10.unwrap();
let mask_11 = mask_11.unwrap();
let mask_00_inv = B::bool_not(mask_00);
let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out]));
let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out]));
let mask_01_inv = B::bool_not(mask_01);
let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out]));
let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out]));
let mask_10_inv = B::bool_not(mask_10);
let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out]));
let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out]));
let mask_11_inv = B::bool_not(mask_11);
let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out]));
let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out]));
(
B::float_mask_fill(sample_00, mask_00_inv, 0f32.into()),
B::float_mask_fill(sample_01, mask_01_inv, 0f32.into()),
B::float_mask_fill(sample_10, mask_10_inv, 0f32.into()),
B::float_mask_fill(sample_11, mask_11_inv, 0f32.into()),
)
} else {
(sample_00, sample_01, sample_10, sample_11)
};
// Compute bilinear interpolation weights
let one_minus_x = B::float_neg(x_frac.clone());
let one_minus_x = B::float_add_scalar(one_minus_x, 1f32.into());
let one_minus_y = B::float_neg(y_frac.clone());
let one_minus_y = B::float_add_scalar(one_minus_y, 1f32.into());
let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone());
let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone());
let weight_10 = B::float_mul(x_frac.clone(), one_minus_y);
let weight_11 = B::float_mul(x_frac, y_frac);
// Bilinear interpolation
let result = B::float_mul(sample_00, weight_00);
let result = B::float_add(result, B::float_mul(sample_01, weight_01));
let result = B::float_add(result, B::float_mul(sample_10, weight_10));
B::float_add(result, B::float_mul(sample_11, weight_11))
}
/// Reflect coordinates at boundaries using a triangle wave pattern.
///
/// For align_corners=true: reflects within [0, size-1]
/// For align_corners=false: reflects within [-0.5, size-0.5]
fn reflect_coordinates<B: Backend>(
coords: FloatTensor<B>,
size: f64,
align_corners: bool,
) -> FloatTensor<B> {
let (min_val, max_val) = if align_corners {
(0.0f32, (size - 1.0) as f32)
} else {
(-0.5f32, (size - 0.5) as f32)
};
let span = max_val - min_val;
if span <= 0.0 {
// Edge case: size is 1, just return min_val everywhere
let zeros = B::float_mul_scalar(coords, 0f32.into());
return B::float_add_scalar(zeros, min_val.into());
}
// Triangle wave formula: span - |((x mod 2*span) - span)| + min_val
let period = 2.0 * span;
// x = abs(coord - min_val)
let x = B::float_sub_scalar(coords, min_val.into());
let x = B::float_abs(x);
// x_mod = x - floor(x / period) * period
let x_div = B::float_div_scalar(x.clone(), period.into());
let x_div_floor = B::float_floor(x_div);
let x_mod = B::float_sub(x, B::float_mul_scalar(x_div_floor, period.into()));
// result = span - abs(x_mod - span) + min_val
let diff = B::float_sub_scalar(x_mod, span.into());
let abs_diff = B::float_abs(diff);
let reflected = B::float_sub_scalar(abs_diff, span.into());
let reflected = B::float_neg(reflected);
B::float_add_scalar(reflected, min_val.into())
}

View File

@@ -0,0 +1,18 @@
/// Module with convolution operations.
pub mod conv;
/// Module with attention operations.
pub mod attention;
/// Module with unfold operations.
pub mod unfold;
/// Module with pooling operations.
pub mod pool;
/// Module for grid_sample operations
pub mod grid_sample;
mod base;
pub use base::*;

View File

@@ -0,0 +1,176 @@
use crate::tensor::{FloatTensor, IntTensor};
use crate::{Backend, TensorMetadata};
use burn_std::Shape;
use super::{MaxPool1dBackward, MaxPool1dWithIndices};
pub(crate) fn avg_pool1d_from_2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<B> {
let [batch_size, channels, length] = x.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::avg_pool2d(
x,
[kernel_size, 1],
[stride, 1],
[padding, 0],
count_include_pad,
ceil_mode,
);
let [batch_size, channels, length, _] = x.shape().dims();
B::float_reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
x: FloatTensor<B>,
grad: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<B> {
let [batch_size, channels, length_in] = x.shape().dims();
let [_, _, length_out] = grad.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::avg_pool2d_backward(
x,
grad_x,
[kernel_size, 1],
[stride, 1],
[padding, 0],
count_include_pad,
ceil_mode,
);
B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}
pub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>(
x: FloatTensor<B>,
output_size: usize,
) -> FloatTensor<B> {
let [batch_size, channels, length] = x.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::adaptive_avg_pool2d(x, [output_size, 1]);
let [batch_size, channels, length, _] = x.shape().dims();
B::float_reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
x: FloatTensor<B>,
grad: FloatTensor<B>,
) -> FloatTensor<B> {
let [batch_size, channels, length_in] = x.shape().dims();
let [_, _, length_out] = grad.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x);
B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}
pub(crate) fn max_pool1d_from_2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> FloatTensor<B> {
let [batch_size, channels, length] = x.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::max_pool2d(
x,
[kernel_size, 1],
[stride, 1],
[padding, 0],
[dilation, 1],
ceil_mode,
);
let [batch_size, channels, length, _] = x.shape().dims();
B::float_reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> MaxPool1dWithIndices<B> {
let [batch_size, channels, length] = x.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length]));
let x = B::max_pool2d_with_indices(
x,
[1, kernel_size],
[1, stride],
[0, padding],
[1, dilation],
ceil_mode,
);
let [batch_size, channels, _, length] = x.output.shape().dims();
let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length]));
let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
MaxPool1dWithIndices::new(output, indices)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
output_grad: FloatTensor<B>,
indices: IntTensor<B>,
) -> MaxPool1dBackward<B> {
let [batch_size, channels, length_in] = x.shape().dims();
let [_, _, length_out] = output_grad.shape().dims();
let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::float_reshape(
output_grad,
Shape::from([batch_size, channels, length_out, 1]),
);
let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::max_pool2d_with_indices_backward(
x,
[kernel_size, 1],
[stride, 1],
[padding, 0],
[dilation, 1],
ceil_mode,
grad_x,
indices,
)
.x_grad;
MaxPool1dBackward::new(B::float_reshape(
grad_x,
Shape::from([batch_size, channels, length_in]),
))
}

View File

@@ -0,0 +1,146 @@
use super::{ConvOptions, UnfoldOptions};
use crate::tensor::FloatTensor;
use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
use alloc::vec;
use alloc::vec::Vec;
use burn_std::Shape;
/// Constructs a special weight tensor used for unfolding.
///
/// # Notes
///
/// The idea behind using convolution for unfolding is to leverage the sliding window mechanism of
/// convolution. By creating a weight tensor with ones in a particular pattern, we are able to borrow
/// the convolution operation's mechanism as it moves across the input tensor, picking up the desired
/// values in the pattern of the unfolding operation.
pub(crate) fn create_unfolding_weight<B: Backend>(
in_channels: usize,
kernel_size: [usize; 2],
device: &B::Device,
) -> FloatTensor<B> {
let shape = Shape::new([
in_channels * kernel_size[0] * kernel_size[1],
in_channels,
kernel_size[0],
kernel_size[1],
]);
let mut strides = [0; 4];
let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
let num_elements = shape.num_elements();
let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements];
for k in 0..in_channels {
for i in 0..kernel_size[0] {
for j in 0..kernel_size[1] {
let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j;
let index =
output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3];
weight[index] = 1.elem();
}
}
}
B::float_from_data(TensorData::new(weight, shape), device)
}
/// Compute the unfold4d operation using the conv2d operations.
pub(crate) fn unfold4d_using_conv2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: [usize; 2],
options: UnfoldOptions,
) -> FloatTensor<B> {
let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims();
let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x));
let unfolded = B::conv2d(
x,
weight,
None,
ConvOptions::new(options.stride, options.padding, options.dilation, 1),
);
let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();
B::float_reshape(
unfolded,
Shape::new([batch_size, channels_out, out_height * out_width]),
)
}
/// Calculate the number of unfolding windows that can be extracted from a dimension of given size.
pub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize {
assert!(step_size > 0);
let x = dim_size + step_size;
if x < window_size {
0
} else {
(x - window_size) / step_size
}
}
/// Calculate the output shape for an unfold operation.
///
/// The operation yields a view with all complete windows of size `size` in dimension `dim`;
/// where windows are advanced by `step` at each index.
///
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
///
/// # Arguments
///
/// * `shape` - The input shape to unfold; of shape ``[pre=..., dim shape, post=...]``
/// * `dim` - the dimension to unfold.
/// * `size` - the size of each unfolded window.
/// * `step` - the step between each window.
///
/// # Returns
///
/// A shape with ``[pre=..., windows, post=..., size]``.
pub fn calculate_unfold_shape<S: Into<Shape>>(
shape: S,
dim: usize,
size: usize,
step: usize,
) -> Shape {
let mut shape = shape.into();
let d_shape = shape[dim];
let windows = calculate_unfold_windows(d_shape, size, step);
shape[dim] = windows;
shape.push(size);
shape
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_unfold_windows() {
assert_eq!(calculate_unfold_windows(2, 5, 1), 0);
assert_eq!(calculate_unfold_windows(2, 3, 1), 0);
assert_eq!(calculate_unfold_windows(3, 3, 1), 1);
assert_eq!(calculate_unfold_windows(4, 3, 1), 2);
assert_eq!(calculate_unfold_windows(5, 3, 1), 3);
assert_eq!(calculate_unfold_windows(2, 3, 2), 0);
assert_eq!(calculate_unfold_windows(3, 3, 2), 1);
assert_eq!(calculate_unfold_windows(4, 3, 2), 1);
assert_eq!(calculate_unfold_windows(5, 3, 2), 2);
}
#[test]
fn test_calculate_unfold_shape() {
assert_eq!(
calculate_unfold_shape([2, 6, 6], 1, 3, 2),
Shape::new([2, 2, 6, 3])
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,39 @@
use crate::{
Backend, TensorMetadata,
tensor::{BasicOps, TensorKind},
};
use alloc::vec::Vec;
use burn_std::Slice;
pub(crate) fn repeat_with_slice_assign<B: Backend, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
times: usize,
) -> K::Primitive {
let shape = tensor.shape();
let device = K::device(&tensor);
let dtype = tensor.dtype();
let original_dim_length = shape[dim];
let shape = shape.repeat(dim, times).unwrap();
let mut tensor_output = K::empty(shape.clone(), &device, dtype);
let indices_select_all = shape.iter().map(|d| 0..*d).collect::<Vec<_>>();
let mut output_index = 0;
for _ in 0..times {
let mut indices = indices_select_all.clone();
indices[dim] = output_index..output_index + original_dim_length;
output_index += original_dim_length;
// Convert ranges to Slice
let slices: Vec<Slice> = indices
.iter()
.map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1))
.collect();
tensor_output = K::slice_assign(tensor_output, &slices, tensor.clone());
}
tensor_output
}

View File

@@ -0,0 +1,377 @@
use core::cmp::Ordering;
use crate::{
Backend, DType, TensorData,
element::{ElementConversion, ElementOrdered},
tensor::{BasicOps, IntElem, IntTensor},
};
use alloc::{vec, vec::Vec};
use burn_std::reader::try_read_sync;
use burn_std::{bf16, f16};
/// Macro used to dispatch sort operations based on dtype.
macro_rules! sort_dispatch_dtype {
($fn:ident, $data:ident, $($args:expr),*) => {
match $data.dtype {
DType::F64 => $fn::<B, f64>($data, $($args),*),
DType::F32 | DType::Flex32 => $fn::<B, f32>($data, $($args),*),
DType::F16 => $fn::<B, f16>($data, $($args),*),
DType::BF16 => $fn::<B, bf16>($data, $($args),*),
DType::I64 => $fn::<B, i64>($data, $($args),*),
DType::I32 => $fn::<B, i32>($data, $($args),*),
DType::I16 => $fn::<B, i16>($data, $($args),*),
DType::I8 => $fn::<B, i8>($data, $($args),*),
DType::U64 => $fn::<B, u64>($data, $($args),*),
DType::U32 => $fn::<B, u32>($data, $($args),*),
DType::U16 => $fn::<B, u16>($data, $($args),*),
DType::U8 => $fn::<B, u8>($data, $($args),*),
DType::Bool | DType::QFloat(_) => unimplemented!("not supported for sorting operations"),
}
};
}
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn sort<B: Backend, K: BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> K::Primitive {
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
let data = sort_dispatch_dtype!(sort_data, data, dim, descending);
K::from_data(data, &device)
}
pub fn sort_data<B: Backend, E: ElementOrdered>(
mut data: TensorData,
dim: usize,
descending: bool,
) -> TensorData {
let dims = data.shape.clone();
let data_slice = data.as_mut_slice().unwrap();
if dims.len() == 1 {
// 1D sort
data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
} else {
sort_slice::<B, E>(data_slice, &dims, dim, None, false, descending);
}
data
}
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// the elements are sorted by value and the indices map back to the original input tensor.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn sort_with_indices<B: Backend, K: BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> (K::Primitive, IntTensor<B>) {
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
let (values, indices) = sort_dispatch_dtype!(sort_data_with_indices, data, dim, descending);
(
K::from_data(values, &device),
B::int_from_data(indices, &device),
)
}
fn sort_data_with_indices<B: Backend, E: ElementOrdered>(
mut data: TensorData,
dim: usize,
descending: bool,
) -> (TensorData, TensorData) {
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
let data_slice = data.as_mut_slice().unwrap();
if dims.len() == 1 {
// 1D sort
indices_data.sort_unstable_by(|&a, &b| {
compare(
&data_slice[a.elem::<i64>() as usize],
&data_slice[b.elem::<i64>() as usize],
descending,
)
});
// Permute data in-place by the sorted indices
let mut indices = indices_data
.clone()
.iter()
.map(|i| i.elem::<i64>() as usize)
.collect::<Vec<_>>();
for idx in 0..indices.len() {
if indices[idx] != idx {
let mut current_idx = idx;
loop {
let target_idx = indices[current_idx];
indices[current_idx] = current_idx;
if indices[target_idx] == target_idx {
// correct position
break;
}
// Permute data by indices
data_slice.swap(current_idx, target_idx);
current_idx = target_idx;
}
}
}
} else {
sort_slice::<B, E>(
data_slice,
&dims,
dim,
Some(&mut indices_data),
true,
descending,
);
}
(data, TensorData::new(indices_data, dims))
}
/// Returns the indices that sort the elements of the input `tensor` along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn argsort<B: Backend, K: BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> IntTensor<B> {
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
let data = sort_dispatch_dtype!(argsort_data, data, dim, descending);
B::int_from_data(data, &device)
}
fn argsort_data<B: Backend, E: ElementOrdered>(
mut data: TensorData,
dim: usize,
descending: bool,
) -> TensorData {
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
if dims.len() == 1 {
// 1D sort
let slice = data.as_slice::<E>().unwrap();
indices_data.sort_unstable_by(|&a, &b| {
compare(
&slice[a.elem::<i64>() as usize],
&slice[b.elem::<i64>() as usize],
descending,
)
});
} else {
sort_slice::<B, E>(
data.as_mut_slice().unwrap(),
&dims,
dim,
Some(&mut indices_data),
false,
descending,
);
}
TensorData::new(indices_data, dims)
}
/// Sort the elements by value along a given dimension.
///
/// When `indices` are not provided, the `data` is sorted.
/// Otherwise, the `indices` are sorted based on the value of the elements in `data`,
/// and if `permute_both` is enabled then the data is also sorted.
///
/// This sort is unstable (i.e., may reorder equal elements).
fn sort_slice<B: Backend, E: ElementOrdered>(
data: &mut [E],
dims: &[usize],
dim: usize,
mut indices: Option<&mut [IntElem<B>]>,
permute_both: bool,
descending: bool,
) {
let ndims = dims.len();
let strides = compute_strides(dims);
// Dimensions to access elements to sort
let mut sort_dims = dims.to_vec();
sort_dims[dim] = 1;
let strides_out = compute_strides(&sort_dims);
// Number of groups to sort
let num_sorts: usize = dims
.iter()
.enumerate()
.filter(|&(i, _)| i != dim)
.map(|(_, d)| d)
.product();
// TODO: run each sort in parallel
// run_par!(|| {
// iter_range_par!(0, num_sorts).for_each(|id| {...})
for id in 0..num_sorts {
let mut index_offset = 0;
let mut stride_dim = 0;
let mut shape_dim = 0;
for d in 0..ndims {
let stride_input = strides[d];
let stride_output = strides_out[d];
let shape_output = sort_dims[d];
let num_block = id / stride_output % shape_output;
if d != dim {
index_offset += num_block * stride_input;
} else {
let shape_input = dims[d];
stride_dim = stride_input;
shape_dim = shape_input;
index_offset += num_block;
}
}
// For each group, sort the indices based on the element values
// NOTE: Sorting methods like `sort_unstable_by` are in-place but we need to sort
// different views/groups of the underlying data, so the swap is performed on the elements
// of the (flat index, element value) collection.
let mut elements = (0..shape_dim)
.map(|d| {
let flat_index = d * stride_dim + index_offset;
let elem = data[flat_index];
(d, flat_index, elem)
})
.collect::<Vec<_>>();
elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending));
// Permute data in-place by the sorted indices
for idx in 0..elements.len() {
if elements[idx].0 != idx {
let mut current_idx = idx;
loop {
let target_idx = elements[current_idx].0;
elements[current_idx].0 = current_idx;
if elements[target_idx].0 == target_idx {
// correct position
break;
}
if indices.is_none() || permute_both {
// Permute data by indices
data.swap(elements[current_idx].1, elements[target_idx].1);
}
if let Some(ref mut indices_data) = indices {
// Permute data element indices
indices_data.swap(elements[current_idx].1, elements[target_idx].1);
}
current_idx = target_idx;
}
}
}
}
}
/// Computes the steps for each dimension when traversing an array.
fn compute_strides(dims: &[usize]) -> Vec<usize> {
let mut strides = vec![0; dims.len()];
let mut current = 1;
dims.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
strides
}
/// Generates the indices for each element along the specified dimension.
fn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> {
if dims.len() == 1 {
(0..dims[dim])
.map(|i| (i as i64).elem::<IntElem<B>>())
.collect::<Vec<_>>()
} else {
// Dimension indices tensor
let numel_leading_dims: usize = dims[..dim].iter().product();
let numel_trailing_dims: usize = dims[dim + 1..].iter().product();
(0..dims[dim])
.map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims))
.collect::<Vec<_>>()
.concat()
.repeat(numel_leading_dims)
}
}
/// Compare two elements
fn compare<E: ElementOrdered>(a: &E, b: &E, descending: bool) -> Ordering {
if descending { b.cmp(a) } else { a.cmp(b) }
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,139 @@
use alloc::vec::Vec;
use core::future::Future;
use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use crate::{Backend, ExecutionError, TensorData, TensorPrimitive};
enum Order {
Float(usize),
QFloat(usize),
Int(usize),
Bool(usize),
}
#[derive(Default)]
/// Contains all tensor primitives that are going to be read.
pub struct TransactionPrimitive<B: Backend> {
/// Float tensors.
pub read_floats: Vec<FloatTensor<B>>,
/// Quantized tensors.
pub read_qfloats: Vec<QuantizedTensor<B>>,
/// Int tensors.
pub read_ints: Vec<IntTensor<B>>,
/// Bool tensors.
pub read_bools: Vec<BoolTensor<B>>,
orders: Vec<Order>,
}
#[derive(Default)]
/// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive).
pub struct TransactionPrimitiveData {
/// Float tensor data.
pub read_floats: Vec<TensorData>,
/// Quantized tensor data.
pub read_qfloats: Vec<TensorData>,
/// Int tensor data.
pub read_ints: Vec<TensorData>,
/// Bool tensor data.
pub read_bools: Vec<TensorData>,
}
/// Operations that are sync by nature and that can be batch together in transactions to improve
/// compute utilization with efficient laziness.
pub trait TransactionOps<B: Backend> {
/// Executes a [transaction](TransactionPrimitive) and return its
/// [data](TransactionPrimitiveData).
fn tr_execute(
transaction: TransactionPrimitive<B>,
) -> impl Future<Output = Result<TransactionPrimitiveData, ExecutionError>> + Send {
async move {
let mut floats = Vec::new();
let mut qfloats = Vec::new();
let mut ints = Vec::new();
let mut bools = Vec::new();
for t in transaction.read_floats {
floats.push(B::float_into_data(t).await?);
}
for t in transaction.read_qfloats {
qfloats.push(B::q_into_data(t).await?);
}
for t in transaction.read_ints {
ints.push(B::int_into_data(t).await?);
}
for t in transaction.read_bools {
bools.push(B::bool_into_data(t).await?);
}
Ok(TransactionPrimitiveData {
read_floats: floats,
read_qfloats: qfloats,
read_ints: ints,
read_bools: bools,
})
}
}
}
impl<B: Backend> TransactionPrimitive<B> {
/// Creates a new transaction.
pub fn new(
read_floats: Vec<FloatTensor<B>>,
read_qfloats: Vec<QuantizedTensor<B>>,
read_ints: Vec<IntTensor<B>>,
read_bools: Vec<BoolTensor<B>>,
) -> Self {
Self {
read_floats,
read_qfloats,
read_ints,
read_bools,
orders: Vec::default(),
}
}
/// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
/// in which they were [registered](crate::tensor::BasicOps::register_transaction).
pub async fn execute_async(mut self) -> Result<Vec<TensorData>, ExecutionError> {
let mut orders = Vec::new();
core::mem::swap(&mut orders, &mut self.orders);
let result = B::tr_execute(self).await?;
let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();
let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();
let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();
let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();
Ok(orders
.into_iter()
.map(|order| match order {
Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),
Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),
Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),
Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),
})
.collect::<Vec<_>>())
}
pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {
match tensor {
TensorPrimitive::Float(tensor) => {
self.orders.push(Order::Float(self.read_floats.len()));
self.read_floats.push(tensor);
}
TensorPrimitive::QFloat(tensor) => {
self.orders.push(Order::QFloat(self.read_qfloats.len()));
self.read_qfloats.push(tensor);
}
}
}
pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {
self.orders.push(Order::Int(self.read_ints.len()));
self.read_ints.push(tensor);
}
pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {
self.orders.push(Order::Bool(self.read_bools.len()));
self.read_bools.push(tensor);
}
}

View File

@@ -0,0 +1,77 @@
use crate::Backend;
use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};
use burn_std::{DType, Shape};
#[derive(Debug, Clone)]
/// A primitive tensor representation.
pub enum TensorPrimitive<B: Backend> {
/// Float tensor primitive.
Float(B::FloatTensorPrimitive),
/// Quantized float tensor primitive.
QFloat(B::QuantizedTensorPrimitive),
}
impl<B: Backend> TensorPrimitive<B> {
/// Returns the full tensor representation.
pub fn tensor(self) -> B::FloatTensorPrimitive {
match self {
Self::QFloat(tensor) => B::dequantize(tensor),
Self::Float(tensor) => tensor,
}
}
}
impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
fn dtype(&self) -> DType {
match self {
TensorPrimitive::Float(tensor) => tensor.dtype(),
TensorPrimitive::QFloat(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> Shape {
match self {
TensorPrimitive::Float(tensor) => tensor.shape(),
TensorPrimitive::QFloat(tensor) => tensor.shape(),
}
}
fn rank(&self) -> usize {
match self {
TensorPrimitive::Float(tensor) => tensor.rank(),
TensorPrimitive::QFloat(tensor) => tensor.rank(),
}
}
}
/// Tensor metadata trait for tensor primitive.
pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
/// The dtype of the tensor.
fn dtype(&self) -> DType;
/// The shape of the tensor.
fn shape(&self) -> Shape;
/// The number of dimensions of the tensor.
fn rank(&self) -> usize {
self.shape().num_dims()
}
}
/// Quantized tensor primitive.
pub trait QTensorPrimitive {
/// Returns the quantization settings for the given tensor.
fn scheme(&self) -> &QuantScheme;
/// The precision used for the accumulation in various kernels.
fn acc_precision(&self) -> QuantAcc {
QuantAcc::F32
}
/// How quantization is propagated during computation.
fn propagation(&self) -> QuantPropagation {
QuantPropagation::Inhibit
}
/// Returns the default tensor quantization scheme.
fn default_scheme() -> QuantScheme {
QuantScheme::default()
}
}

View File

@@ -0,0 +1,427 @@
use alloc::format;
use alloc::string::String;
use burn_std::{DType, bf16, f16};
use num_traits::{Float, ToPrimitive};
use super::TensorData;
use crate::{Element, ElementOrdered};
/// The tolerance used to compare to floating point numbers.
///
/// Generally, two numbers `x` and `y` are approximately equal if
///
/// ```text
/// |x - y| < max(R * (|x + y|), A)
/// ```
///
/// where `R` is the relative tolerance and `A` is the absolute tolerance.
///
///
/// The most common way to initialize this struct is to use `Tolerance::<F>::default()`.
/// In that case, the relative and absolute tolerances are computed using an heuristic based
/// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`.
///
/// Another common initialization is `Tolerance::<F>::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`.
/// This will use a sane default to manage values too close to 0.0 and
/// use different relative tolerances depending on the floating point precision.
#[derive(Debug, Clone, Copy)]
pub struct Tolerance<F> {
relative: F,
absolute: F,
}
impl<F: Float> Default for Tolerance<F> {
fn default() -> Self {
Self::balanced()
}
}
impl<F: Float> Tolerance<F> {
/// Create a tolerance with strict precision setting.
pub fn strict() -> Self {
Self {
relative: F::from(0.00).unwrap(),
absolute: F::from(64).unwrap() * F::min_positive_value(),
}
}
/// Create a tolerance with balanced precision setting.
pub fn balanced() -> Self {
Self {
relative: F::from(0.005).unwrap(), // 0.5%
absolute: F::from(1e-5).unwrap(),
}
}
/// Create a tolerance with permissive precision setting.
pub fn permissive() -> Self {
Self {
relative: F::from(0.01).unwrap(), // 1.0%
absolute: F::from(0.01).unwrap(),
}
}
/// When comparing two numbers, this uses both the relative and absolute differences.
///
/// That is, `x` and `y` are approximately equal if
///
/// ```text
/// |x - y| < max(R * (|x + y|), A)
/// ```
///
/// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance.
pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
let relative = Self::check_relative(relative);
let absolute = Self::check_absolute(absolute);
Self { relative, absolute }
}
/// When comparing two numbers, this uses only the relative difference.
///
/// That is, `x` and `y` are approximately equal if
///
/// ```text
/// |x - y| < R * max(|x|, |y|)
/// ```
///
/// where `R` is the relative `tolerance`.
pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
let relative = Self::check_relative(tolerance);
Self {
relative,
absolute: F::from(0.0).unwrap(),
}
}
/// When comparing two numbers, this uses only the absolute difference.
///
/// That is, `x` and `y` are approximately equal if
///
/// ```text
/// |x - y| < A
/// ```
///
/// where `A` is the absolute `tolerance`.
pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
let absolute = Self::check_absolute(tolerance);
Self {
relative: F::from(0.0).unwrap(),
absolute,
}
}
/// Change the relative tolerance to the given one.
pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
self.relative = Self::check_relative(tolerance);
self
}
/// Change the relative tolerance to the given one only if `F` is half precision.
pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 2 {
self.relative = Self::check_relative(tolerance);
}
self
}
/// Change the relative tolerance to the given one only if `F` is single precision.
pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 4 {
self.relative = Self::check_relative(tolerance);
}
self
}
/// Change the relative tolerance to the given one only if `F` is double precision.
pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 8 {
self.relative = Self::check_relative(tolerance);
}
self
}
/// Change the absolute tolerance to the given one.
pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
self.absolute = Self::check_absolute(tolerance);
self
}
/// Change the absolute tolerance to the given one only if `F` is half precision.
pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 2 {
self.absolute = Self::check_absolute(tolerance);
}
self
}
/// Change the absolute tolerance to the given one only if `F` is single precision.
pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 4 {
self.absolute = Self::check_absolute(tolerance);
}
self
}
/// Change the absolute tolerance to the given one only if `F` is double precision.
pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 8 {
self.absolute = Self::check_absolute(tolerance);
}
self
}
/// Checks if `x` and `y` are approximately equal given the tolerance.
pub fn approx_eq(&self, x: F, y: F) -> bool {
// See the accepted answer here
// https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison
// This also handles the case where both a and b are infinity so that we don't need
// to manage it in the rest of the function.
if x == y {
return true;
}
let diff = (x - y).abs();
let max = F::max(x.abs(), y.abs());
diff < self.absolute.max(self.relative * max)
}
fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
let tolerance = F::from(tolerance).unwrap();
assert!(tolerance <= F::one());
tolerance
}
fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
let tolerance = F::from(tolerance).unwrap();
assert!(tolerance >= F::zero());
tolerance
}
}
impl TensorData {
/// Asserts the data is equal to another data.
///
/// # Arguments
///
/// * `other` - The other data.
/// * `strict` - If true, the data types must the be same.
/// Otherwise, the comparison is done in the current data type.
///
/// # Panics
///
/// Panics if the data is not equal.
#[track_caller]
pub fn assert_eq(&self, other: &Self, strict: bool) {
if strict {
assert_eq!(
self.dtype, other.dtype,
"Data types differ ({:?} != {:?})",
self.dtype, other.dtype
);
}
match self.dtype {
DType::F64 => self.assert_eq_elem::<f64>(other),
DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
DType::F16 => self.assert_eq_elem::<f16>(other),
DType::BF16 => self.assert_eq_elem::<bf16>(other),
DType::I64 => self.assert_eq_elem::<i64>(other),
DType::I32 => self.assert_eq_elem::<i32>(other),
DType::I16 => self.assert_eq_elem::<i16>(other),
DType::I8 => self.assert_eq_elem::<i8>(other),
DType::U64 => self.assert_eq_elem::<u64>(other),
DType::U32 => self.assert_eq_elem::<u32>(other),
DType::U16 => self.assert_eq_elem::<u16>(other),
DType::U8 => self.assert_eq_elem::<u8>(other),
DType::Bool => self.assert_eq_elem::<bool>(other),
DType::QFloat(q) => {
// Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
let q_other = if let DType::QFloat(q_other) = other.dtype {
q_other
} else {
panic!("Quantized data differs from other not quantized data")
};
// Data equality mostly depends on input quantization type, but we also check level
if q.value == q_other.value && q.level == q_other.level {
self.assert_eq_elem::<i8>(other)
} else {
panic!("Quantization schemes differ ({q:?} != {q_other:?})")
}
}
}
}
#[track_caller]
fn assert_eq_elem<E: Element>(&self, other: &Self) {
let mut message = String::new();
if self.shape != other.shape {
message += format!(
"\n => Shape is different: {:?} != {:?}",
self.shape, other.shape
)
.as_str();
}
let mut num_diff = 0;
let max_num_diff = 5;
for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
if !a.eq(&b) {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!("\n => Position {i}: {a} != {b}").as_str();
}
num_diff += 1;
}
}
if num_diff >= max_num_diff {
message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
}
if !message.is_empty() {
panic!("Tensors are not eq:{message}");
}
}
/// Asserts the data is approximately equal to another data.
///
/// # Arguments
///
/// * `other` - The other data.
/// * `tolerance` - The tolerance of the comparison.
///
/// # Panics
///
/// Panics if the data is not approximately equal.
#[track_caller]
pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
let mut message = String::new();
if self.shape != other.shape {
message += format!(
"\n => Shape is different: {:?} != {:?}",
self.shape, other.shape
)
.as_str();
}
let iter = self.iter::<F>().zip(other.iter::<F>());
let mut num_diff = 0;
let max_num_diff = 5;
for (i, (a, b)) in iter.enumerate() {
//if they are both nan, then they are equally nan
let both_nan = a.is_nan() && b.is_nan();
//this works for both infinities
let both_inf =
a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
if both_nan || both_inf {
continue;
}
if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
// Only print the first 5 different values.
if num_diff < max_num_diff {
let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
let max = F::max(a.abs(), b.abs());
let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
message += format!(
"\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
)
.as_str();
}
num_diff += 1;
}
}
if num_diff >= max_num_diff {
message += format!("\n{} more errors...", num_diff - 5).as_str();
}
if !message.is_empty() {
panic!("Tensors are not approx eq:{message}");
}
}
/// Asserts each value is within a given range.
///
/// # Arguments
///
/// * `range` - The range.
///
/// # Panics
///
/// If any value is not within the half-open range bounded inclusively below
/// and exclusively above (`start..end`).
pub fn assert_within_range<E: ElementOrdered>(&self, range: core::ops::Range<E>) {
for elem in self.iter::<E>() {
if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
panic!("Element ({elem:?}) is not within range {range:?}");
}
}
}
/// Asserts each value is within a given inclusive range.
///
/// # Arguments
///
/// * `range` - The range.
///
/// # Panics
///
/// If any value is not within the half-open range bounded inclusively (`start..=end`).
pub fn assert_within_range_inclusive<E: ElementOrdered>(
&self,
range: core::ops::RangeInclusive<E>,
) {
let start = range.start();
let end = range.end();
for elem in self.iter::<E>() {
if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
panic!("Element ({elem:?}) is not within range {range:?}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_assert_appox_eq_limit() {
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
data1.assert_approx_eq::<f16>(&data2, Tolerance::absolute(3e-2));
}
#[test]
#[should_panic]
fn should_assert_approx_eq_above_limit() {
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
}
#[test]
#[should_panic]
fn should_assert_approx_eq_check_shape() {
let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
}
}

View File

@@ -0,0 +1,5 @@
mod compare;
mod tensor;
pub use compare::*;
pub use tensor::*;

View File

@@ -0,0 +1,815 @@
use core::f32;
use alloc::boxed::Box;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
use rand::Rng;
use thiserror::Error;
use crate::Scalar;
use crate::distribution::Distribution;
use crate::element::{Element, ElementConversion};
use burn_std::tensor::DType;
use burn_std::{Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, bf16, f16};
/// Data structure for tensors.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TensorData {
/// The values of the tensor (as bytes).
pub bytes: Bytes,
/// The shape of the tensor.
pub shape: Vec<usize>,
/// The data type of the tensor.
pub dtype: DType,
}
impl TensorData {
/// Creates a new tensor data structure.
pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
// Ensure shape is valid
let shape = shape.into();
Self::check_data_len(&value, &shape);
Self {
bytes: Bytes::from_elems(value),
shape,
dtype: E::dtype(),
}
}
/// Creates a new quantized tensor data structure.
pub fn quantized<E: Element, S: Into<Vec<usize>>>(
value: Vec<E>,
shape: S,
scheme: QuantScheme,
qparams: &[f32],
) -> Self {
let shape = shape.into();
Self::check_data_len(&value, &shape);
let q_bytes = QuantizedBytes::new(value, scheme, qparams);
Self {
bytes: q_bytes.bytes,
shape,
dtype: DType::QFloat(q_bytes.scheme),
}
}
/// Creates a new tensor data structure from raw bytes.
pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
Self {
bytes,
shape: shape.into(),
dtype,
}
}
/// Creates a new tensor data structure from raw bytes stored in a vector.
///
/// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
/// certain that the bytes representation is valid.
pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
Self {
bytes: Bytes::from_bytes_vec(bytes),
shape: shape.into(),
dtype,
}
}
// Check that the input vector contains a correct number of elements
fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
let expected_data_len = Self::numel(shape);
let num_data = data.len();
assert_eq!(
expected_data_len, num_data,
"Shape {shape:?} is invalid for input of size {num_data:?}",
);
}
/// Returns the immutable slice view of the tensor data.
pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
if E::dtype() == self.dtype {
match E::dtype() {
// The only way to create a bool `TensorData` with invalid values is by unsafely modifying
// the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
// to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
DType::Bool => {
let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
.map_err(DataError::CastError)?;
Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
}
_ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
}
} else {
Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
self.dtype,
E::dtype()
)))
}
}
/// Returns the mutable slice view of the tensor data.
///
/// # Panics
/// If the target element type is different from the stored element type.
pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
if E::dtype() == self.dtype {
match E::dtype() {
// The only way to create a bool `TensorData` with invalid values is by unsafely modifying
// the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
// to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
DType::Bool => {
let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
.map_err(DataError::CastError)?;
Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
}
_ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
.map_err(DataError::CastError),
}
} else {
Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
self.dtype,
E::dtype()
)))
}
}
/// Returns the tensor data as a vector of scalar values.
pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
Ok(self.as_slice()?.to_vec())
}
/// Returns the tensor data as a vector of scalar values.
pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
// This means we cannot call `into_vec` for QFloat
if E::dtype() != self.dtype {
return Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
self.dtype,
E::dtype()
)));
}
match E::dtype() {
// The only way to create a bool `TensorData` with invalid values is by unsafely modifying
// the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
// to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
DType::Bool => {
let vec = self.into_vec_unchecked::<u8>()?;
Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
}
_ => self.into_vec_unchecked(),
}
}
/// Returns the tensor data as a vector of scalar values. Does not check dtype.
fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
let mut me = self;
me.bytes = match me.bytes.try_into_vec::<E>() {
Ok(elems) => return Ok(elems),
Err(bytes) => bytes,
};
// The bytes might have been deserialized and allocated with a different align.
// In that case, we have to memcopy the data into a new vector, more suitably allocated
Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
.map_err(DataError::CastError)?
.to_vec())
}
/// Returns an iterator over the values of the tensor data.
pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
if E::dtype() == self.dtype {
Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
} else {
match self.dtype {
DType::I8 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i8| e.elem::<E>()),
),
DType::I16 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i16| e.elem::<E>()),
),
DType::I32 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i32| e.elem::<E>()),
),
DType::I64 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i64| e.elem::<E>()),
),
DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
DType::U16 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &u16| e.elem::<E>()),
),
DType::U32 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &u32| e.elem::<E>()),
),
DType::U64 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &u64| e.elem::<E>()),
),
DType::BF16 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &bf16| e.elem::<E>()),
),
DType::F16 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &f16| e.elem::<E>()),
),
DType::F32 | DType::Flex32 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &f32| e.elem::<E>()),
),
DType::F64 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &f64| e.elem::<E>()),
),
// bool is a byte value equal to either 0 or 1
DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
DType::QFloat(scheme) => match scheme {
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
value:
QuantValue::Q8F
| QuantValue::Q8S
// Represent sub-byte values as i8
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S,
..
} => {
// Quantized int8 values
let q_bytes = QuantizedBytes {
bytes: self.bytes.clone(),
scheme,
num_elements: self.num_elements(),
};
let (values, _) = q_bytes.into_vec_i8();
Box::new(
values
.iter()
.map(|e: &i8| e.elem::<E>())
.collect::<Vec<_>>()
.into_iter(),
)
}
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
value:
QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
..
} => {
unimplemented!("Not yet implemented for iteration");
}
},
}
}
}
/// Returns the rank (the number of dimensions).
pub fn rank(&self) -> usize {
self.shape.len()
}
/// Returns the total number of elements of the tensor data.
pub fn num_elements(&self) -> usize {
Self::numel(&self.shape)
}
fn numel(shape: &[usize]) -> usize {
shape.iter().product()
}
/// Populates the data with random values.
pub fn random<E: Element, R: Rng, S: Into<Vec<usize>>>(
shape: S,
distribution: Distribution,
rng: &mut R,
) -> Self {
let shape = shape.into();
let num_elements = Self::numel(&shape);
let mut data = Vec::with_capacity(num_elements);
for _ in 0..num_elements {
data.push(E::random(distribution, rng));
}
TensorData::new(data, shape)
}
/// Populates the data with zeros.
pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
let shape = shape.into();
let num_elements = Self::numel(&shape);
let mut data = Vec::<E>::with_capacity(num_elements);
for _ in 0..num_elements {
data.push(0.elem());
}
TensorData::new(data, shape)
}
/// Populates the data with ones.
pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
let shape = shape.into();
let num_elements = Self::numel(&shape);
let mut data = Vec::<E>::with_capacity(num_elements);
for _ in 0..num_elements {
data.push(1.elem());
}
TensorData::new(data, shape)
}
/// Populates the data with the given value
pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
let shape = shape.into();
let num_elements = Self::numel(&shape);
let mut data = Vec::<E>::with_capacity(num_elements);
for _ in 0..num_elements {
data.push(fill_value)
}
TensorData::new(data, shape)
}
/// Populates the data with the given value
pub fn full_dtype<E: Into<Scalar>, S: Into<Vec<usize>>>(
shape: S,
fill_value: E,
dtype: DType,
) -> TensorData {
let fill_value = fill_value.into();
match dtype {
DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
DType::QFloat(_) => unreachable!(),
}
}
/// Converts the data to a different element type.
pub fn convert<E: Element>(self) -> Self {
self.convert_dtype(E::dtype())
}
/// Converts the data to a different element type.
pub fn convert_dtype(self, dtype: DType) -> Self {
if dtype == self.dtype {
self
} else if dtype.size() == self.dtype.size()
&& !matches!(self.dtype, DType::Bool | DType::QFloat(_))
&& !matches!(dtype, DType::Bool | DType::QFloat(_))
{
match self.dtype {
DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
DType::Bool | DType::QFloat(_) => unreachable!(),
}
} else {
match self.dtype {
DType::F64 => self.convert_clone_dtype::<f64>(dtype),
DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
DType::F16 => self.convert_clone_dtype::<f16>(dtype),
DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
DType::I64 => self.convert_clone_dtype::<i64>(dtype),
DType::I32 => self.convert_clone_dtype::<i32>(dtype),
DType::I16 => self.convert_clone_dtype::<i16>(dtype),
DType::I8 => self.convert_clone_dtype::<i8>(dtype),
DType::U64 => self.convert_clone_dtype::<u64>(dtype),
DType::U32 => self.convert_clone_dtype::<u32>(dtype),
DType::U16 => self.convert_clone_dtype::<u16>(dtype),
DType::U8 => self.convert_clone_dtype::<u8>(dtype),
DType::Bool => self.convert_clone_dtype::<bool>(dtype),
DType::QFloat(_) => unreachable!(),
}
}
}
fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
match dtype {
DType::F64 => self.convert_inplace::<Current, f64>(),
DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
DType::F16 => self.convert_inplace::<Current, f16>(),
DType::BF16 => self.convert_inplace::<Current, bf16>(),
DType::I64 => self.convert_inplace::<Current, i64>(),
DType::I32 => self.convert_inplace::<Current, i32>(),
DType::I16 => self.convert_inplace::<Current, i16>(),
DType::I8 => self.convert_inplace::<Current, i8>(),
DType::U64 => self.convert_inplace::<Current, u64>(),
DType::U32 => self.convert_inplace::<Current, u32>(),
DType::U16 => self.convert_inplace::<Current, u16>(),
DType::U8 => self.convert_inplace::<Current, u8>(),
DType::Bool | DType::QFloat(_) => unreachable!(),
}
}
fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
mut self,
) -> Self {
for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
let t: Target = x.elem();
let x = cast_mut::<_, Target>(x);
*x = t;
}
self.dtype = Target::dtype();
self
}
fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
match dtype {
DType::F64 => self.convert_clone::<Current, f64>(),
DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
DType::F16 => self.convert_clone::<Current, f16>(),
DType::BF16 => self.convert_clone::<Current, bf16>(),
DType::I64 => self.convert_clone::<Current, i64>(),
DType::I32 => self.convert_clone::<Current, i32>(),
DType::I16 => self.convert_clone::<Current, i16>(),
DType::I8 => self.convert_clone::<Current, i8>(),
DType::U64 => self.convert_clone::<Current, u64>(),
DType::U32 => self.convert_clone::<Current, u32>(),
DType::U16 => self.convert_clone::<Current, u16>(),
DType::U8 => self.convert_clone::<Current, u8>(),
DType::Bool => self.convert_clone::<Current, bool>(),
DType::QFloat(_) => unreachable!(),
}
}
fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
self,
) -> Self {
let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
for (x, out) in this.iter().zip(&mut out) {
*out = x.elem();
}
Self::new(out, self.shape)
}
/// Returns the data as a slice of bytes.
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
/// Returns the bytes representation of the data.
pub fn into_bytes(self) -> Bytes {
self.bytes
}
}
impl<E: Element, const A: usize> From<[E; A]> for TensorData {
fn from(elems: [E; A]) -> Self {
TensorData::new(elems.to_vec(), [A])
}
}
impl<const A: usize> From<[usize; A]> for TensorData {
fn from(elems: [usize; A]) -> Self {
TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
}
}
impl From<&[usize]> for TensorData {
fn from(elems: &[usize]) -> Self {
let mut data = Vec::with_capacity(elems.len());
for elem in elems.iter() {
data.push(*elem as i64);
}
TensorData::new(data, [elems.len()])
}
}
impl<E: Element> From<&[E]> for TensorData {
fn from(elems: &[E]) -> Self {
let mut data = Vec::with_capacity(elems.len());
for elem in elems.iter() {
data.push(*elem);
}
TensorData::new(data, [elems.len()])
}
}
impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
fn from(elems: [[E; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B);
for elem in elems.into_iter().take(A) {
for elem in elem.into_iter().take(B) {
data.push(elem);
}
}
TensorData::new(data, [A, B])
}
}
impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
for TensorData
{
fn from(elems: [[[E; C]; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B * C);
for elem in elems.into_iter().take(A) {
for elem in elem.into_iter().take(B) {
for elem in elem.into_iter().take(C) {
data.push(elem);
}
}
}
TensorData::new(data, [A, B, C])
}
}
impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
From<[[[[E; D]; C]; B]; A]> for TensorData
{
fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B * C * D);
for elem in elems.into_iter().take(A) {
for elem in elem.into_iter().take(B) {
for elem in elem.into_iter().take(C) {
for elem in elem.into_iter().take(D) {
data.push(elem);
}
}
}
}
TensorData::new(data, [A, B, C, D])
}
}
impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
{
fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B * C * D * E);
for elem in elems.into_iter().take(A) {
for elem in elem.into_iter().take(B) {
for elem in elem.into_iter().take(C) {
for elem in elem.into_iter().take(D) {
for elem in elem.into_iter().take(E) {
data.push(elem);
}
}
}
}
}
TensorData::new(data, [A, B, C, D, E])
}
}
impl core::fmt::Display for TensorData {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let fmt = match self.dtype {
DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
DType::QFloat(scheme) => match scheme {
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
value:
QuantValue::Q8F
| QuantValue::Q8S
// Display sub-byte values as i8
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S,
..
} => {
format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
},
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
value:
QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
..
} => {
unimplemented!("Can't format yet");
}
},
};
f.write_str(fmt.as_str())
}
}
/// The things that can go wrong when manipulating tensor data.
#[derive(Debug, Error)]
pub enum DataError {
/// Failed to cast the values to a specified element type.
#[error("Failed to cast values to the specified element type.\nError:\n {0}")]
CastError(CheckedCastError),
/// Invalid target element type.
#[error("{0}")]
TypeMismatch(String),
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use rand::{
SeedableRng,
rngs::{StdRng, SysRng},
};
#[test]
fn should_have_rank() {
let shape = [3, 5, 6];
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::try_from_rng(&mut SysRng).unwrap(),
);
assert_eq!(data.rank(), 3);
}
#[test]
fn into_vec_should_yield_same_value_as_iter() {
let shape = [3, 5, 6];
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::try_from_rng(&mut SysRng).unwrap(),
);
let expected = data.iter::<f32>().collect::<Vec<f32>>();
let actual = data.into_vec::<f32>().unwrap();
assert_eq!(expected, actual);
}
#[test]
#[should_panic]
fn into_vec_should_assert_wrong_dtype() {
let shape = [3, 5, 6];
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::try_from_rng(&mut SysRng).unwrap(),
);
data.into_vec::<i32>().unwrap();
}
#[test]
fn should_have_right_num_elements() {
let shape = [3, 5, 6];
let num_elements: usize = shape.iter().product();
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::try_from_rng(&mut SysRng).unwrap(),
);
assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
}
#[test]
fn should_have_right_shape() {
let data = TensorData::from([[3.0, 5.0, 6.0]]);
assert_eq!(data.shape, vec![1, 3]);
let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
assert_eq!(data.shape, vec![2, 3]);
let data = TensorData::from([3.0, 5.0, 6.0]);
assert_eq!(data.shape, vec![3]);
}
#[test]
fn should_convert_bytes_correctly() {
let mut vector: Vec<f32> = Vec::with_capacity(5);
vector.push(2.0);
vector.push(3.0);
let data1 = TensorData::new(vector, vec![2]);
let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
assert_eq!(data1.bytes.len(), 2 * factor);
assert_eq!(data1.bytes.capacity(), 5 * factor);
}
#[test]
fn should_convert_bytes_correctly_inplace() {
fn test_precision<E: Element>() {
let data = TensorData::new((0..32).collect(), [32]);
for (i, val) in data
.clone()
.convert::<E>()
.into_vec::<E>()
.unwrap()
.into_iter()
.enumerate()
{
assert_eq!(i as u32, val.elem::<u32>())
}
}
test_precision::<f32>();
test_precision::<f16>();
test_precision::<i64>();
test_precision::<i32>();
}
macro_rules! test_dtypes {
($test_name:ident, $($dtype:ty),*) => {
$(
paste::paste! {
#[test]
fn [<$test_name _ $dtype:snake>]() {
let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
assert_eq!(full_dtype, full);
}
}
)*
};
}
test_dtypes!(
should_create_with_dtype,
bool,
i8,
i16,
i32,
i64,
u8,
u16,
u32,
u64,
f16,
bf16,
f32,
f64
);
}

View File

@@ -0,0 +1,125 @@
//! Random value distributions used to initialize and populate tensor data.
use rand::{Rng, RngExt, distr::StandardUniform};
use super::element::{Element, ElementConversion};
/// Distribution for random value of a tensor.
#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum Distribution {
/// Uniform distribution from 0 (inclusive) to 1 (exclusive).
#[default]
Default,
/// Bernoulli distribution with the given probability.
Bernoulli(f64),
/// Uniform distribution `[low, high)`.
Uniform(f64, f64),
/// Normal distribution with the given mean and standard deviation.
Normal(f64, f64),
}
/// Distribution sampler for random value of a tensor.
#[derive(new)]
pub struct DistributionSampler<'a, E, R>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
R: Rng,
{
kind: DistributionSamplerKind<E>,
rng: &'a mut R,
}
/// Distribution sampler kind for random value of a tensor.
pub enum DistributionSamplerKind<E>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
{
/// Standard distribution.
Standard(rand::distr::StandardUniform),
/// Uniform distribution.
Uniform(rand::distr::Uniform<E>),
/// Bernoulli distribution.
Bernoulli(rand::distr::Bernoulli),
/// Normal distribution.
Normal(rand_distr::Normal<f64>),
}
impl<E, R> DistributionSampler<'_, E, R>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
E: Element,
R: Rng,
{
/// Sames a random value from the distribution.
pub fn sample(&mut self) -> E {
match &self.kind {
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Bernoulli(distribution) => {
if self.rng.sample(distribution) {
1.elem()
} else {
0.elem()
}
}
DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
}
}
}
impl Distribution {
/// Creates a new distribution sampler.
///
/// # Arguments
///
/// * `rng` - The random number generator.
///
/// # Returns
///
/// The distribution sampler.
pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
where
R: Rng,
E: Element + rand::distr::uniform::SampleUniform,
StandardUniform: rand::distr::Distribution<E>,
{
let kind = match self {
Distribution::Default => {
DistributionSamplerKind::Standard(rand::distr::StandardUniform {})
}
Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),
),
Distribution::Bernoulli(prob) => {
DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())
}
Distribution::Normal(mean, std) => {
DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
}
};
DistributionSampler::new(kind, rng)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distribution_default() {
let dist: Distribution = Default::default();
assert_eq!(dist, Distribution::Default);
assert_eq!(Distribution::default(), Distribution::Default);
}
}

View File

@@ -0,0 +1,295 @@
use core::cmp::Ordering;
use rand::Rng;
use crate::distribution::Distribution;
use burn_std::{DType, bf16, f16};
#[cfg(feature = "cubecl")]
use burn_std::flex32;
use super::cast::ToElement;
/// Core element trait for tensor values.
///
/// This trait defines the minimal set of capabilities required for a type to be
/// stored and manipulated as a tensor element across all backends.
pub trait Element:
ToElement
+ ElementRandom
+ ElementConversion
+ ElementEq
+ ElementLimits
+ bytemuck::CheckedBitPattern
+ bytemuck::NoUninit
+ bytemuck::Zeroable
+ core::fmt::Debug
+ core::fmt::Display
+ Default
+ Send
+ Sync
+ Copy
+ 'static
{
/// The dtype of the element.
fn dtype() -> DType;
}
/// Ordered element trait for tensor values.
///
/// This trait extends [`Element`] with ordering semantics, enabling comparison
/// and order-dependent operations in generic Rust implementations.
///
/// Backends that implement these operations entirely at the device level do
/// not rely on this trait. It only constrains the scalar type for generic Rust code.
pub trait ElementOrdered: Element + ElementComparison {}
/// Element conversion trait for tensor.
pub trait ElementConversion {
/// Converts an element to another element.
///
/// # Arguments
///
/// * `elem` - The element to convert.
///
/// # Returns
///
/// The converted element.
fn from_elem<E: ToElement>(elem: E) -> Self;
/// Converts and returns the converted element.
fn elem<E: Element>(self) -> E;
}
/// Element trait for random value of a tensor.
pub trait ElementRandom {
/// Returns a random value for the given distribution.
///
/// # Arguments
///
/// * `distribution` - The distribution to sample from.
/// * `rng` - The random number generator.
///
/// # Returns
///
/// The random value.
fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self;
}
/// Element trait for equality of a tensor.
pub trait ElementEq {
/// Returns whether `self` and `other` are equal.
fn eq(&self, other: &Self) -> bool;
}
/// Element ordering trait.
pub trait ElementComparison {
/// Returns and [Ordering] between `self` and `other`.
fn cmp(&self, other: &Self) -> Ordering;
}
/// Element limits trait.
pub trait ElementLimits {
/// The minimum representable value
const MIN: Self;
/// The maximum representable value
const MAX: Self;
}
/// Macro to implement the element trait for a type.
#[macro_export]
macro_rules! make_element {
(
ty $type:ident,
convert $convert:expr,
random $random:expr,
cmp $cmp:expr,
dtype $dtype:expr
) => {
make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
};
(
ty $type:ident,
convert $convert:expr,
random $random:expr,
cmp $cmp:expr,
dtype $dtype:expr,
min $min:expr,
max $max:expr
) => {
impl Element for $type {
#[inline(always)]
fn dtype() -> burn_std::DType {
$dtype
}
}
impl ElementEq for $type {
fn eq(&self, other: &Self) -> bool {
self == other
}
}
impl ElementConversion for $type {
#[inline(always)]
fn from_elem<E: ToElement>(elem: E) -> Self {
#[allow(clippy::redundant_closure_call)]
$convert(&elem)
}
#[inline(always)]
fn elem<E: Element>(self) -> E {
E::from_elem(self)
}
}
impl ElementRandom for $type {
fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self {
#[allow(clippy::redundant_closure_call)]
$random(distribution, rng)
}
}
impl ElementComparison for $type {
fn cmp(&self, other: &Self) -> Ordering {
let a = self.elem::<$type>();
let b = other.elem::<$type>();
#[allow(clippy::redundant_closure_call)]
$cmp(&a, &b)
}
}
impl ElementLimits for $type {
const MIN: Self = $min;
const MAX: Self = $max;
}
impl ElementOrdered for $type {}
};
}
make_element!(
ty f64,
convert ToElement::to_f64,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &f64, b: &f64| a.total_cmp(b),
dtype DType::F64
);
make_element!(
ty f32,
convert ToElement::to_f32,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &f32, b: &f32| a.total_cmp(b),
dtype DType::F32
);
make_element!(
ty i64,
convert ToElement::to_i64,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i64, b: &i64| Ord::cmp(a, b),
dtype DType::I64
);
make_element!(
ty u64,
convert ToElement::to_u64,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u64, b: &u64| Ord::cmp(a, b),
dtype DType::U64
);
make_element!(
ty i32,
convert ToElement::to_i32,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i32, b: &i32| Ord::cmp(a, b),
dtype DType::I32
);
make_element!(
ty u32,
convert ToElement::to_u32,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u32, b: &u32| Ord::cmp(a, b),
dtype DType::U32
);
make_element!(
ty i16,
convert ToElement::to_i16,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i16, b: &i16| Ord::cmp(a, b),
dtype DType::I16
);
make_element!(
ty u16,
convert ToElement::to_u16,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u16, b: &u16| Ord::cmp(a, b),
dtype DType::U16
);
make_element!(
ty i8,
convert ToElement::to_i8,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i8, b: &i8| Ord::cmp(a, b),
dtype DType::I8
);
make_element!(
ty u8,
convert ToElement::to_u8,
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u8, b: &u8| Ord::cmp(a, b),
dtype DType::U8
);
make_element!(
ty f16,
convert ToElement::to_f16,
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
f16::from_elem(sample)
},
cmp |a: &f16, b: &f16| a.total_cmp(b),
dtype DType::F16
);
make_element!(
ty bf16,
convert ToElement::to_bf16,
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
bf16::from_elem(sample)
},
cmp |a: &bf16, b: &bf16| a.total_cmp(b),
dtype DType::BF16
);
#[cfg(feature = "cubecl")]
make_element!(
ty flex32,
convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
flex32::from_elem(sample)
},
cmp |a: &flex32, b: &flex32| a.total_cmp(b),
dtype DType::Flex32,
min flex32::from_f32(f16::MIN.to_f32_const()),
max flex32::from_f32(f16::MAX.to_f32_const())
);
make_element!(
ty bool,
convert ToElement::to_bool,
random |distribution: Distribution, rng: &mut R| {
let sample: u8 = distribution.sampler(rng).sample();
bool::from_elem(sample)
},
cmp |a: &bool, b: &bool| Ord::cmp(a, b),
dtype DType::Bool,
min false,
max true
);

View File

@@ -0,0 +1,706 @@
use core::mem::size_of;
use burn_std::{bf16, f16};
/// A generic trait for converting a value to a number.
/// Adapted from num_traits::ToPrimitive to support [bool].
///
/// A value can be represented by the target type when it lies within
/// the range of scalars supported by the target type.
/// For example, a negative integer cannot be represented by an unsigned
/// integer type, and an `i64` with a very high magnitude might not be
/// convertible to an `i32`.
/// On the other hand, conversions with possible precision loss or truncation
/// are admitted, like an `f32` with a decimal part to an integer type, or
/// even a large `f64` saturating to `f32` infinity.
///
/// The methods *panic* when the value cannot be represented by the target type.
pub trait ToElement {
/// Converts the value of `self` to an `isize`.
#[inline]
fn to_isize(&self) -> isize {
ToElement::to_isize(&self.to_i64())
}
/// Converts the value of `self` to an `i8`.
#[inline]
fn to_i8(&self) -> i8 {
ToElement::to_i8(&self.to_i64())
}
/// Converts the value of `self` to an `i16`.
#[inline]
fn to_i16(&self) -> i16 {
ToElement::to_i16(&self.to_i64())
}
/// Converts the value of `self` to an `i32`.
#[inline]
fn to_i32(&self) -> i32 {
ToElement::to_i32(&self.to_i64())
}
/// Converts the value of `self` to an `i64`.
fn to_i64(&self) -> i64;
/// Converts the value of `self` to an `i128`.
///
/// The default implementation converts through `to_i64()`. Types implementing
/// this trait should override this method if they can represent a greater range.
#[inline]
fn to_i128(&self) -> i128 {
i128::from(self.to_i64())
}
/// Converts the value of `self` to a `usize`.
#[inline]
fn to_usize(&self) -> usize {
ToElement::to_usize(&self.to_u64())
}
/// Converts the value of `self` to a `u8`.
#[inline]
fn to_u8(&self) -> u8 {
ToElement::to_u8(&self.to_u64())
}
/// Converts the value of `self` to a `u16`.
#[inline]
fn to_u16(&self) -> u16 {
ToElement::to_u16(&self.to_u64())
}
/// Converts the value of `self` to a `u32`.
#[inline]
fn to_u32(&self) -> u32 {
ToElement::to_u32(&self.to_u64())
}
/// Converts the value of `self` to a `u64`.
fn to_u64(&self) -> u64;
/// Converts the value of `self` to a `u128`.
///
/// The default implementation converts through `to_u64()`. Types implementing
/// this trait should override this method if they can represent a greater range.
#[inline]
fn to_u128(&self) -> u128 {
u128::from(self.to_u64())
}
/// Converts the value of `self` to an `f16`. Overflows may map to positive
/// or negative infinity.
#[inline]
fn to_f16(&self) -> f16 {
f16::from_f32(self.to_f32())
}
/// Converts the value of `self` to an `bf16`. Overflows may map to positive
/// or negative infinity.
#[inline]
fn to_bf16(&self) -> bf16 {
bf16::from_f32(self.to_f32())
}
/// Converts the value of `self` to an `f32`. Overflows may map to positive
/// or negative infinity.
#[inline]
fn to_f32(&self) -> f32 {
ToElement::to_f32(&self.to_f64())
}
/// Converts the value of `self` to an `f64`. Overflows may map to positive
/// or negative infinity.
///
/// The default implementation tries to convert through `to_i64()`, and
/// failing that through `to_u64()`. Types implementing this trait should
/// override this method if they can represent a greater range.
#[inline]
fn to_f64(&self) -> f64 {
ToElement::to_f64(&self.to_u64())
}
/// Converts the value of `self` to a bool.
/// Rust only considers 0 and 1 to be valid booleans, but for compatibility, C semantics are
/// adopted (anything that's not 0 is true).
///
/// The default implementation tries to convert through `to_i64()`, and
/// failing that through `to_u64()`. Types implementing this trait should
/// override this method if they can represent a greater range.
#[inline]
fn to_bool(&self) -> bool {
ToElement::to_bool(&self.to_u64())
}
}
macro_rules! impl_to_element_int_to_int {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let min = $DstT::MIN as $SrcT;
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
*self as $DstT
} else {
panic!(
"Element cannot be represented in the target type: {:?}({:?}) => {:?}",
core::any::type_name::<$SrcT>(),
self,
core::any::type_name::<$DstT>(),
)
}
}
)*}
}
macro_rules! impl_to_element_int_to_uint {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let max = $DstT::MAX as $SrcT;
if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
*self as $DstT
} else {
panic!(
"Element cannot be represented in the target type: {:?}({:?}) => {:?}",
core::any::type_name::<$SrcT>(),
self,
core::any::type_name::<$DstT>(),
)
}
}
)*}
}
macro_rules! impl_to_element_int {
($T:ident) => {
impl ToElement for $T {
impl_to_element_int_to_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_element_int_to_uint! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
#[inline]
fn to_f32(&self) -> f32 {
*self as f32
}
#[inline]
fn to_f64(&self) -> f64 {
*self as f64
}
#[inline]
fn to_bool(&self) -> bool {
*self != 0
}
}
};
}
impl_to_element_int!(isize);
impl_to_element_int!(i8);
impl_to_element_int!(i16);
impl_to_element_int!(i32);
impl_to_element_int!(i64);
impl_to_element_int!(i128);
macro_rules! impl_to_element_uint_to_int {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
*self as $DstT
} else {
panic!(
"Element cannot be represented in the target type: {:?}({:?}) => {:?}",
core::any::type_name::<$SrcT>(),
self,
core::any::type_name::<$DstT>(),
)
}
}
)*}
}
macro_rules! impl_to_element_uint_to_uint {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
*self as $DstT
} else {
panic!(
"Element cannot be represented in the target type: {:?}({:?}) => {:?}",
core::any::type_name::<$SrcT>(),
self,
core::any::type_name::<$DstT>(),
)
}
}
)*}
}
macro_rules! impl_to_element_uint {
($T:ident) => {
impl ToElement for $T {
impl_to_element_uint_to_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_element_uint_to_uint! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
#[inline]
fn to_f32(&self) -> f32 {
*self as f32
}
#[inline]
fn to_f64(&self) -> f64 {
*self as f64
}
#[inline]
fn to_bool(&self) -> bool {
*self != 0
}
}
};
}
impl_to_element_uint!(usize);
impl_to_element_uint!(u8);
impl_to_element_uint!(u16);
impl_to_element_uint!(u32);
impl_to_element_uint!(u64);
impl_to_element_uint!(u128);
macro_rules! impl_to_element_float_to_float {
($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
fn $method(&self) -> $DstT {
// We can safely cast all values, whether NaN, +-inf, or finite.
// Finite values that are reducing size may saturate to +-inf.
*self as $DstT
}
)*}
}
macro_rules! float_to_int_unchecked {
// SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.
// We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.
($float:expr => $int:ty) => {
unsafe { $float.to_int_unchecked::<$int>() }
};
}
macro_rules! impl_to_element_float_to_signed_int {
($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $i {
// Float as int truncates toward zero, so we want to allow values
// in the exclusive range `(MIN-1, MAX+1)`.
if size_of::<$f>() > size_of::<$i>() {
// With a larger size, we can represent the range exactly.
const MIN_M1: $f = $i::MIN as $f - 1.0;
const MAX_P1: $f = $i::MAX as $f + 1.0;
if *self > MIN_M1 && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $i);
}
} else {
// We can't represent `MIN-1` exactly, but there's no fractional part
// at this magnitude, so we can just use a `MIN` inclusive boundary.
const MIN: $f = $i::MIN as $f;
// We can't represent `MAX` exactly, but it will round up to exactly
// `MAX+1` (a power of two) when we cast it.
const MAX_P1: $f = $i::MAX as $f;
if *self >= MIN && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $i);
}
}
panic!("Float cannot be represented in the target signed int type")
}
)*}
}
macro_rules! impl_to_element_float_to_unsigned_int {
($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $u {
// Float as int truncates toward zero, so we want to allow values
// in the exclusive range `(-1, MAX+1)`.
if size_of::<$f>() > size_of::<$u>() {
// With a larger size, we can represent the range exactly.
const MAX_P1: $f = $u::MAX as $f + 1.0;
if *self > -1.0 && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $u);
}
} else {
// We can't represent `MAX` exactly, but it will round up to exactly
// `MAX+1` (a power of two) when we cast it.
// (`u128::MAX as f32` is infinity, but this is still ok.)
const MAX_P1: $f = $u::MAX as $f;
if *self > -1.0 && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $u);
}
}
panic!("Float cannot be represented in the target unsigned int type")
}
)*}
}
macro_rules! impl_to_element_float {
($T:ident) => {
impl ToElement for $T {
impl_to_element_float_to_signed_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_element_float_to_unsigned_int! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
impl_to_element_float_to_float! { $T:
fn to_f32 -> f32;
fn to_f64 -> f64;
}
#[inline]
fn to_bool(&self) -> bool {
*self != 0.0
}
}
};
}
impl_to_element_float!(f32);
impl_to_element_float!(f64);
impl ToElement for f16 {
#[inline]
fn to_i64(&self) -> i64 {
Self::to_f32(*self).to_i64()
}
#[inline]
fn to_u64(&self) -> u64 {
Self::to_f32(*self).to_u64()
}
#[inline]
fn to_i8(&self) -> i8 {
Self::to_f32(*self).to_i8()
}
#[inline]
fn to_u8(&self) -> u8 {
Self::to_f32(*self).to_u8()
}
#[inline]
fn to_i16(&self) -> i16 {
Self::to_f32(*self).to_i16()
}
#[inline]
fn to_u16(&self) -> u16 {
Self::to_f32(*self).to_u16()
}
#[inline]
fn to_i32(&self) -> i32 {
Self::to_f32(*self).to_i32()
}
#[inline]
fn to_u32(&self) -> u32 {
Self::to_f32(*self).to_u32()
}
#[inline]
fn to_f16(&self) -> f16 {
*self
}
#[inline]
fn to_f32(&self) -> f32 {
Self::to_f32(*self)
}
#[inline]
fn to_f64(&self) -> f64 {
Self::to_f64(*self)
}
#[inline]
fn to_bool(&self) -> bool {
*self != f16::from_f32_const(0.0)
}
}
impl ToElement for bf16 {
#[inline]
fn to_i64(&self) -> i64 {
Self::to_f32(*self).to_i64()
}
#[inline]
fn to_u64(&self) -> u64 {
Self::to_f32(*self).to_u64()
}
#[inline]
fn to_i8(&self) -> i8 {
Self::to_f32(*self).to_i8()
}
#[inline]
fn to_u8(&self) -> u8 {
Self::to_f32(*self).to_u8()
}
#[inline]
fn to_i16(&self) -> i16 {
Self::to_f32(*self).to_i16()
}
#[inline]
fn to_u16(&self) -> u16 {
Self::to_f32(*self).to_u16()
}
#[inline]
fn to_i32(&self) -> i32 {
Self::to_f32(*self).to_i32()
}
#[inline]
fn to_u32(&self) -> u32 {
Self::to_f32(*self).to_u32()
}
#[inline]
fn to_bf16(&self) -> bf16 {
*self
}
#[inline]
fn to_f32(&self) -> f32 {
Self::to_f32(*self)
}
#[inline]
fn to_f64(&self) -> f64 {
Self::to_f64(*self)
}
#[inline]
fn to_bool(&self) -> bool {
*self != bf16::from_f32_const(0.0)
}
}
#[cfg(feature = "cubecl")]
impl ToElement for burn_std::flex32 {
#[inline]
fn to_i64(&self) -> i64 {
Self::to_f32(*self).to_i64()
}
#[inline]
fn to_u64(&self) -> u64 {
Self::to_f32(*self).to_u64()
}
#[inline]
fn to_i8(&self) -> i8 {
Self::to_f32(*self).to_i8()
}
#[inline]
fn to_u8(&self) -> u8 {
Self::to_f32(*self).to_u8()
}
#[inline]
fn to_i16(&self) -> i16 {
Self::to_f32(*self).to_i16()
}
#[inline]
fn to_u16(&self) -> u16 {
Self::to_f32(*self).to_u16()
}
#[inline]
fn to_i32(&self) -> i32 {
Self::to_f32(*self).to_i32()
}
#[inline]
fn to_u32(&self) -> u32 {
Self::to_f32(*self).to_u32()
}
#[inline]
fn to_f32(&self) -> f32 {
Self::to_f32(*self)
}
#[inline]
fn to_f64(&self) -> f64 {
Self::to_f64(*self)
}
#[inline]
fn to_bool(&self) -> bool {
*self != burn_std::flex32::from_f32(0.0)
}
}
impl ToElement for bool {
#[inline]
fn to_i64(&self) -> i64 {
*self as i64
}
#[inline]
fn to_u64(&self) -> u64 {
*self as u64
}
#[inline]
fn to_i8(&self) -> i8 {
*self as i8
}
#[inline]
fn to_u8(&self) -> u8 {
*self as u8
}
#[inline]
fn to_i16(&self) -> i16 {
*self as i16
}
#[inline]
fn to_u16(&self) -> u16 {
*self as u16
}
#[inline]
fn to_i32(&self) -> i32 {
*self as i32
}
#[inline]
fn to_u32(&self) -> u32 {
*self as u32
}
#[inline]
fn to_f32(&self) -> f32 {
self.to_u8() as f32
}
#[inline]
fn to_f64(&self) -> f64 {
self.to_u8() as f64
}
#[inline]
fn to_bool(&self) -> bool {
*self
}
}
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn to_element_float() {
let f32_toolarge = 1e39f64;
assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
assert!((f64::NAN).to_f32().is_nan());
}
#[test]
#[should_panic]
fn to_element_signed_to_u8_underflow() {
let _x = (-1i8).to_u8();
}
#[test]
#[should_panic]
fn to_element_signed_to_u16_underflow() {
let _x = (-1i8).to_u16();
}
#[test]
#[should_panic]
fn to_element_signed_to_u32_underflow() {
let _x = (-1i8).to_u32();
}
#[test]
#[should_panic]
fn to_element_signed_to_u64_underflow() {
let _x = (-1i8).to_u64();
}
#[test]
#[should_panic]
fn to_element_signed_to_u128_underflow() {
let _x = (-1i8).to_u128();
}
#[test]
#[should_panic]
fn to_element_signed_to_usize_underflow() {
let _x = (-1i8).to_usize();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u8_overflow() {
let _x = 256.to_u8();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u16_overflow() {
let _x = 65_536.to_u16();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u32_overflow() {
let _x = 4_294_967_296u64.to_u32();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u64_overflow() {
let _x = 18_446_744_073_709_551_616u128.to_u64();
}
#[test]
fn to_element_int_to_float() {
assert_eq!((-1).to_f32(), -1.0);
assert_eq!((-1).to_f64(), -1.0);
assert_eq!(255.to_f32(), 255.0);
assert_eq!(65_535.to_f64(), 65_535.0);
}
#[test]
fn to_element_float_to_int() {
assert_eq!((-1.0).to_i8(), -1);
assert_eq!(1.0.to_u8(), 1);
assert_eq!(1.8.to_u16(), 1);
assert_eq!(123.456.to_u32(), 123);
}
}

View File

@@ -0,0 +1,10 @@
//! Traits and helpers for working with element types and conversions.
mod base;
mod scalar;
/// Tensor element casting.
pub mod cast;
pub use base::*;
pub use scalar::*;

View File

@@ -0,0 +1,105 @@
use burn_std::{DType, bf16, f16};
use num_traits::ToPrimitive;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use crate::{Element, ElementConversion};
/// A scalar element.
#[derive(Clone, Copy, Debug)]
#[allow(missing_docs)]
pub enum Scalar {
Float(f64),
Int(i64),
UInt(u64),
Bool(bool),
}
impl Scalar {
/// Creates a scalar with the specified data type.
///
/// # Note
/// [`QFloat`](DType::QFloat) scalars are represented as float for element-wise operations.
pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {
if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {
Self::Float(value.elem())
} else if dtype.is_int() {
Self::Int(value.elem())
} else if dtype.is_uint() {
Self::UInt(value.elem())
} else if dtype.is_bool() {
Self::Bool(value.elem())
} else {
unimplemented!("Scalar not supported for {dtype:?}")
}
}
/// Converts and returns the converted element.
pub fn elem<E: Element>(self) -> E {
match self {
Self::Float(x) => x.elem(),
Self::Int(x) => x.elem(),
Self::UInt(x) => x.elem(),
Self::Bool(x) => x.elem(),
}
}
/// Returns the exact integer value, if valid.
pub fn try_as_integer(&self) -> Option<Self> {
match self {
Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),
Scalar::Int(_) | Scalar::UInt(_) => Some(*self),
Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),
}
}
}
macro_rules! impl_from_scalar {
($($ty:ty => $variant:ident),+ $(,)?) => {
$(
impl From<$ty> for Scalar {
fn from(value: $ty) -> Self {
Scalar::$variant(value.elem())
}
}
)+
};
}
impl_from_scalar! {
f64 => Float, f32 => Float, f16 => Float, bf16 => Float,
i64 => Int, i32 => Int, i16 => Int, i8 => Int,
u64 => UInt, u32 => UInt, u16 => UInt, u8 => UInt, bool => Bool,
}
// CubeCL requirement
impl ToPrimitive for Scalar {
fn to_i64(&self) -> Option<i64> {
match self {
Scalar::Float(x) => x.to_i64(),
Scalar::UInt(x) => x.to_i64(),
Scalar::Int(x) => Some(*x),
Scalar::Bool(x) => Some(*x as i64),
}
}
fn to_u64(&self) -> Option<u64> {
match self {
Scalar::Float(x) => x.to_u64(),
Scalar::UInt(x) => Some(*x),
Scalar::Int(x) => x.to_u64(),
Scalar::Bool(x) => Some(*x as u64),
}
}
fn to_f64(&self) -> Option<f64> {
match self {
Scalar::Float(x) => Some(*x),
Scalar::UInt(x) => x.to_f64(),
Scalar::Int(x) => x.to_f64(),
Scalar::Bool(x) => (*x as u8).to_f64(),
}
}
}

View File

@@ -0,0 +1,122 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! This library provides the core types that define how Burn tensor data is represented, stored, and interpreted.
#[macro_use]
extern crate derive_new;
extern crate alloc;
mod data;
pub use data::*;
pub mod distribution;
pub use distribution::*;
pub mod element;
pub use element::*;
/// [`Backend`] trait and required types.
pub mod backend;
pub use backend::*;
/// Backend tensor primitives and operations.
pub mod tensor;
// Re-exported types
pub use burn_std::reader::*; // Useful so that backends don't have to add `burn_std` as a dependency.
pub use burn_std::{
AllocationProperty, Bytes, DType, FloatDType, IntDType, bf16, f16, stream_id::StreamId,
};
/// Shape definition.
pub mod shape {
pub use burn_std::shape::*;
}
pub use shape::*;
/// Slice utilities.
pub mod slice {
pub use burn_std::{s, slice::*};
}
pub use slice::*;
/// Indexing utilities.
pub mod indexing {
pub use burn_std::indexing::*;
}
pub use indexing::*;
/// Quantization data representation.
pub mod quantization {
pub use crate::tensor::quantization::*;
pub use burn_std::quantization::{
BlockSize, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantStore,
QuantValue, QuantizedBytes,
};
}
#[cfg(feature = "cubecl-wgpu")]
mod cube_wgpu {
use crate::backend::DeviceOps;
use cubecl::wgpu::WgpuDevice;
impl DeviceOps for WgpuDevice {}
}
#[cfg(feature = "cubecl-cuda")]
mod cube_cuda {
use crate::backend::DeviceOps;
use cubecl::cuda::CudaDevice;
impl DeviceOps for CudaDevice {}
}
#[cfg(feature = "cubecl-cpu")]
mod cube_cpu {
use crate::backend::DeviceOps;
use cubecl::cpu::CpuDevice;
impl DeviceOps for CpuDevice {}
}
#[cfg(feature = "cubecl-hip")]
mod cube_hip {
use crate::backend::DeviceOps;
use cubecl::hip::AmdDevice;
impl DeviceOps for AmdDevice {}
}
/// Convenience macro to link to the `burn-tensor` docs for this crate version.
///
/// Usage:
/// ```rust,ignore
/// # use burn_backend::doc_tensor;
/// doc_tensor!(); // Links to `Tensor` struct
/// doc_tensor!("zeros"); // Links to `Tensor::zeros` method
/// ```
#[macro_export]
macro_rules! doc_tensor {
() => {
concat!(
"[`Tensor`](https://docs.rs/burn-tensor/",
env!("CARGO_PKG_VERSION"),
"/burn_tensor/struct.Tensor.html)"
)
};
($method:literal) => {
concat!(
"[`Tensor::",
$method,
"`](",
"https://docs.rs/burn-tensor/",
env!("CARGO_PKG_VERSION"),
"/burn_tensor/struct.Tensor.html#method.",
$method,
")"
)
};
}

View File

@@ -0,0 +1,23 @@
use crate::backend::Backend;
// We provide some type aliases to improve the readability of using associated types without
// having to use the disambiguation syntax.
/// Device type used by the backend.
pub type Device<B> = <B as Backend>::Device;
/// Float element type used by backend.
pub type FloatElem<B> = <B as Backend>::FloatElem;
/// Integer element type used by backend.
pub type IntElem<B> = <B as Backend>::IntElem;
/// Boolean element type used by backend.
pub type BoolElem<B> = <B as Backend>::BoolElem;
/// Float tensor primitive type used by the backend.
pub type FloatTensor<B> = <B as Backend>::FloatTensorPrimitive;
/// Integer tensor primitive type used by the backend.
pub type IntTensor<B> = <B as Backend>::IntTensorPrimitive;
/// Boolean tensor primitive type used by the backend.
pub type BoolTensor<B> = <B as Backend>::BoolTensorPrimitive;
/// Quantized tensor primitive type used by the backend.
pub type QuantizedTensor<B> = <B as Backend>::QuantizedTensorPrimitive;

View File

@@ -0,0 +1,92 @@
use alloc::boxed::Box;
use core::any::Any;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[cfg(feature = "std")]
use std::collections::HashMap;
use crate::{TensorPrimitive, backend::Backend};
/// Contains tensor of arbitrary dimension.
#[derive(Debug)]
pub struct TensorContainer<ID> {
tensors: HashMap<ID, Box<dyn Any + Send>>,
}
impl<ID> Default for TensorContainer<ID>
where
ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
{
fn default() -> Self {
Self::new()
}
}
impl<ID> TensorContainer<ID>
where
ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
{
/// Create an empty container.
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
}
}
/// Get a tensor with the given ID.
pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
where
B: Backend,
{
let grad = self.tensors.get(id)?;
let tensor = grad
.downcast_ref::<TensorPrimitive<B>>()
// .map(|primitive| Tensor::<B, D>::from_primitive(primitive.clone()))
.unwrap();
Some(tensor.clone())
}
/// Register a new tensor for the given ID.
///
/// # Notes
///
/// If a tensor is already registered for the given ID, it will be replaced.
pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
where
B: Backend,
{
self.tensors.insert(id, Box::new(value));
}
/// Remove a tensor for the given ID and returns it.
pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
where
B: Backend,
{
self.tensors
.remove(id)
.map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
// .map(|primitive| Tensor::from_primitive(*primitive))
}
/// The number of tensors registered.
pub fn len(&self) -> usize {
self.tensors.len()
}
/// If any tensor is contained.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get id of every tensor in the container
pub fn ids(&self) -> Vec<&ID> {
self.tensors.keys().collect()
}
}

View File

@@ -0,0 +1,44 @@
use crate::{Backend, TensorMetadata, TensorPrimitive};
/// A type-level representation of the kind of a float tensor
#[derive(Clone, Debug)]
pub struct Float;
/// A type-level representation of the kind of a int tensor.
#[derive(Clone, Debug)]
pub struct Int;
/// A type-level representation of the kind of a bool tensor.
#[derive(Clone, Debug)]
pub struct Bool;
/// A type-level representation of the kind of a tensor.
/// Metadata access is lazy.
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
/// The primitive type of the tensor.
type Primitive: TensorMetadata;
/// The name of the tensor kind.
fn name() -> &'static str;
}
impl<B: Backend> TensorKind<B> for Float {
type Primitive = TensorPrimitive<B>;
fn name() -> &'static str {
"Float"
}
}
impl<B: Backend> TensorKind<B> for Int {
type Primitive = B::IntTensorPrimitive;
fn name() -> &'static str {
"Int"
}
}
impl<B: Backend> TensorKind<B> for Bool {
type Primitive = B::BoolTensorPrimitive;
fn name() -> &'static str {
"Bool"
}
}

View File

@@ -0,0 +1,12 @@
mod alias;
mod container;
mod kind;
mod ops;
pub use alias::*;
pub use container::*;
pub use kind::*;
pub use ops::*;
/// Tensor quantization module.
pub mod quantization;

View File

@@ -0,0 +1,49 @@
use crate::{
AutodiffBackend,
tensor::{BasicOps, TensorKind},
};
/// Trait that list all operations that can be applied on all tensors on an autodiff backend.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by the
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// struct.
pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
/// Inner primitive tensor.
type InnerKind: BasicOps<B::InnerBackend>;
/// Returns the inner tensor without the autodiff information.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("inner"))]
#[cfg_attr(not(doc), doc = "`Tensor::inner`")]
/// function, which is more high-level and designed for public use.
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive;
/// Convert a tensor to the autodiff backend.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("from_inner"))]
#[cfg_attr(not(doc), doc = "`Tensor::from_inner`")]
/// function, which is more high-level and designed for public use.
fn from_inner(
inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive;
}

View File

@@ -0,0 +1,807 @@
use alloc::vec::Vec;
use burn_std::{DType, Shape, Slice};
use crate::{
Backend, ExecutionError, Scalar, TensorData, TensorMetadata,
element::Element,
ops::TransactionPrimitive,
tensor::{IndexingUpdateOp, IntTensor, TensorKind},
};
/// Trait that list all operations that can be applied on all tensors.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by the
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// struct.
pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The type of the tensor elements.
type Elem: Element;
/// Creates an empty tensor with the given shape.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device on which the tensor will be allocated.
/// * `dtype` - The target data type.
///
/// # Returns
///
/// The empty tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating empty tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("empty"))]
#[cfg_attr(not(doc), doc = "`Tensor::empty`")]
/// function, which is more high-level and designed for public use.
fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
/// Creates a tensor filled with zeros.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device on which the tensor will be allocated.
/// * `dtype` - The target data type.
///
/// # Returns
///
/// The tensor filled with zeros.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor filled with zeros, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("zeros"))]
#[cfg_attr(not(doc), doc = "`Tensor::zeros`")]
/// function, which is more high-level and designed for public use.
fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
/// Creates a tensor filled with ones.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device on which the tensor will be allocated.
/// * `dtype` - The target data type.
///
/// # Returns
///
/// The tensor filled with ones.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor filled with ones, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("ones"))]
#[cfg_attr(not(doc), doc = "`Tensor::ones`")]
/// function, which is more high-level and designed for public use.
fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
/// Creates a tensor of the given shape where each element is equal to the provided value.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `fill_value` - The value with which to fill the tensor.
/// * `device` - The device on which the tensor will be allocated.
/// * `dtype` - The target data type.
///
/// # Returns
///
/// The tensor filled with the specified value.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating full tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("full"))]
#[cfg_attr(not(doc), doc = "`Tensor::full`")]
/// function, which is more high-level and designed for public use.
fn full(shape: Shape, fill_value: Scalar, device: &B::Device, dtype: DType) -> Self::Primitive;
/// Reshapes the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `shape` - The new shape of the tensor.
///
/// # Returns
///
/// The reshaped tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For reshaping a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("reshape"))]
#[cfg_attr(not(doc), doc = "`Tensor::reshape`")]
/// function, which is more high-level and designed for public use.
fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
/// Transposes a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to transpose.
///
/// # Returns
///
/// The transposed tensor.
fn transpose(tensor: Self::Primitive) -> Self::Primitive;
/// Swaps two dimensions of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to swap the dimensions of.
/// * `dim1` - The first dimension to swap.
/// * `dim2` - The second dimension to swap.
///
/// # Returns
///
/// The tensor with the dimensions swapped.
fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
/// Permutes the dimensions of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to permute the dimensions of.
/// * `axes` - The new order of the dimensions.
///
/// # Returns
///
/// The tensor with the dimensions permuted.
fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
/// Flips the tensor along the given axes.
///
/// # Arguments
///
/// * `tensor` - The tensor to flip.
/// * `axes` - The axes to flip the tensor along.
///
/// # Returns
///
/// The tensor with the axes flipped.
fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
/// Select tensor elements corresponding to the given slices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `slices` - The slices specifying ranges and steps for each dimension.
///
/// # Returns
///
/// The selected elements.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("slice"))]
#[cfg_attr(not(doc), doc = "`Tensor::slice`")]
/// function, which is more high-level and designed for public use.
fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
/// Assigns the given value to the tensor elements corresponding to the given slices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `slices` - The slices specifying which elements to assign, including support for steps.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the assigned values.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For assigning values to elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("slice_assign"))]
#[cfg_attr(not(doc), doc = "`Tensor::slice_assign`")]
/// function, which is more high-level and designed for public use.
fn slice_assign(
tensor: Self::Primitive,
slices: &[Slice],
value: Self::Primitive,
) -> Self::Primitive;
/// Select tensor elements along the given dimension corresponding to the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension along which to select.
/// * `indices` - The indices of the elements to select.
///
/// # Returns
///
/// The selected tensor elements.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements from a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("select"))]
#[cfg_attr(not(doc), doc = "`Tensor::select`")]
/// function, which is more high-level and designed for public use.
fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive;
/// Assign the selected elements along the given dimension corresponding to the given indices
/// from the value tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to assign elements to.
/// * `dim` - The axis along which to assign elements.
/// * `indices` - The indices of the elements to assign.
/// * `values` - The values to assign to the tensor.
/// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis,
/// except for the elements at the specified indices, which are taken from the corresponding
/// element of the values tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For assigning elements to a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("select_assign"))]
#[cfg_attr(not(doc), doc = "`Tensor::select_assign`")]
/// function, which is more high-level and designed for public use.
fn select_assign(
tensor: Self::Primitive,
dim: usize,
indices: IntTensor<B>,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive;
/// Selects elements from a tensor based on a boolean mask.
///
/// # Arguments
///
/// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
/// * `mask` - The boolean mask to use for selecting elements.
/// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
///
/// # Returns
///
/// A tensor with the same shape as the input tensors, where each element is taken from the
/// corresponding element of the left hand side tensor if the corresponding element of the mask
/// is true, and from the corresponding element of the right hand side tensor otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements from a tensor based on a boolean mask, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("mask_where"))]
#[cfg_attr(not(doc), doc = "`Tensor::mask_where`")]
/// function, which is more high-level and designed for public use.
fn mask_where(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
source: Self::Primitive,
) -> Self::Primitive;
/// Fills elements of a tensor based on a boolean mask.
///
/// # Arguments
///
/// * `tensor` - The tensor where will be overwritten with the value
/// when the corresponding element of the mask is true.
/// * `mask` - The boolean mask to use for filling elements.
/// * `value` - The value to fill elements with when the corresponding element of the mask is true.
///
/// # Returns
///
/// A tensor with the same shape as the input tensors, where each element is taken from the
/// corresponding element unmodified if the corresponding element of the mask is false, and
/// filled with the value otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For filling elements of a tensor based on a boolean mask, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("mask_fill"))]
#[cfg_attr(not(doc), doc = "`Tensor::mask_fill`")]
/// function, which is more high-level and designed for public use.
fn mask_fill(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
value: Scalar,
) -> Self::Primitive;
/// Gathers elements from a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to gather elements.
/// * `tensor` - The tensor to gather elements from.
/// * `indices` - The indices of the elements to gather.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For gathering elements from a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("gather"))]
#[cfg_attr(not(doc), doc = "`Tensor::gather`")]
/// function, which is more high-level and designed for public use.
fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;
/// Scatters elements into a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to scatter elements.
/// * `tensor` - The tensor to scatter elements into.
/// * `indices` - The indices of the elements to scatter.
/// * `values` - The values to scatter into the tensor.
/// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis,
/// except for the elements at the specified indices, which are taken from the corresponding
/// element of the values tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For scattering elements into a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("scatter"))]
#[cfg_attr(not(doc), doc = "`Tensor::scatter`")]
/// function, which is more high-level and designed for public use.
fn scatter(
dim: usize,
tensor: Self::Primitive,
indices: IntTensor<B>,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive;
/// Returns the device on which the tensor is allocated.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device on which the tensor is allocated.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the device of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("device"))]
#[cfg_attr(not(doc), doc = "`Tensor::device`")]
/// function, which is more high-level and designed for public use.
fn device(tensor: &Self::Primitive) -> B::Device;
/// Moves the tensor to the given device.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `device` - The device on which the tensor will be moved.
///
/// # Returns
///
/// The tensor on the given device.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For moving a tensor to a device, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("to_device"))]
#[cfg_attr(not(doc), doc = "`Tensor::to_device`")]
/// function, which is more high-level and designed for public use.
#[allow(clippy::wrong_self_convention)]
fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
/// Extracts the data from the tensor asynchronously.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For extracting the data of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("into_data"))]
#[cfg_attr(not(doc), doc = "`Tensor::into_data`")]
/// function, which is more high-level and designed for public use.
#[allow(clippy::wrong_self_convention)]
fn into_data_async(
tensor: Self::Primitive,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
/// Read the data from the tensor using a transaction.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive);
/// Creates a tensor from the given data.
///
/// # Arguments
///
/// * `data` - The data of the tensor.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// The tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor from data, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("from_data"))]
#[cfg_attr(not(doc), doc = "`Tensor::from_data`")]
/// function, which is more high-level and designed for public use.
fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
/// Creates a tensor from the given data enforcing the given data type.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor from data, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("from_data_dtype"))]
#[cfg_attr(not(doc), doc = "`Tensor::from_data_dtype`")]
/// function, which is more high-level and designed for public use.
fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
/// Repeat the tensor along the given dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension along which the tensor will be repeated.
/// * `times` - The number of times the tensor will be repeated.
///
/// # Returns
///
/// The repeated tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For repeating a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("repeat_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::repeat_dim`")]
/// function, which is more high-level and designed for public use.
fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
/// Concatenates the given tensors along the given dimension.
///
/// # Arguments
///
/// * `vectors` - The tensors to concatenate.
/// * `dim` - The dimension along which the tensors will be concatenated.
///
/// # Returns
///
/// The concatenated tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For concatenating tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("cat"))]
#[cfg_attr(not(doc), doc = "`Tensor::cat`")]
/// function, which is more high-level and designed for public use.
fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
/// Equates the given tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor of booleans indicating whether the corresponding elements are equal.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For equating tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("equal"))]
#[cfg_attr(not(doc), doc = "`Tensor::equal`")]
/// function, which is more high-level and designed for public use.
fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
/// Element-wise equality between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding elements of the input tensors are equal, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise equality between two tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("equal_elem"))]
#[cfg_attr(not(doc), doc = "`Tensor::equal_elem`")]
/// function, which is more high-level and designed for public use.
fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
/// Applies element-wise non-equality comparison between the given tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor of booleans indicating whether the corresponding elements are equal.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For non-equality comparison of tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("not_equal"))]
#[cfg_attr(not(doc), doc = "`Tensor::not_equal`")]
/// function, which is more high-level and designed for public use.
fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
/// Element-wise non-equality between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding elements of the input tensors are equal, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise non-equality between two tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("not_equal_elem"))]
#[cfg_attr(not(doc), doc = "`Tensor::not_equal_elem`")]
/// function, which is more high-level and designed for public use.
fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
/// Returns the name of the element type.
fn elem_type_name() -> &'static str {
core::any::type_name::<Self::Elem>()
}
/// Returns the tensor data type.
fn dtype(tensor: &Self::Primitive) -> DType {
tensor.dtype()
}
/// Tests if any element in the `tensor` evaluates to True.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
///
/// # Returns
///
/// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("any"))]
#[cfg_attr(not(doc), doc = "`Tensor::any`")]
/// function, which is more high-level and designed for public use.
fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
/// Tests if any element in the tensor evaluates to True along a given dimension dim.
///
/// # Arguments
///
/// * tensor - The tensor to test.
/// * dim - The axis along which to test.
///
/// # Returns
///
/// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
/// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("any_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::any_dim`")]
/// function, which is more high-level and designed for public use.
fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
/// Tests if all elements in the `tensor` evaluate to True.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
///
/// # Returns
///
/// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("all"))]
#[cfg_attr(not(doc), doc = "`Tensor::all`")]
/// function, which is more high-level and designed for public use.
fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
/// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
///
/// # Arguments
///
/// * `tensor` - The tensor to test.
///
/// # Returns
///
/// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
/// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("all_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::all_dim`")]
/// function, which is more high-level and designed for public use.
fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
/// Broadcasts the given tensor to the specified shape.
///
/// # Arguments
///
/// * `tensor` - The tensor to broadcast.
/// * `shape` - The shape to broadcast to.
///
/// # Returns
///
/// The broadcasted tensor.
fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
/// Unfold windows along a dimension.
///
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
/// where windows are advanced by `step` at each index.
///
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
///
/// # Warning
///
/// For the `ndarray` and `candle` backends; this is not a view but a full copy.
///
/// # Arguments
///
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
/// * `dim` - the dimension to unfold.
/// * `size` - the size of each unfolded window.
/// * `step` - the step between each window.
///
/// # Returns
///
/// A tensor view with shape ``[pre=..., windows, post=..., size]``.
fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
}

View File

@@ -0,0 +1,218 @@
use alloc::vec::Vec;
use burn_std::{DType, Shape, Slice};
use crate::{
AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
element::Element,
ops::TransactionPrimitive,
tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind},
};
impl<B: Backend> BasicOps<B> for Bool {
type Elem = B::BoolElem;
fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
if dtype != Self::Elem::dtype() {
panic!("Expected bool data type, got {dtype:?}");
}
B::bool_empty(shape, device)
}
fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
if dtype != Self::Elem::dtype() {
panic!("Expected bool data type, got {dtype:?}");
}
B::bool_zeros(shape, device)
}
fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
if dtype != Self::Elem::dtype() {
panic!("Expected bool data type, got {dtype:?}");
}
B::bool_ones(shape, device)
}
fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
if dtype != Self::Elem::dtype() {
panic!("Expected bool data type, got {dtype:?}");
}
if fill_value.elem() {
B::bool_ones(shape, device)
} else {
B::bool_zeros(shape, device)
}
}
fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
tr.register_bool(tensor);
}
fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
B::bool_reshape(tensor, shape)
}
fn transpose(tensor: Self::Primitive) -> Self::Primitive {
B::bool_transpose(tensor)
}
fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
B::bool_swap_dims(tensor, dim1, dim2)
}
fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
B::bool_slice(tensor, slices)
}
fn slice_assign(
tensor: Self::Primitive,
slices: &[Slice],
value: Self::Primitive,
) -> Self::Primitive {
B::bool_slice_assign(tensor, slices, value)
}
fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
B::bool_select(tensor, dim, indices)
}
fn select_assign(
tensor: Self::Primitive,
dim: usize,
indices: IntTensor<B>,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive {
match update {
IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
}
}
fn mask_where(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
source: Self::Primitive,
) -> Self::Primitive {
B::bool_mask_where(tensor, mask, source)
}
fn mask_fill(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
value: Scalar,
) -> Self::Primitive {
B::bool_mask_fill(tensor, mask, value)
}
fn gather(
dim: usize,
tensor: Self::Primitive,
indices: B::IntTensorPrimitive,
) -> Self::Primitive {
B::bool_gather(dim, tensor, indices)
}
fn scatter(
dim: usize,
tensor: Self::Primitive,
indices: B::IntTensorPrimitive,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive {
match update {
IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
}
}
fn device(tensor: &Self::Primitive) -> Device<B> {
B::bool_device(tensor)
}
fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
B::bool_to_device(tensor, device)
}
async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
B::bool_into_data(tensor).await
}
fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
B::bool_from_data(data.convert::<B::BoolElem>(), device)
}
fn from_data_dtype(data: TensorData, device: &Device<B>, _dtype: DType) -> Self::Primitive {
// Bool tensors have exactly one representation per backend, so the
// requested dtype is irrelevant. Convert to `B::BoolElem` directly.
B::bool_from_data(data.convert::<B::BoolElem>(), device)
}
fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
B::bool_repeat_dim(tensor, dim, times)
}
fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::bool_equal(lhs, rhs)
}
fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::bool_not_equal(lhs, rhs)
}
fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::bool_equal_elem(lhs, rhs)
}
fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::bool_not_equal_elem(lhs, rhs)
}
fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
B::bool_cat(vectors, dim)
}
fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
B::bool_any(tensor)
}
fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
B::bool_any_dim(tensor, dim)
}
fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
B::bool_all(tensor)
}
fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
B::bool_all_dim(tensor, dim)
}
fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
B::bool_permute(tensor, axes)
}
fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
B::bool_expand(tensor, shape)
}
fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
B::bool_flip(tensor, axes)
}
fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
B::bool_unfold(tensor, dim, size, step)
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
type InnerKind = Bool;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
B::bool_inner(tensor)
}
fn from_inner(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive {
B::bool_from_inner(inner)
}
}

View File

@@ -0,0 +1,700 @@
use alloc::vec::Vec;
use burn_std::{DType, Shape, Slice};
use crate::{
AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorPrimitive,
ops::TransactionPrimitive,
tensor::{
BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
TensorKind,
},
};
macro_rules! q_bin_ops {
($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
match ($lhs, $rhs) {
(TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
TensorPrimitive::Float(B::$op(lhs, rhs))
}
(TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
(TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
}
(TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
}
}
};
}
impl<B: Backend> BasicOps<B> for Float {
type Elem = B::FloatElem;
fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
}
fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
}
fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
}
fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
}
fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
tr.register_float(tensor);
}
fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_reshape(tensor, shape))
}
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
}
}
fn transpose(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
}
}
fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
}
}
}
fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_slice(tensor, slices))
}
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
}
}
fn slice_assign(
tensor: Self::Primitive,
slices: &[Slice],
value: Self::Primitive,
) -> Self::Primitive {
TensorPrimitive::Float(B::float_slice_assign(
tensor.tensor(),
slices,
value.tensor(),
))
}
fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_select(tensor, dim, indices))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
}
}
}
fn select_assign(
tensor: Self::Primitive,
dim: usize,
indices: IntTensor<B>,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive {
// Select assign is ambiguous for QFloat
match update {
IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
tensor.tensor(),
dim,
indices,
values.tensor(),
)),
}
}
fn mask_where(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
source: Self::Primitive,
) -> Self::Primitive {
TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
}
fn mask_fill(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
value: Scalar,
) -> Self::Primitive {
TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
}
fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
}
}
}
fn scatter(
dim: usize,
tensor: Self::Primitive,
indices: IntTensor<B>,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive {
match update {
IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
dim,
tensor.tensor(),
indices,
values.tensor(),
)),
}
}
fn device(tensor: &Self::Primitive) -> Device<B> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_device(tensor),
TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
}
}
fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_to_device(tensor, device))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_to_device(tensor, device))
}
}
}
async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
}
}
fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
match &data.dtype {
DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
_ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
}
}
fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
match dtype {
DType::QFloat(_scheme) => {
TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
}
_ if dtype.is_float() => {
TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
}
_ => panic!("Expected float dtype, got {dtype:?}"),
}
}
fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
}
}
}
fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
match vectors.first().unwrap() {
TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
dim,
)),
TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
vectors
.into_iter()
.map(|tensor| {
if let TensorPrimitive::QFloat(t) = tensor {
t
} else {
panic!("Concatenation only works with vector of QFloat")
}
})
.collect(),
dim,
)),
}
}
fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_equal(lhs.tensor(), rhs.tensor())
}
fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_not_equal(lhs.tensor(), rhs.tensor())
}
fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::float_equal_elem(lhs.tensor(), rhs)
}
fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::float_not_equal_elem(lhs.tensor(), rhs)
}
fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_any(tensor.tensor())
}
fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
B::float_any_dim(tensor.tensor(), dim)
}
fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_all(tensor.tensor())
}
fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
B::float_all_dim(tensor.tensor(), dim)
}
fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_permute(tensor, axes))
}
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
}
}
fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
}
fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
}
}
fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
}
}
impl<B: Backend> Numeric<B> for Float {
fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
q_bin_ops!(lhs, rhs, float_add, q_add)
}
fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
match lhs {
TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
}
}
fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
q_bin_ops!(lhs, rhs, float_sub, q_sub)
}
fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
match lhs {
TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
}
}
fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
q_bin_ops!(lhs, rhs, float_div, q_div)
}
fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
match lhs {
TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
}
}
fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
}
fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
}
fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
q_bin_ops!(lhs, rhs, float_mul, q_mul)
}
fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
match lhs {
TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
}
}
fn neg(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
}
}
fn sum(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
}
}
fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
}
}
fn prod(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
}
}
fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
}
TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
}
}
fn mean(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
}
}
fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
}
TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
}
}
fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
}
}
fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
}
}
fn abs(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
}
}
fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
q_bin_ops!(lhs, rhs, float_powf, q_powf)
}
fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
match lhs {
TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
}
}
fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
q_bin_ops!(lhs, rhs, float_powf, q_powf)
}
fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
match lhs {
TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
}
}
fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
TensorPrimitive::Float(B::float_random(shape, distribution, device))
}
fn sign(tensor: Self::Primitive) -> Self::Primitive {
TensorPrimitive::Float(B::float_sign(tensor.tensor()))
}
/// Applies the matrix multiplication operation.
///
/// `C = AB`
///
/// # Panics
///
/// If the two tensors don't have a compatible shape.
fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
match (lhs, rhs) {
(TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
TensorPrimitive::Float(B::float_matmul(lhs, rhs))
}
(lhs, rhs) => B::q_matmul(lhs, rhs),
}
}
}
impl<B: Backend> Ordered<B> for Float {
fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
}
}
}
fn sort_with_indices(
tensor: Self::Primitive,
dim: usize,
descending: bool,
) -> (Self::Primitive, IntTensor<B>) {
match tensor {
TensorPrimitive::Float(tensor) => {
let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
(TensorPrimitive::Float(values), indices)
}
TensorPrimitive::QFloat(tensor) => {
let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
(TensorPrimitive::QFloat(values), indices)
}
}
}
fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
}
}
fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
}
}
fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
}
}
fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_greater(lhs.tensor(), rhs.tensor())
}
fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::float_greater_elem(lhs.tensor(), rhs)
}
fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_greater_equal(lhs.tensor(), rhs.tensor())
}
fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::float_greater_equal_elem(lhs.tensor(), rhs)
}
fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_lower(lhs.tensor(), rhs.tensor())
}
fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::float_lower_elem(lhs.tensor(), rhs)
}
fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::float_lower_equal(lhs.tensor(), rhs.tensor())
}
fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::float_lower_equal_elem(lhs.tensor(), rhs)
}
fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
}
}
fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
}
}
fn max(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
}
}
fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
}
}
fn max_dim_with_indices(
tensor: Self::Primitive,
dim: usize,
) -> (Self::Primitive, IntTensor<B>) {
match tensor {
TensorPrimitive::Float(tensor) => {
let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
(TensorPrimitive::Float(values), indices)
}
TensorPrimitive::QFloat(tensor) => {
let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
(TensorPrimitive::QFloat(values), indices)
}
}
}
fn min(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
}
}
fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
}
}
fn min_dim_with_indices(
tensor: Self::Primitive,
dim: usize,
) -> (Self::Primitive, IntTensor<B>) {
match tensor {
TensorPrimitive::Float(tensor) => {
let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
(TensorPrimitive::Float(values), indices)
}
TensorPrimitive::QFloat(tensor) => {
let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
(TensorPrimitive::QFloat(values), indices)
}
}
}
fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_clamp(tensor, min, max))
}
TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
}
}
fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_clamp_min(tensor, min))
}
TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
}
}
fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_clamp_max(tensor, max))
}
TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
}
}
fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
}
}
fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
}
}
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
type InnerKind = Float;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
}
}
fn from_inner(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive {
match inner {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
}
}
}

View File

@@ -0,0 +1,426 @@
use alloc::vec::Vec;
use burn_std::{DType, Shape, Slice};
use crate::{
AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
ops::TransactionPrimitive,
tensor::{
BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
Ordered, TensorKind,
},
};
impl<B: Backend> BasicOps<B> for Int {
type Elem = B::IntElem;
fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
B::int_empty(shape, device, dtype.into())
}
fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
B::int_zeros(shape, device, dtype.into())
}
fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
B::int_ones(shape, device, dtype.into())
}
fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
B::int_full(shape, fill_value, device, dtype.into())
}
fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
tr.register_int(tensor);
}
fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
B::int_reshape(tensor, shape)
}
fn transpose(tensor: Self::Primitive) -> Self::Primitive {
B::int_transpose(tensor)
}
fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
B::int_swap_dims(tensor, dim1, dim2)
}
fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
B::int_slice(tensor, slices)
}
fn slice_assign(
tensor: Self::Primitive,
slices: &[Slice],
value: Self::Primitive,
) -> Self::Primitive {
B::int_slice_assign(tensor, slices, value)
}
fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
B::int_select(tensor, dim, indices)
}
fn select_assign(
tensor: Self::Primitive,
dim: usize,
indices: IntTensor<B>,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive {
match update {
IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
}
}
fn mask_where(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
source: Self::Primitive,
) -> Self::Primitive {
B::int_mask_where(tensor, mask, source)
}
fn mask_fill(
tensor: Self::Primitive,
mask: B::BoolTensorPrimitive,
value: Scalar,
) -> Self::Primitive {
B::int_mask_fill(tensor, mask, value)
}
fn gather(
dim: usize,
tensor: Self::Primitive,
indices: B::IntTensorPrimitive,
) -> Self::Primitive {
B::int_gather(dim, tensor, indices)
}
fn scatter(
dim: usize,
tensor: Self::Primitive,
indices: B::IntTensorPrimitive,
values: Self::Primitive,
update: IndexingUpdateOp,
) -> Self::Primitive {
match update {
IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
}
}
fn device(tensor: &Self::Primitive) -> Device<B> {
B::int_device(tensor)
}
fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
B::int_to_device(tensor, device)
}
async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
B::int_into_data(tensor).await
}
fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
B::int_from_data(data.convert::<B::IntElem>(), device)
}
fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
if !dtype.is_int() {
panic!("Expected int dtype, got {dtype:?}")
}
B::int_from_data(data.convert_dtype(dtype), device)
}
fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
B::int_repeat_dim(tensor, dim, times)
}
fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
B::int_equal(lhs, rhs)
}
fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
B::int_not_equal(lhs, rhs)
}
fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::int_equal_elem(lhs, rhs)
}
fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::int_not_equal_elem(lhs, rhs)
}
fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
B::int_cat(vectors, dim)
}
fn any(tensor: Self::Primitive) -> BoolTensor<B> {
B::int_any(tensor)
}
fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
B::int_any_dim(tensor, dim)
}
fn all(tensor: Self::Primitive) -> BoolTensor<B> {
B::int_all(tensor)
}
fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
B::int_all_dim(tensor, dim)
}
fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
B::int_permute(tensor, axes)
}
fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
B::int_expand(tensor, shape)
}
fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
B::int_flip(tensor, axes)
}
fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
B::int_unfold(tensor, dim, size, step)
}
}
impl<B: Backend> Numeric<B> for Int {
fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_add(lhs, rhs)
}
fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_add_scalar(lhs, rhs)
}
fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_sub(lhs, rhs)
}
fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_sub_scalar(lhs, rhs)
}
fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_div(lhs, rhs)
}
fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_div_scalar(lhs, rhs)
}
fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_remainder(lhs, rhs)
}
fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_remainder_scalar(lhs, rhs)
}
fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_mul(lhs, rhs)
}
fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_mul_scalar(lhs, rhs)
}
fn neg(tensor: Self::Primitive) -> Self::Primitive {
B::int_neg(tensor)
}
fn sum(tensor: Self::Primitive) -> Self::Primitive {
B::int_sum(tensor)
}
fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_sum_dim(tensor, dim)
}
fn prod(tensor: Self::Primitive) -> Self::Primitive {
B::int_prod(tensor)
}
fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_prod_dim(tensor, dim)
}
fn mean(tensor: Self::Primitive) -> Self::Primitive {
B::int_mean(tensor)
}
fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_mean_dim(tensor, dim)
}
fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_cumsum(tensor, dim)
}
fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_cumprod(tensor, dim)
}
fn abs(tensor: Self::Primitive) -> Self::Primitive {
B::int_abs(tensor)
}
fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_powf(lhs, B::int_into_float(rhs))
}
fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_powf_scalar(lhs, rhs)
}
fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_powi(lhs, rhs)
}
fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
B::int_powi_scalar(lhs, rhs)
}
fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
B::int_random(shape, distribution, device)
}
fn sign(tensor: Self::Primitive) -> Self::Primitive {
B::int_sign(tensor)
}
/// Applies the matrix multiplication operation.
///
/// `C = AB`
///
/// # Panics
///
/// If the two tensors don't have a compatible shape.
fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
B::int_matmul(lhs, rhs)
}
}
impl<B: Backend> Ordered<B> for Int {
fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
B::int_sort(tensor, dim, descending)
}
fn sort_with_indices(
tensor: Self::Primitive,
dim: usize,
descending: bool,
) -> (Self::Primitive, IntTensor<B>) {
B::int_sort_with_indices(tensor, dim, descending)
}
fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
B::int_argsort(tensor, dim, descending)
}
fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_cummin(tensor, dim)
}
fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_cummax(tensor, dim)
}
fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::int_greater(lhs, rhs)
}
fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::int_greater_elem(lhs, rhs)
}
fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::int_greater_equal(lhs, rhs)
}
fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::int_greater_equal_elem(lhs, rhs)
}
fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::int_lower(lhs, rhs)
}
fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::int_lower_elem(lhs, rhs)
}
fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
B::int_lower_equal(lhs, rhs)
}
fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
B::int_lower_equal_elem(lhs, rhs)
}
fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
B::int_argmax(tensor, dim)
}
fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
B::int_argmin(tensor, dim)
}
fn max(tensor: Self::Primitive) -> Self::Primitive {
B::int_max(tensor)
}
fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_max_dim(tensor, dim)
}
fn max_dim_with_indices(
tensor: Self::Primitive,
dim: usize,
) -> (Self::Primitive, IntTensor<B>) {
B::int_max_dim_with_indices(tensor, dim)
}
fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
B::int_max_abs(tensor)
}
fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_max_abs_dim(tensor, dim)
}
fn min(tensor: Self::Primitive) -> Self::Primitive {
B::int_min(tensor)
}
fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
B::int_min_dim(tensor, dim)
}
fn min_dim_with_indices(
tensor: Self::Primitive,
dim: usize,
) -> (Self::Primitive, IntTensor<B>) {
B::int_min_dim_with_indices(tensor, dim)
}
fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
B::int_clamp(tensor, min, max)
}
fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
B::int_clamp_min(tensor, min)
}
fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
B::int_clamp_max(tensor, max)
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
type InnerKind = Int;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
B::int_inner(tensor)
}
fn from_inner(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive {
B::int_from_inner(inner)
}
}

View File

@@ -0,0 +1,21 @@
mod autodiff;
mod base;
mod bool;
mod float;
mod int;
mod numeric;
mod ordered;
pub use autodiff::*;
pub use base::*;
pub use numeric::*;
pub use ordered::*;
/// Computation to be used to update the existing values in indexed assignment operations (scatter/select).
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum IndexingUpdateOp {
// Assign,
/// Performs an addition.
Add,
// Mul
}

View File

@@ -0,0 +1,556 @@
use burn_std::Shape;
use crate::{Backend, Distribution, Scalar, element::Element, tensor::BasicOps};
/// Trait that list all operations that can be applied on all numerical tensors.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by the
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// struct.
pub trait Numeric<B: Backend>: BasicOps<B>
where
Self::Elem: Element,
{
/// Adds two tensors together.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The sum of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For adding tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("add"))]
#[cfg_attr(not(doc), doc = "`Tensor::add`")]
/// function, which is more high-level and designed for public use.
fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Adds a scalar to a tensor element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The sum of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For adding a scalar to a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("add_scalar"))]
#[cfg_attr(not(doc), doc = "`Tensor::add_scalar`")]
/// function, which is more high-level and designed for public use.
fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Subtracts two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The difference of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For subtracting tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sub"))]
#[cfg_attr(not(doc), doc = "`Tensor::sub`")]
/// function, which is more high-level and designed for public use.
fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Subtracts a scalar from a tensor element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The difference of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For subtracting a scalar from a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sub_scalar"))]
#[cfg_attr(not(doc), doc = "`Tensor::sub_scalar`")]
/// function, which is more high-level and designed for public use.
fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Divides two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The quotient of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For dividing tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("div"))]
#[cfg_attr(not(doc), doc = "`Tensor::div`")]
/// function, which is more high-level and designed for public use.
fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Divides a tensor by a scalar element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The quotient of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For dividing a tensor by a scalar, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("div_scalar"))]
#[cfg_attr(not(doc), doc = "`Tensor::div_scalar`")]
/// function, which is more high-level and designed for public use.
fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
/// less than that of the divisor.
///
/// # Arguments
///
/// * `lhs` - The dividend.
/// * `rhs` - The divisor.
///
/// # Returns
///
/// The modulo of the input tensor with the divisor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For performing the modulo operation, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("remainder"))]
#[cfg_attr(not(doc), doc = "`Tensor::remainder`")]
/// function, which is more high-level and designed for public use.
fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
/// less than that of the divisor.
///
/// # Arguments
///
/// * `lhs` - The dividend.
/// * `rhs` - The divisor.
///
/// # Returns
///
/// The modulo of the input tensor with the divisor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For performing the modulo operation, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("remainder_scalar"))]
#[cfg_attr(not(doc), doc = "`Tensor::remainder_scalar`")]
/// function, which is more high-level and designed for public use.
fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Multiplies two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The product of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For multiplying tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("mul"))]
#[cfg_attr(not(doc), doc = "`Tensor::mul`")]
/// function, which is more high-level and designed for public use.
fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Multiplies a tensor by a scalar element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The product of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For multiplying a tensor by a scalar, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("mul_scalar"))]
#[cfg_attr(not(doc), doc = "`Tensor::mul_scalar`")]
/// function, which is more high-level and designed for public use.
fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Negates a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to negate.
///
/// # Returns
///
/// The negated tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For negating a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("neg"))]
#[cfg_attr(not(doc), doc = "`Tensor::neg`")]
/// function, which is more high-level and designed for public use.
fn neg(tensor: Self::Primitive) -> Self::Primitive;
/// Returns the signs of the elements of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The signs of the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the signs of the elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sign"))]
#[cfg_attr(not(doc), doc = "`Tensor::sign`")]
/// function, which is more high-level and designed for public use.
fn sign(tensor: Self::Primitive) -> Self::Primitive;
/// Sums all the elements of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
///
/// # Returns
///
/// The sum of all the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For summing all the elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sum"))]
#[cfg_attr(not(doc), doc = "`Tensor::sum`")]
/// function, which is more high-level and designed for public use.
fn sum(tensor: Self::Primitive) -> Self::Primitive;
/// Sums all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
/// * `dim` - The dimension along which to sum.
///
/// # Returns
///
/// The sum of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For summing all the elements of a tensor along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sum_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::sum_dim`")]
/// function, which is more high-level and designed for public use.
fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Computes the product of all the elements of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the product of.
///
/// # Returns
///
/// The product of all the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the product of all the elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("prod"))]
#[cfg_attr(not(doc), doc = "`Tensor::prod`")]
/// function, which is more high-level and designed for public use.
fn prod(tensor: Self::Primitive) -> Self::Primitive;
/// Computes the product of all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the product of.
/// * `dim` - The dimension along which to compute the product.
///
/// # Returns
///
/// The product of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the product of all the elements of a tensor along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("prod_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::prod_dim`")]
/// function, which is more high-level and designed for public use.
fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Computes the mean of all the elements of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the mean of.
///
/// # Returns
///
/// The mean of all the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the mean of all the elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("mean"))]
#[cfg_attr(not(doc), doc = "`Tensor::mean`")]
/// function, which is more high-level and designed for public use.
fn mean(tensor: Self::Primitive) -> Self::Primitive;
/// Computes the mean of all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the mean of.
/// * `dim` - The dimension along which to compute the mean.
///
/// # Returns
///
/// The mean of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the mean of all the elements of a tensor along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("mean_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::mean_dim`")]
/// function, which is more high-level and designed for public use.
fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Computes the cumulative sum of elements along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the cumulative sum of.
/// * `dim` - The dimension along which to compute the cumulative sum.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the cumulative sum
/// of all elements up to and including that position along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the cumulative sum of elements along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("cumsum"))]
#[cfg_attr(not(doc), doc = "`Tensor::cumsum`")]
/// function, which is more high-level and designed for public use.
fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Computes the cumulative product of elements along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the cumulative product of.
/// * `dim` - The dimension along which to compute the cumulative product.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the cumulative product
/// of all elements up to and including that position along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the cumulative product of elements along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("cumprod"))]
#[cfg_attr(not(doc), doc = "`Tensor::cumprod`")]
/// function, which is more high-level and designed for public use.
fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Calculate absolute value on all elements of a tensor
///
/// # Arguments
///
/// * `tensor` - The tensor to apply abs to.
///
/// # Returns
///
/// A tensor with absolute values.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For calculating abs of the elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("abs"))]
#[cfg_attr(not(doc), doc = "`Tensor::abs`")]
/// function, which is more high-level and designed for public use.
fn abs(tensor: Self::Primitive) -> Self::Primitive;
/// Element-wise power of a tensor to a float tensor
///
/// # Arguments
/// * `tensor` - The tensor to apply power to.
/// * `power` - The power to apply to the tensor.
fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Element-wise power of a tensor
///
/// # Arguments
/// * `tensor` - The tensor to apply power to.
/// * `power` - The power to apply to the tensor.
fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
/// Element-wise power of a tensor to a scalar float
///
/// # Arguments
/// * `tensor` - The tensor to apply power to.
/// * `power` - The power to apply to the tensor.
fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Element-wise power of a tensor to a scalar int
///
/// # Arguments
/// * `tensor` - The tensor to apply power to.
/// * `power` - The power to apply to the tensor.
fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
/// Create a random tensor.
///
/// # Arguments
///
/// * `shape` - The shape of the output tensor.
/// * `distribution` - The distribution used to sample.
/// * `device` - The device to use.
///
/// # Returns
///
/// A new tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("random"))]
#[cfg_attr(not(doc), doc = "`Tensor::random`")]
/// function, which is more high-level and designed for public use.
fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
/// Applies the matrix multiplication operation.
///
/// ```math
/// C = AB
/// ```
fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
}

View File

@@ -0,0 +1,650 @@
use crate::{
Backend, Scalar,
tensor::{IntTensor, Numeric},
};
/// Trait that list all operations that can be applied on all numerical tensors
/// whose elements have a well-defined ordering.
///
/// This includes operations such as comparisons, minimum/maximum reductions,
/// and other order-dependent computations that are not strictly valid for all numerical
/// types.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by the
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// struct.
pub trait Ordered<B: Backend>: Numeric<B> {
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sort"))]
#[cfg_attr(not(doc), doc = "`Tensor::sort`")]
/// function, which is more high-level and designed for public use.
fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// the elements are sorted by value and the indices map back to the original input tensor.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For sorting the elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("sort_with_indices"))]
#[cfg_attr(not(doc), doc = "`Tensor::sort_with_indices`")]
/// function, which is more high-level and designed for public use.
fn sort_with_indices(
tensor: Self::Primitive,
dim: usize,
descending: bool,
) -> (Self::Primitive, IntTensor<B>);
/// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("argsort"))]
#[cfg_attr(not(doc), doc = "`Tensor::argsort`")]
/// function, which is more high-level and designed for public use.
fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B>;
/// Computes the cumulative minimum of elements along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the cumulative minimum of.
/// * `dim` - The dimension along which to compute the cumulative minimum.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the minimum
/// of all elements up to and including that position along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the cumulative minimum of elements along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("cummin"))]
#[cfg_attr(not(doc), doc = "`Tensor::cummin`")]
/// function, which is more high-level and designed for public use.
fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Computes the cumulative maximum of elements along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the cumulative maximum of.
/// * `dim` - The dimension along which to compute the cumulative maximum.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the maximum
/// of all elements up to and including that position along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the cumulative maximum of elements along a dimension, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("cummax"))]
#[cfg_attr(not(doc), doc = "`Tensor::cummax`")]
/// function, which is more high-level and designed for public use.
fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Element-wise greater than comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is greater than the corresponding element
/// of the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than comparison between two tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("greater"))]
#[cfg_attr(not(doc), doc = "`Tensor::greater`")]
/// function, which is more high-level and designed for public use.
fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
/// Element-wise greater than comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is greater than the right hand side
/// scalar, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than comparison between a tensor and a scalar, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("greater_elem"))]
#[cfg_attr(not(doc), doc = "`Tensor::greater_elem`")]
/// function, which is more high-level and designed for public use.
fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
/// Element-wise greater than or equal comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is greater than or equal to the
/// corresponding element of the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than or equal comparison between two tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal"))]
#[cfg_attr(not(doc), doc = "`Tensor::greater_equal`")]
/// function, which is more high-level and designed for public use.
fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
/// Element-wise greater than or equal comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is greater than or equal to the right
/// hand side scalar, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal_elem"))]
#[cfg_attr(not(doc), doc = "`Tensor::greater_equal_elem`")]
/// function, which is more high-level and designed for public use.
fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
/// Element-wise less than comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is less than the corresponding element of
/// the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than comparison between two tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("lower"))]
#[cfg_attr(not(doc), doc = "`Tensor::lower`")]
/// function, which is more high-level and designed for public use.
fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
/// Element-wise less than comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is less than the right hand side scalar,
/// and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than comparison between a tensor and a scalar, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("lower_elem"))]
#[cfg_attr(not(doc), doc = "`Tensor::lower_elem`")]
/// function, which is more high-level and designed for public use.
fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
/// Element-wise less than or equal comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is less than or equal to the corresponding
/// element of the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than or equal comparison between two tensors, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal"))]
#[cfg_attr(not(doc), doc = "`Tensor::lower_equal`")]
/// function, which is more high-level and designed for public use.
fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
/// Element-wise less than or equal comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is less than or equal to the right hand
/// side scalar, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal_elem"))]
#[cfg_attr(not(doc), doc = "`Tensor::lower_equal_elem`")]
/// function, which is more high-level and designed for public use.
fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
/// Gets the indices of the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the indices of the maximum elements.
/// * `tensor` - The tensor to get the indices of the maximum elements from.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the index of the
/// maximum element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("argmax"))]
#[cfg_attr(not(doc), doc = "`Tensor::argmax`")]
/// function, which is more high-level and designed for public use.
fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B>;
/// Gets the indices of the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the indices of the minimum elements.
/// * `tensor` - The tensor to get the indices of the minimum elements from.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the index of the
/// minimum element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("argmin"))]
#[cfg_attr(not(doc), doc = "`Tensor::argmin`")]
/// function, which is more high-level and designed for public use.
fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B>;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A single-element tensor containing the maximum element of the input tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("max"))]
#[cfg_attr(not(doc), doc = "`Tensor::max`")]
/// function, which is more high-level and designed for public use.
fn max(tensor: Self::Primitive) -> Self::Primitive;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements from.
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
/// Each element is the maximum element of the corresponding input dim.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("max_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::max_dim`")]
/// function, which is more high-level and designed for public use.
fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements from.
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
/// as the input tensor, where each element is the index of the maximum element of the input tensor
/// at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("max_dim_with_indices"))]
#[cfg_attr(not(doc), doc = "`Tensor::max_dim_with_indices`")]
/// function, which is more high-level and designed for public use.
fn max_dim_with_indices(tensor: Self::Primitive, dim: usize)
-> (Self::Primitive, IntTensor<B>);
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A single-element tensor containing the maximum absolute element of the input tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum absolute elements of a tensor, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("max_abs"))]
#[cfg_attr(not(doc), doc = "`Tensor::max_abs`")]
/// function, which is more high-level and designed for public use.
fn max_abs(tensor: Self::Primitive) -> Self::Primitive;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements from.
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
/// Each element is the maximum absolute element of the corresponding input dim.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("max_abs_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::max_abs_dim`")]
/// function, which is more high-level and designed for public use.
fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Gets the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements from.
///
/// # Returns
///
/// A single-element tensor containing the minimum element of the input tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("min"))]
#[cfg_attr(not(doc), doc = "`Tensor::min`")]
/// function, which is more high-level and designed for public use.
fn min(tensor: Self::Primitive) -> Self::Primitive;
/// Gets the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements from.
/// * `dim` - The axis along which to get the minimum elements.
///
/// # Returns
///
/// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
/// Each element is the minimum element of the corresponding input dim.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("min_dim"))]
#[cfg_attr(not(doc), doc = "`Tensor::min_dim`")]
/// function, which is more high-level and designed for public use.
fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
/// Gets the minimum elements and indices of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements from.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// each element is the minimum element of the input tensor at the corresponding index
/// along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("min_dim_with_indices"))]
#[cfg_attr(not(doc), doc = "`Tensor::min_dim_with_indices`")]
/// function, which is more high-level and designed for public use.
fn min_dim_with_indices(tensor: Self::Primitive, dim: usize)
-> (Self::Primitive, IntTensor<B>);
/// Clamp the tensor between the given min and max values.
///
/// # Arguments
///
/// * `min` - The minimum value.
/// * `max` - The maximum value.
///
/// # Returns
///
/// A new tensor with the values clamped between the given min and max values.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users.
///
/// For clamping a tensor between the given min and max values, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("clamp"))]
#[cfg_attr(not(doc), doc = "`Tensor::clamp`")]
/// function, which is more high-level and designed for public use.
fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive;
/// Clamps a tensor under a minimum value.
///
/// # Arguments
///
/// * `tensor` - The tensor to clamp.
/// * `min` - The minimum value.
///
/// # Returns
///
/// A new tensor with the values clamped under the given min value.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users.
///
/// For clamping a tensor under a minimum value, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("clamp_min"))]
#[cfg_attr(not(doc), doc = "`Tensor::clamp_min`")]
/// function, which is more high-level and designed for public use.
fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive;
/// Clamps a tensor over a maximum value.
///
/// # Arguments
///
/// * `tensor` - The tensor to clamp.
/// * `max` - The maximum value.
///
/// # Returns
///
/// A new tensor with the values clamped over the given max value.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users.
///
/// For clamping a tensor over a maximum value, users should prefer the
#[cfg_attr(doc, doc = crate::doc_tensor!("clamp_max"))]
#[cfg_attr(not(doc), doc = "`Tensor::clamp_max`")]
/// function, which is more high-level and designed for public use.
fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive;
}

View File

@@ -0,0 +1,5 @@
/// Calibration method used to compute the quantization range mapping.
pub enum Calibration {
/// Computes quantization range mapping based on the min and max values.
MinMax,
}

View File

@@ -0,0 +1,7 @@
mod calibration;
mod parameters;
mod scheme;
pub use calibration::*;
pub use parameters::*;
pub use scheme::*;

View File

@@ -0,0 +1,15 @@
use crate::Backend;
pub use burn_std::quantization::{QParamTensor, QParams};
/// The quantization parameters primitive.
///
/// # Remarks
///
/// This is a low-level struct used internally by the library to provide the quantization parameters
/// to the backends. It is not designed for direct usage by users, and not recommended to import
/// or use this struct directly.
pub struct QuantizationParametersPrimitive<B: Backend> {
/// The scaling factor.
pub scales: B::FloatTensorPrimitive,
}

View File

@@ -0,0 +1,70 @@
pub use burn_std::{QPARAM_ALIGN, params_shape};
use burn_std::{QuantLevel, QuantMode, QuantScheme, Shape};
use super::{Calibration, QuantizationParametersPrimitive};
use crate::{Backend, TensorMetadata};
/// Compute the quantization range mapping.
pub fn compute_range<B: Backend>(
scheme: &QuantScheme,
tensor: B::FloatTensorPrimitive,
calibration: &Calibration,
) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
match calibration {
Calibration::MinMax => match scheme.level {
QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)),
QuantLevel::Block(block_size) => {
let block_elems = block_size.num_elements();
let shape = tensor.shape();
let numel = shape.num_elements();
assert_eq!(
numel % block_elems,
0,
"Tensor {shape:?} must be evenly divisible by block size {block_elems}"
);
let num_blocks = numel / block_elems;
let params_shape = params_shape(&shape, scheme.level);
let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems]));
let blocks_min =
B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone());
let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape);
(blocks_min, blocks_max)
}
},
}
}
/// Compute the quantization parameters.
pub fn compute_q_params<B: Backend>(
scheme: &QuantScheme,
min: B::FloatTensorPrimitive,
max: B::FloatTensorPrimitive,
) -> QuantizationParametersPrimitive<B> {
match scheme {
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
..
} => {
// Quantized range `[a, b]`
let (a, b) = scheme.value.range();
// Compute scale to convert an input value in range `[-alpha, alpha]`
let min_abs = B::float_abs(min);
let max_abs = B::float_abs(max);
// `min_abs.max_pair(max_abs)`
let mask = B::float_lower(min_abs.clone(), max_abs.clone());
let values_range =
B::float_mul_scalar(B::float_mask_where(min_abs, mask, max_abs), 2f32.into());
QuantizationParametersPrimitive {
scales: B::float_div_scalar(values_range, (b - a).into()),
}
}
}
}