feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,62 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Tensor library with user-friendly APIs and automatic differentiation support"
documentation = "https://docs.rs/burn-tensor"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-tensor"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std"]
doc = ["default"]
std = [
"num-traits/std",
"burn-std/std",
"burn-backend/std",
"colored",
]
tracing = [
"burn-std/tracing",
"burn-backend/tracing",
]
cubecl = ["burn-std/cubecl", "burn-backend/cubecl"]
cubecl-cuda = ["burn-backend/cubecl-cuda"]
cubecl-hip = ["burn-backend/cubecl-hip"]
cubecl-wgpu = ["burn-backend/cubecl-wgpu"]
cubecl-cpu = ["burn-backend/cubecl-cpu"]
experimental-named-tensor = []
[dependencies]
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
colored = { workspace = true, optional = true }
derive-new = { workspace = true }
num-traits = { workspace = true }
# Device
hashbrown = { workspace = true }
spin = { workspace = true }
thiserror = { workspace = true }
# Serialization
serde = { workspace = true }
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
[dev-dependencies]
serial_test = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,12 @@
# Burn Tensor
> [Burn](https://github.com/tracel-ai/burn) Tensor Library
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tensor.svg)](https://crates.io/crates/burn-tensor)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tensor/blob/master/README.md)
This library provides the core abstractions required to run tensor operations with Burn.
`Tensor`s are generic over the backend to allow users to perform operations using different
`Backend` implementations. Burn's tensors also support auto-differentiation thanks to the
`AutodiffBackend` trait.

View File

@@ -0,0 +1 @@
../../docs/katex-header.html

View File

@@ -0,0 +1,465 @@
use alloc::format;
use alloc::string::String;
use burn_backend::{Backend, Device, DeviceId, DeviceOps};
use burn_std::stub::RwLock;
use burn_std::{DType, FloatDType, IntDType};
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
use thiserror::Error;
use core::any::TypeId;
#[cfg(feature = "std")]
pub use std::collections::HashMap;
#[cfg(feature = "std")]
use std::sync::LazyLock;
#[cfg(not(feature = "std"))]
pub use hashbrown::HashMap;
#[cfg(not(feature = "std"))]
use spin::Lazy as LazyLock;
/// Policy controlling default device behavior.
///
/// This includes default data types used for tensor creation.
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct DevicePolicy {
/// Default floating-point data type for tensor creation.
float_dtype: Option<FloatDType>,
/// Default integer data type for tensor creation.
int_dtype: Option<IntDType>,
}
impl DevicePolicy {
/// Returns the default floating-point data type used for tensor creation.
pub(crate) fn float_dtype(&self) -> Option<FloatDType> {
self.float_dtype
}
/// Returns the default integer data type used for tensor creation.
pub(crate) fn int_dtype(&self) -> Option<IntDType> {
self.int_dtype
}
/// Sets the default floating-point data type.
pub(crate) fn set_float_dtype(&mut self, dtype: FloatDType) {
self.float_dtype = Some(dtype);
}
/// Sets the default integer data type.
pub(crate) fn set_int_dtype(&mut self, dtype: IntDType) {
self.int_dtype = Some(dtype);
}
}
/// Key for the registry: physical device type + device id
type RegistryKey = (DeviceId, TypeId);
/// Global registry mapping devices to their policies.
static REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<DevicePolicy>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
/// Device policy management for controlling default tensor creation behavior.
///
/// # Policy Semantics
///
/// Device policies use snapshot semantics: when you retrieve a policy with
/// [`get_device_policy`], you get an immutable snapshot of the current configuration.
/// Updates to the policy (via [`set_default_dtypes`], [`set_default_float_dtype`], etc.)
/// only affect future policy retrievals, not existing references.
///
/// This is intended for the common case where policies are set once during
/// initialization and then read frequently during tensor creation.
struct DevicePolicyRegistry;
impl DevicePolicyRegistry {
/// Get the policy for a physical device type and device id.
///
/// If no policy exists yet, a default one is created and stored.
fn get<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
let key = Self::key(device);
if let Some(policy) = REGISTRY.read().unwrap().get(&key) {
return Arc::clone(policy);
}
let mut map = REGISTRY.write().unwrap();
Arc::clone(
map.entry(key)
.or_insert_with(|| Arc::new(DevicePolicy::default())),
)
}
/// Mutate the policy for a given device.
fn update<D: DeviceOps>(device: &D, update_fn: impl FnOnce(&mut DevicePolicy)) {
let key = Self::key(device);
let mut map = REGISTRY.write().unwrap();
let policy = map
.entry(key)
.or_insert_with(|| Arc::new(DevicePolicy::default()));
// Update the policy
let policy_mut = Arc::make_mut(policy);
update_fn(policy_mut);
}
/// Returns the device registry key.
fn key<D: Device>(device: &D) -> RegistryKey {
(device.to_id(), TypeId::of::<D>())
}
}
/// Get the [`device`'s policy](DevicePolicy).
///
/// Returns an immutable snapshot of the device's current policy. If the policy
/// is updated after retrieval, this snapshot will not reflect those changes.
pub(crate) fn get_device_policy<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
DevicePolicyRegistry::get(device)
}
/// Errors that can occur during device-related operations.
///
/// This covers errors related to hardware capability mismatches, such as
/// requesting a data type not supported by the device, and configuration
/// errors like attempting to change a policy in an invalid context.
#[derive(Debug, Error)]
pub enum DeviceError {
/// Unsupported data type by the device.
#[error("Device {device} does not support the requested data type {dtype:?}")]
UnsupportedDType {
/// The string representation of the device.
device: String,
/// The data type that caused the error.
dtype: DType,
},
// TODO: `InvalidContext` if a device policy cannot be changed after init / during training / etc.
}
impl DeviceError {
/// Helper to create a [`DeviceError::UnsupportedDType`] from any device.
pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {
Self::UnsupportedDType {
device: format!("{device:?}"),
dtype,
}
}
}
fn check_dtype_support<B: Backend>(
device: &B::Device,
dtype: impl Into<DType>,
) -> Result<(), DeviceError> {
let dtype = dtype.into();
// Default dtypes should have `DTypeUsage::general()`. Types restricted to specialized
// operations should not be used as default.
if B::supports_dtype(device, dtype) {
Ok(())
} else {
Err(DeviceError::unsupported_dtype(device, dtype))
}
}
/// Sets the default data types for the device.
///
/// This updates the device's default data types used for tensor creation.
/// The policy should typically be set once during initialization and then
/// remains global for all subsequent operations on that device.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{DType, Int, Tensor, set_default_dtypes};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
///
/// // Update the device policy
/// set_default_dtypes::<B>(&device, DType::F16, DType::I32);
///
/// // All float tensors created after this will use F16 by default
/// let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
/// // All int tensors created after this will use I32 default
/// let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
/// }
/// ```
pub fn set_default_dtypes<B: Backend>(
device: &B::Device,
float_dtype: impl Into<FloatDType>,
int_dtype: impl Into<IntDType>,
) -> Result<(), DeviceError> {
let float_dtype = float_dtype.into();
let int_dtype = int_dtype.into();
check_dtype_support::<B>(device, float_dtype)?;
check_dtype_support::<B>(device, int_dtype)?;
set_default_dtypes_unchecked(device, float_dtype, int_dtype);
Ok(())
}
/// Sets the default floating-point data type for the device.
///
/// This updates the device's default data types used for tensor creation.
/// The policy should typically be set once during initialization and then
/// remains global for all subsequent operations on that device.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{DType, Tensor, set_default_float_dtype};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
///
/// // Update the device policy
/// set_default_float_dtype::<B>(&device, DType::F16);
///
/// // All float tensors created after this will use F16 by default
/// let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
/// }
/// ```
pub fn set_default_float_dtype<B: Backend>(
device: &B::Device,
dtype: impl Into<FloatDType>,
) -> Result<(), DeviceError> {
let dtype = dtype.into();
check_dtype_support::<B>(device, dtype)?;
set_default_float_dtype_unchecked(device, dtype);
Ok(())
}
/// Sets the default integer data type for the device.
///
/// This updates the device's default data types used for tensor creation.
/// The policy should typically be set once during initialization and then
/// remains global for all subsequent operations on that device.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{DType, Int, Tensor, set_default_int_dtype};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
///
/// // Update the device policy
/// set_default_int_dtype::<B>(&device, DType::I32);
///
/// // All int tensors created after this will use I32 default
/// let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
/// }
/// ```
pub fn set_default_int_dtype<B: Backend>(
device: &B::Device,
dtype: impl Into<IntDType>,
) -> Result<(), DeviceError> {
let dtype = dtype.into();
check_dtype_support::<B>(device, dtype)?;
set_default_int_dtype_unchecked(device, dtype);
Ok(())
}
// Unchecked versions
fn set_default_dtypes_unchecked<D: DeviceOps>(
device: &D,
float_dtype: FloatDType,
int_dtype: IntDType,
) {
DevicePolicyRegistry::update(device, |p| {
p.set_float_dtype(float_dtype);
p.set_int_dtype(int_dtype);
});
}
fn set_default_float_dtype_unchecked<D: DeviceOps>(device: &D, dtype: FloatDType) {
DevicePolicyRegistry::update(device, |p| {
p.set_float_dtype(dtype);
});
}
fn set_default_int_dtype_unchecked<D: DeviceOps>(device: &D, dtype: IntDType) {
DevicePolicyRegistry::update(device, |p| {
p.set_int_dtype(dtype);
});
}
#[cfg(all(test, feature = "std"))]
mod tests {
use serial_test::serial;
use super::*;
fn clear_registry() {
REGISTRY.write().unwrap().clear();
}
#[derive(Clone, Debug, Default, PartialEq, new)]
pub struct TestDeviceA {
index: u32,
}
impl Device for TestDeviceA {
fn from_id(device_id: DeviceId) -> Self {
Self {
index: device_id.index_id,
}
}
fn to_id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: self.index,
}
}
fn device_count(_type_id: u16) -> usize {
1
}
}
impl DeviceOps for TestDeviceA {}
#[derive(Clone, Debug, Default, PartialEq, new)]
pub struct TestDeviceB {
index: u32,
}
impl Device for TestDeviceB {
fn from_id(device_id: DeviceId) -> Self {
Self {
index: device_id.index_id,
}
}
fn to_id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: self.index,
}
}
fn device_count(_type_id: u16) -> usize {
1
}
}
impl DeviceOps for TestDeviceB {}
#[test]
#[serial]
fn default_policy_is_created_and_shared() {
clear_registry(); // reset registry for each test
let device = TestDeviceA::new(0);
let p1 = get_device_policy(&device);
let p2 = get_device_policy(&device);
assert!(Arc::ptr_eq(&p1, &p2));
// Not explicitly set
assert!(p1.float_dtype().is_none());
assert!(p1.int_dtype().is_none());
assert!(p2.float_dtype().is_none());
assert!(p2.int_dtype().is_none());
}
#[test]
#[serial]
fn updated_policy_is_shared() {
clear_registry(); // reset registry for each test
let device = TestDeviceA::new(0);
// The device policy is meant to be set once at initialization
set_default_dtypes_unchecked(&device, FloatDType::BF16, IntDType::I32);
let p1 = get_device_policy(&device);
let p2 = get_device_policy(&device);
assert!(Arc::ptr_eq(&p1, &p2));
assert_eq!(p1.float_dtype(), Some(FloatDType::BF16));
assert_eq!(p1.int_dtype(), Some(IntDType::I32));
assert_eq!(p2.float_dtype(), Some(FloatDType::BF16));
assert_eq!(p2.int_dtype(), Some(IntDType::I32));
}
#[test]
#[serial]
fn policy_is_device_id_specific() {
clear_registry(); // reset registry for each test
let d1 = TestDeviceA::new(0);
let d2 = TestDeviceA::new(1);
set_default_float_dtype_unchecked(&d1, FloatDType::F16);
let p1 = get_device_policy(&d1);
let p2 = get_device_policy(&d2);
assert!(!Arc::ptr_eq(&p1, &p2));
assert_eq!(p1.float_dtype(), Some(FloatDType::F16));
assert!(p1.int_dtype().is_none());
assert!(p2.float_dtype().is_none());
assert!(p2.int_dtype().is_none());
}
#[test]
#[serial]
fn policy_is_device_type_specific() {
clear_registry(); // reset registry for each test
let d1 = TestDeviceA::new(0);
let d2 = TestDeviceB::new(0);
set_default_float_dtype_unchecked(&d2, FloatDType::F16);
let p1 = get_device_policy(&d1);
let p2 = get_device_policy(&d2);
assert!(p1.float_dtype().is_none());
assert!(p1.int_dtype().is_none());
assert_eq!(p2.float_dtype(), Some(FloatDType::F16));
assert!(p2.int_dtype().is_none());
}
#[test]
#[serial]
fn updating_policy_should_not_affect_snapshot() {
clear_registry(); // reset registry for each test
// The device policy is meant to be set once at initialization
let device = TestDeviceA::new(0);
let before = get_device_policy(&device);
set_default_float_dtype_unchecked(&device, FloatDType::BF16);
let after = get_device_policy(&device);
assert!(!Arc::ptr_eq(&before, &after));
assert_eq!(after.float_dtype(), Some(FloatDType::BF16));
assert!(before.float_dtype().is_none());
}
#[test]
#[serial]
fn set_default_dtypes_overwrites_fields() {
clear_registry(); // reset registry for each test
let device = TestDeviceA::new(0);
set_default_dtypes_unchecked(&device, FloatDType::F16, IntDType::I64);
let policy = get_device_policy(&device);
assert_eq!(policy.float_dtype(), Some(FloatDType::F16));
assert_eq!(policy.int_dtype(), Some(IntDType::I64));
}
}

View File

@@ -0,0 +1,23 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! This library provides the core abstractions required to run tensor operations with Burn.
//! `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations.
//! Burn's tensors also support auto-differentiation thanks to the `AutodiffBackend` trait.
#[macro_use]
extern crate derive_new;
extern crate alloc;
mod tensor;
pub(crate) use tensor::check::macros::check;
pub use tensor::*;
// Re-exported types
pub use burn_backend::{AllocationProperty, Bytes, StreamId, bf16, f16, read_sync, try_read_sync};
mod device;
pub use device::*;

View File

@@ -0,0 +1,647 @@
use crate::backend::Backend;
use crate::check::TensorCheck;
use crate::{Tensor, TensorPrimitive, check, s};
/// Applies the rectified linear unit function element-wise
/// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).
///
#[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")]
#[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")]
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.relu()
}
/// Applies the leaky rectified linear unit function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\)
$$
or
$$
\text{LeakyReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \newline
\text{negative\\_slope} \cdot x & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`"
)]
pub fn leaky_relu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
negative_slope: f64,
) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
tensor.primitive.tensor(),
negative_slope.into(),
)))
}
/// Applies the Gaussian Error Linear Units function as described in the paper
/// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{GELU}(x)
= x \cdot \Phi(x)
= x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
$$
where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.
"#
)]
#[cfg_attr(
not(doc),
doc = r#"
`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`
where `Φ(x)` is the cumulative distribution function for the Gaussian distribution.
"#
)]
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
}
/// Applies the tanh-based approximate GELU function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{GELU\_approx}(x)
= \frac{x}{2}\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`GELU_approx(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`"
)]
pub fn gelu_approximate<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
/// sqrt(2/π) precomputed as FRAC_2_SQRT_PI * FRAC_1_SQRT_2
const SQRT_2_OVER_PI: f64 =
core::f64::consts::FRAC_2_SQRT_PI * core::f64::consts::FRAC_1_SQRT_2;
let x = tensor;
let inner = x.clone() + x.clone().powf_scalar(3.0) * 0.044715;
let inner = inner * SQRT_2_OVER_PI;
(x.clone() * (inner.tanh() + 1)) * 0.5
}
/// Applies Parametric ReLu activation function as described in the paper
/// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).
///
/// - The tensor is assumed to be of shape `[batch_size, channels, ...]`.
/// - `alpha` is assumed to be of shape `[channels]` or `[1]`.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\)
$$
or
$$
\text{PReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \newline
\alpha x & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")]
pub fn prelu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: Tensor<B, 1>,
) -> Tensor<B, D> {
check!(TensorCheck::check_prelu_shape::<D>(
&tensor.shape(),
&alpha.shape()
));
let weight = if alpha.dims()[0] == 1 {
// if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
alpha.reshape([1; D])
} else {
// D>=2 because the case where D==1 and num_weights >1 is handled by check function
// there is more than 1 weight and rank is more than 2
let num_weights = alpha.dims()[0];
let mut s = [1; D];
s[1] = num_weights;
// reshape the weights to (1, channels,1 ...)
alpha.reshape(s)
};
Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
tensor.primitive.tensor(),
weight.primitive.tensor(),
)))
}
/// Applies the softmax function on the input tensor along the given dimension.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")]
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor = tensor.exp();
let tensor_tmp = tensor.clone().sum_dim(dim);
tensor.div(tensor_tmp)
}
/// Applies the softmin function on the input tensor along the given dimension.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")]
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmin", dim));
softmax(tensor.neg(), dim)
}
/// Applies the SoftPlus function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")]
///
/// The SoftPlus function is a smooth approximation of the ReLU function.
pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
tensor.div_scalar(beta)
}
/// Applies the "quiet softmax" function on the input tensor along the given dimension.
///
/// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html).
///
/// This function is similar to the softmax function, but it allows for "no selection" when
/// all the outputs are close to zero.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`"
)]
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));
let max_vals = tensor.clone().detach().max_dim(dim);
let exp_x = (tensor - max_vals.clone()).exp();
let sum_exp = exp_x.clone().sum_dim(dim);
exp_x.div(sum_exp + max_vals.neg().exp())
}
/// Applies the log softmax function on the input tensor along the given dimension.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{log\\_softmax}\(x_i\)
= \log\left(\text{softmax}\(x_i\)\right)
= \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`"
)]
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("log softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();
tensor.sub(tensor_tmp)
}
/// Applies the sigmoid function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{sigmoid}\(x\)
= \sigma(x)
= \frac{1}{1 + \exp(-x)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")]
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
tensor.primitive.tensor(),
)))
}
/// Applies the hard sigmoid function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta))
$$
"#
)]
#[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")]
pub fn hard_sigmoid<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: f64,
beta: f64,
) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
tensor.primitive.tensor(),
alpha.into(),
beta.into(),
)))
}
/// Applies the log sigmoid function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
$$
"#
)]
#[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")]
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
tensor.primitive.tensor(),
)))
}
/// Applies the SiLU function (also known as the swish function) element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")]
pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(sigmoid(tensor))
}
/// Applies the hard swish function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5))
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`"
)]
pub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5))
}
/// Applies the Mish function as described in the paper in
/// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{Mish}\(x\)
= x \cdot \tanh(\text{Softplus}(x))
= \tanh\left(\log\(1 + \exp\(x\)\)\right)
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`"
)]
pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(softplus(tensor, 1.0).tanh())
}
/// Applies the tanh function element-wise.
pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.tanh()
}
/// Applies the Exponential Linear Unit function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{ELU}\(x\) =
\begin{cases}
x & \text{if } x > 0 \newline
\alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`f(x) =`\n- `x for x > 0`\n- `alpha * (exp(x) - 1) for x <= 0`"
)]
pub fn elu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
let mask = tensor.clone().lower_equal_elem(0);
let scaled = tensor.clone().exp().sub_scalar(1).mul_scalar(alpha);
tensor.mask_where(mask, scaled)
}
/// Applies the Continuously Differentiable Exponential Linear Unit function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{CELU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \newline
\alpha \cdot \left(\exp\left(\frac{x}{\alpha}\right) - 1\right) & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`"
)]
///
/// See also [CELU](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html)
///
/// # Arguments
/// - `alpha`: scaling parameter for the negative part.
pub fn celu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
let mask = tensor.clone().lower_equal_elem(0);
let scaled = tensor
.clone()
.div_scalar(alpha)
.exp()
.sub_scalar(1)
.mul_scalar(alpha);
tensor.mask_where(mask, scaled)
}
/// Applies the Scaled Exponential Linear Unit function element-wise
/// as described in the paper [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{SELU}\(x\) = \gamma \cdot
\begin{cases}
x & \text{if } x > 0 \newline
\alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
$$
where $\alpha \approx 1.6733$ and $\gamma \approx 1.0507$.
"#
)]
#[cfg_attr(
not(doc),
doc = "`selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0`"
)]
pub fn selu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
// Constants from the SELU paper / ONNX spec
const ALPHA: f64 = 1.6732632423543772848170429916717_f64;
const GAMMA: f64 = 1.0507009873554804934193349852946_f64;
let mask = tensor.clone().greater_equal_elem(0.0);
let positive = tensor.clone().mul_scalar(GAMMA);
let negative = tensor.exp().sub_scalar(1.0).mul_scalar(ALPHA * GAMMA);
negative.mask_where(mask, positive)
}
/// Applies the thresholded rectified linear unit function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{ThresholdedReLU}(x) =
\begin{cases}
x & \text{if } x > \alpha \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`f(x) =`\n- `x if x > alpha`\n- `0 otherwise`")]
///
/// # Arguments
/// - `alpha`: threshold value (default in ONNX is 1.0).
pub fn thresholded_relu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: f64,
) -> Tensor<B, D> {
let mask = tensor.clone().lower_equal_elem(alpha);
tensor.mask_fill(mask, 0)
}
/// Applies the gated linear unit function.
///
/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
///
/// **Note**:
/// * The size of the input tensor along `dim` must be divisible by 2.
///
/// ### Arguments
/// * `tensor` - The input tensor.
///
/// ### Returns
/// * A tensor with the same shape as the input, except the size along `dim` is halved.
pub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
// TODO: Handle negative indices with AsIndex for compatibility with Pytorch nn.GLU.
assert!(
tensor.dims()[dim].is_multiple_of(2),
"Input tensor along dimension {dim} must have an even size. N is divisible by 2."
);
let new_len = tensor.dims()[dim] / 2;
let a = tensor.clone().slice_dim(dim, s![0..new_len]);
let b = tensor.slice_dim(dim, s![new_len..new_len * 2]);
a.mul(sigmoid(b))
}
/// Applies the Softsign function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{softsign}(x) = \frac{x}{1 + |x|}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softsign(x_i) = x_i / (1 + |x_i|)`")]
pub fn softsign<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().div(tensor.abs() + 1)
}
/// Applies the HardShrink function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\_shrink}(x) =
\begin{cases}
x & \text{if } x > \lambda \newline
x & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`"
)]
/// # Arguments
/// - `lambda`: the lambda value for the Hard Shrink formulation. Default is 0.5.
pub fn hard_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
let mask = tensor.clone().abs().lower_equal_elem(lambda);
tensor.mask_fill(mask, 0)
}
/// Applies the SoftShrink function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{soft\_shrink}(x) =
\begin{cases}
x - \lambda & \text{if } x > \lambda \newline
x + \lambda & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`"
)]
/// # Arguments
/// - `lambda`: the lambda value for the Soft Shrink formulation. Default is 0.5.
pub fn soft_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
shrink(tensor, lambda, lambda)
}
/// Applies the Shrink function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{shrink}(x) =
\begin{cases}
x - \text{bias} & \text{if } x > \lambda \newline
x + \text{bias} & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`"
)]
/// # Arguments
/// - `lambda`: the lambda value for the Shrink formulation.
/// - `bias`: the bias value for the Shrink formulation.
pub fn shrink<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
lambda: f64,
bias: f64,
) -> Tensor<B, D> {
let abs_tensor = tensor.clone().abs();
let sign = tensor.clone().sign();
let shrunk = tensor.sub(sign.mul_scalar(bias));
let mask = abs_tensor.lower_equal_elem(lambda);
shrunk.mask_fill(mask, 0)
}

View File

@@ -0,0 +1,3 @@
mod base;
pub use base::*;

View File

@@ -0,0 +1,75 @@
pub use burn_backend::tensor::BasicAutodiffOps;
use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
/// Backward pass of the tensor.
pub fn backward(&self) -> B::Gradients {
B::backward(self.primitive.clone().tensor())
}
/// Get the gradients of a tensor if it exist.
///
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
/// be accessed multiple times. If you only need to get the gradients one time,
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
}
}
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => {
B::grad_remove(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new)
}
}
}
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
/// gradient.
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
match &self.primitive {
TensorPrimitive::Float(tensor) => {
B::grad_replace(tensor, grads, grad.primitive.tensor())
}
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
&self.primitive.clone().tensor(),
grads,
grad.primitive.tensor(),
),
}
}
}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
/// Returns the inner tensor without the autodiff information.
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
Tensor::new(K::inner(self.primitive))
}
/// Convert a tensor to the autodiff backend.
///
/// # Arguments
///
/// * `inner` - The tensor to convert.
///
/// # Returns
///
/// The tensor converted to the autodiff backend.
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
Self::new(K::from_inner(inner.primitive))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,429 @@
use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
use alloc::{vec, vec::Vec};
use crate::try_read_sync;
/// The part of the tensor to keep when creating a triangular mask.
enum TriPart {
/// Upper triangular part.
Upper,
/// Lower triangular part.
Lower,
/// Diagonal part.
Diagonal,
}
impl<B, const D: usize> Tensor<B, D, Bool>
where
B: Backend,
{
/// Create a boolean tensor from data on the given device.
///
/// # Arguments
///
/// * `data` - The tensor data.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// A boolean tensor.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
/// println!("{tensor}");
/// }
/// ```
pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
}
/// Convert the bool tensor into an int tensor.
///
/// # Returns
///
/// An integer tensor where `true` is converted to `1` and `false` to `0`.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
/// let int_tensor = bool_tensor.int();
/// println!("{int_tensor}"); // [1, 0, 1]
/// }
/// ```
pub fn int(self) -> Tensor<B, D, Int> {
Tensor::new(B::bool_into_int(self.primitive))
}
/// Convert the bool tensor into a float tensor.
///
/// # Returns
///
/// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
/// let float_tensor = bool_tensor.float();
/// println!("{float_tensor}"); // [1.0, 0.0, 1.0]
/// }
/// ```
pub fn float(self) -> Tensor<B, D> {
Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
}
/// Inverses boolean values.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
/// let inverted = tensor.bool_not();
/// println!("{inverted}"); // [[false, true], [true, false]]
/// }
/// ```
pub fn bool_not(self) -> Self {
Tensor::new(B::bool_not(self.primitive))
}
/// Performs logical and (`&&`) on two boolean tensors.
///
/// # Arguments
///
/// * `rhs` - The right-hand side tensor for the AND operation.
///
/// # Returns
///
/// A boolean tensor where each element is the result of `self[i] && rhs[i]`.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
/// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
/// let result = a.bool_and(b);
/// println!("{result}"); // [[true, false], [false, false]]
/// }
/// ```
pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
Tensor::new(B::bool_and(self.primitive, rhs.primitive))
}
/// Performs logical or (`||`) on two boolean tensors.
///
/// # Arguments
///
/// * `rhs` - The right-hand side tensor for the OR operation.
///
/// # Returns
///
/// A boolean tensor where each element is the result of `self[i] || rhs[i]`.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
/// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
/// let result = a.bool_or(b);
/// println!("{result}"); // [[true, true], [true, false]]
/// }
/// ```
pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
Tensor::new(B::bool_or(self.primitive, rhs.primitive))
}
/// Performs logical xor (`^`) on two boolean tensors.
///
/// # Arguments
///
/// * `rhs` - The right-hand side tensor for the XOR operation.
///
/// # Returns
///
/// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.
/// Returns `true` when exactly one of the operands is `true`.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
/// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
/// let result = a.bool_xor(b);
/// println!("{result}"); // [[false, true], [true, false]]
/// }
/// ```
pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
}
/// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
///
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor = Tensor::<B, 2, Bool>::from_bool(
/// [[true, false, true], [false, true, false], [false, true, false]].into(),
/// &device,
/// );
/// let indices = tensor.nonzero();
/// println!("{}", indices[0]); // [0, 0, 1, 2]
/// println!("{}", indices[1]); // [0, 2, 1, 1]
/// }
/// ```
pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
try_read_sync(self.nonzero_async())
.expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
}
/// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
///
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
let indices = self.argwhere_async().await;
if indices.shape().num_elements() == 0 {
// Return empty vec when all elements are zero
return vec![];
}
let dims = indices.shape();
indices
.chunk(dims[1], 1)
.into_iter()
.map(|t| t.reshape(Shape::new([dims[0]])))
.collect()
}
/// Compute the indices of the elements that are true, grouped by element.
///
/// # Returns
///
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
/// result contains the indices of a non-zero element.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor = Tensor::<B, 2, Bool>::from_bool(
/// [[true, false, true], [false, true, false], [false, true, false]].into(),
/// &device,
/// );
/// let indices = tensor.argwhere();
/// println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]]
/// }
/// ```
pub fn argwhere(self) -> Tensor<B, 2, Int> {
try_read_sync(self.argwhere_async())
.expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
}
/// Compute the indices of the elements that are true, grouped by element.
///
/// # Returns
///
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
/// result contains the indices of a non-zero element.
pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
Tensor::new(B::bool_argwhere(self.primitive).await)
}
/// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
/// fill the specified area with a value.
fn tri_mask<S: Into<Shape>>(
shape: S,
tri_part: TriPart,
offset: i64,
device: &B::Device,
) -> Self {
let shape: Shape = shape.into();
let height = shape[D - 2];
let width = shape[D - 1];
// Generate row and column index tensors.
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
// Prepare shapes for broadcasting.
let mut row_shape = [1; D];
row_shape[D - 2] = height;
let mut col_shape = [1; D];
col_shape[D - 1] = width;
// Reshape for broadcasting.
let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
let col_broadcast = col_indices.reshape(Shape::new(col_shape));
// Broadcasting trick to create a matrix that facilitates comparison for mask generation.
let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
// Select the appropriate comparison function based on `tri_part`.
let compare = match tri_part {
TriPart::Upper => Tensor::greater_elem,
TriPart::Lower => Tensor::lower_elem,
TriPart::Diagonal => Tensor::not_equal_elem,
};
// Generate and return the mask by applying the comparison to the matrix.
compare(matrix, 0).unsqueeze()
}
/// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
/// area with a value.
///
/// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
///
/// # Arguments
///
/// * `shape`: The shape of the matrix.
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
/// towards the upper triangle.
/// * `device`: The device on which the tensor will be allocated.
///
/// # Returns
///
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
/// upper triangle taking into account the specified `offset`. All other elements are `true`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
/// println!("{mask}");
/// // [[false, false, false],
/// // [true, false, false],
/// // [true, true, false]]
/// }
/// ```
pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Upper, offset, device)
}
/// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
/// area with a value.
///
/// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
///
/// # Arguments
///
/// * `shape`: The shape of the matrix.
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
/// towards the lower triangle.
/// * `device`: The device on which the tensor will be allocated.
///
/// # Returns
///
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
/// lower triangle taking into account the specified `offset`. All other elements are `true`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
/// println!("{mask}");
/// // [[false, true, true],
/// // [false, false, true],
/// // [false, false, false]]
/// }
/// ```
pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Lower, offset, device)
}
/// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
/// area with a value.
///
/// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
///
/// # Arguments
///
/// * `shape`: The shape of the matrix.
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
/// towards the upper triangle.
/// * `device`: The device on which the tensor will be allocated.
///
/// # Returns
///
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
/// diagonal. All other elements are `true`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
/// println!("{mask}");
/// // [[false, true, true],
/// // [true, false, true],
/// // [true, true, false]]
/// }
/// ```
pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Diagonal, offset, device)
}
}

View File

@@ -0,0 +1,56 @@
use crate::{Int, Shape, Tensor, backend::Backend};
use alloc::vec::Vec;
/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
/// use burn_tensor::Int;
/// use burn_tensor::{backend::Backend, Shape, Tensor};
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
/// println!("{}", result);
/// }
/// ```
pub fn cartesian_grid<B: Backend, S: Into<Shape>, const D: usize, const D2: usize>(
shape: S,
device: &B::Device,
) -> Tensor<B, D2, Int> {
if D2 != D + 1 {
panic!("D2 must equal D + 1 for Tensor::cartesian_grid")
}
let dims = shape.into();
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
for dim in 0..D {
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
let mut shape = [1; D];
shape[dim] = dims[dim];
let mut dim_range = dim_range.reshape(shape);
for (i, &item) in dims.iter().enumerate() {
if i == dim {
continue;
}
dim_range = dim_range.repeat_dim(i, item);
}
indices.push(dim_range);
}
Tensor::stack::<D2>(indices, D)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,111 @@
use crate::{Float, Tensor, backend::Backend};
impl<B, const D: usize> Tensor<B, D, Float>
where
B: Backend,
{
/// Computes the floating-point remainder of dividing `self` by `other`.
///
/// The result has the same sign as `self` and magnitude less than `other`.
/// This is equivalent to the IEEE 754 remainder operation.
///
/// # Special Cases (IEEE 754 compliant)
///
/// - If `self` is ±∞ and `other` is not NaN, NaN is returned
/// - If `other` is ±0 and `self` is not NaN, NaN is returned
/// - If `other` is ±∞ and `self` is finite, `self` is returned
/// - If either argument is NaN, NaN is returned
///
/// # Arguments
///
/// * `other` - The divisor tensor. Must have the same shape as `self`.
///
/// # Returns
///
/// A tensor with the same shape where each element is the floating-point remainder.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let dividend = Tensor::<B, 1>::from_data([5.3, -5.3, 5.3, -5.3], &device);
/// let divisor = Tensor::<B, 1>::from_data([2.0, 2.0, -2.0, -2.0], &device);
/// let result = dividend.fmod(divisor);
///
/// // Result: [1.3, -1.3, 1.3, -1.3]
/// }
/// ```
pub fn fmod(self, other: Self) -> Self {
// Normal case: fmod(x, y) = x - y * trunc(x / y)
let quotient = self.clone().div(other.clone());
let truncated = quotient.trunc();
let product = other.clone() * truncated.clone();
// When divisor is infinity and dividend is finite:
// - quotient is 0, truncated is 0
// - but 0 * infinity = NaN, which is wrong
// We need to handle this case by replacing NaN with 0 when appropriate
// Check if the product is NaN due to 0 * inf
let is_zero_times_inf = truncated.equal_elem(0.0).bool_and(other.is_inf());
let zero_tensor = self.clone().mul_scalar(0.0);
let corrected_product = product.mask_where(is_zero_times_inf, zero_tensor);
self - corrected_product
}
/// Computes the floating-point remainder of dividing `self` by a scalar.
///
/// The result has the same sign as `self` and magnitude less than the scalar.
///
/// # Special Cases (IEEE 754 compliant)
///
/// - If `self` is ±∞ and scalar is not NaN, NaN is returned
/// - If scalar is ±0 and `self` is not NaN, NaN is returned
/// - If scalar is ±∞ and `self` is finite, `self` is returned
/// - If either argument is NaN, NaN is returned
///
/// # Arguments
///
/// * `scalar` - The scalar divisor.
///
/// # Returns
///
/// A tensor with the same shape where each element is the floating-point remainder.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let tensor = Tensor::<B, 1>::from_data([5.3, -5.3, 7.5, -7.5], &device);
/// let result = tensor.fmod_scalar(2.0);
///
/// // Result: [1.3, -1.3, 1.5, -1.5]
/// }
/// ```
pub fn fmod_scalar(self, scalar: f32) -> Self {
// Normal case: fmod(x, y) = x - y * trunc(x / y)
let quotient = self.clone().div_scalar(scalar);
let truncated = quotient.trunc();
let product = truncated.mul_scalar(scalar);
// Handle the special case where scalar is infinity
// When scalar is ±∞ and self is finite, quotient is 0, truncated is 0
// but 0 * infinity = NaN, which is wrong - it should be 0
if scalar.is_infinite() {
// For finite values, fmod(x, ±∞) = x
// For infinite values, fmod(±∞, ±∞) = NaN (which is handled by arithmetic)
return self;
}
self - product
}
}

View File

@@ -0,0 +1,182 @@
use burn_backend::Scalar;
use crate::{
Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend,
cartesian_grid,
};
use core::ops::Range;
impl<B> Tensor<B, 1, Int>
where
B: Backend,
{
/// Returns a new integer tensor on the specified device.
///
/// # Arguments
///
/// * `range` - The range of values to generate.
/// * `device` - The device to create the tensor on.
pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
Tensor::new(B::int_arange(range, device))
}
/// Returns a new integer tensor on the specified device.
///
/// # Arguments
///
/// * `range` - The range of values to generate.
/// * `step` - The step between each value.
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::int_arange_step(range, step, device))
}
}
impl<const D: usize, B> Tensor<B, D, Int>
where
B: Backend,
{
/// Create a tensor from integers (i32), placing it on a given device.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
/// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
/// }
/// ```
pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
Self::from_data(ints.into().convert::<i32>(), device)
}
/// Returns a new tensor with the same shape and device as the current tensor and the data
/// cast to Float.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Int, Tensor};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
/// let float_tensor = int_tensor.float();
/// }
/// ```
pub fn float(self) -> Tensor<B, D, Float> {
Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))
}
/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
/// use burn_tensor::Int;
/// use burn_tensor::{backend::Backend, Shape, Tensor};
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
/// println!("{}", result);
/// }
/// ```
pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
shape: S,
device: &B::Device,
) -> Tensor<B, D2, Int> {
cartesian_grid::<B, S, D, D2>(shape, device)
}
/// Applies the bitwise logical and operation with each bit representing the integer.
pub fn bitwise_and(self, other: Self) -> Self {
Self::new(B::bitwise_and(self.primitive, other.primitive))
}
/// Applies the bitwise logical or operation with another tensor.
pub fn bitwise_or(self, other: Self) -> Self {
Self::new(B::bitwise_or(self.primitive, other.primitive))
}
/// Applies the bitwise logical xor operation with another tensor.
pub fn bitwise_xor(self, other: Self) -> Self {
Self::new(B::bitwise_xor(self.primitive, other.primitive))
}
/// Applies the bitwise logical not operation.
pub fn bitwise_not(self) -> Self {
Self::new(B::bitwise_not(self.primitive))
}
/// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.
pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {
let other = Scalar::new(other, &self.dtype());
Self::new(B::bitwise_and_scalar(self.primitive, other))
}
/// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.
pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {
let other = Scalar::new(other, &self.dtype());
Self::new(B::bitwise_or_scalar(self.primitive, other))
}
/// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.
pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {
let other = Scalar::new(other, &self.dtype());
Self::new(B::bitwise_xor_scalar(self.primitive, other))
}
/// Applies the bitwise left shift operation with the integers in the tensor.
pub fn bitwise_left_shift(self, other: Self) -> Self {
Self::new(B::bitwise_left_shift(self.primitive, other.primitive))
}
/// Applies the bitwise right shift operation with the integers in the tensor.
pub fn bitwise_right_shift(self, other: Self) -> Self {
Self::new(B::bitwise_right_shift(self.primitive, other.primitive))
}
/// Applies the bitwise left shift operation with the scalar.
pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {
let other = Scalar::new(other, &self.dtype());
Self::new(B::bitwise_left_shift_scalar(self.primitive, other))
}
/// Applies the bitwise right shift operation with the scalar.
pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {
let other = Scalar::new(other, &self.dtype());
Self::new(B::bitwise_right_shift_scalar(self.primitive, other))
}
/// Converts a tensor to the specified integer data type.
///
/// This is always a no-op when casting to the current dtype.
///
/// # Warning
/// Most backends don't have automatic type promotion at this time, so make sure that all tensors
/// have the same integer data type for operations multiple input tensors (e.g., binary ops).
pub fn cast<F: Into<IntDType>>(self, dtype: F) -> Tensor<B, D, Int> {
let dtype = dtype.into();
let self_dtype: IntDType = self.dtype().into();
if dtype == self_dtype {
// no-op.
return self;
}
Tensor::new(B::int_cast(self.primitive, dtype))
}
}

View File

@@ -0,0 +1,27 @@
pub(crate) mod check;
mod autodiff;
mod base;
mod bool;
mod cartesian_grid;
mod float;
mod fmod;
mod int;
mod numeric;
mod options;
mod orderable;
mod pad;
pub use pad::IntoPadding;
mod take;
mod transaction;
mod trunc;
pub use autodiff::*;
pub use base::*;
pub use cartesian_grid::cartesian_grid;
pub use float::{DEFAULT_ATOL, DEFAULT_RTOL};
pub use numeric::*;
pub use options::*;
pub use transaction::*;
pub use burn_backend::tensor::IndexingUpdateOp;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,116 @@
use burn_backend::{Backend, Element, tensor::Device};
use burn_std::DType;
use crate::get_device_policy;
/// Options for tensor creation.
///
/// This struct allows specifying the `device` and overriding the data type when creating a tensor.
/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.
#[derive(Debug, Clone)]
pub struct TensorCreationOptions<B: Backend> {
/// Device where the tensor will be created.
pub device: Device<B>,
/// Optional data type.
/// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).
pub dtype: Option<DType>,
}
impl<B: Backend> Default for TensorCreationOptions<B> {
/// Returns new options with the backend's default device.
fn default() -> Self {
Self::new(Default::default())
}
}
impl<B: Backend> TensorCreationOptions<B> {
/// Create new options with a specific device.
///
/// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.
pub fn new(device: Device<B>) -> Self {
Self {
device,
dtype: None,
}
}
/// Set the tensor creation data type.
pub fn with_dtype(mut self, dtype: DType) -> Self {
self.dtype = Some(dtype);
self
}
/// Set the tensor creation device.
pub fn with_device(mut self, device: Device<B>) -> Self {
self.device = device;
self
}
/// Create options with backend's default device and float dtype.
pub fn float() -> Self {
Self::default().with_dtype(<B::FloatElem as Element>::dtype())
}
/// Create options with backend's default device and int dtype.
pub fn int() -> Self {
Self::default().with_dtype(<B::IntElem as Element>::dtype())
}
/// Create options with backend's default device and bool dtype.
pub fn bool() -> Self {
Self::default().with_dtype(<B::BoolElem as Element>::dtype())
}
/// Returns the tensor data type, or a provided default if not set.
///
/// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
pub fn dtype_or(&self, dtype: DType) -> DType {
self.dtype.unwrap_or(dtype)
}
/// Returns the tensor data type, or the default from the [device policy](crate::set_default_dtypes).
pub(crate) fn resolve_policy(&self, dtype: DType) -> DType {
// TODO: should rely on tensor kind, not element dtype
self.dtype.unwrap_or_else(|| {
let policy = get_device_policy(&self.device);
if dtype.is_float()
&& let Some(float_dtype) = policy.float_dtype()
{
float_dtype.into()
} else if (dtype.is_int() || dtype.is_uint())
&& let Some(int_dtype) = policy.int_dtype()
{
int_dtype.into()
} else {
// If policy was not explicitly set, use the fallback dtype (default backend elem type)
dtype
}
})
}
}
impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
/// Convenience conversion from a reference to a device.
///
/// Example:
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::TensorCreationOptions;
///
/// fn example<B: Backend>(device: B::Device) {
/// let options: TensorCreationOptions<B> = (&device).into();
/// }
/// ```
fn from(device: &Device<B>) -> Self {
TensorCreationOptions::new(device.clone())
}
}
impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
/// Convenience conversion for a specified `(&device, dtype)` tuple.
fn from(args: (&Device<B>, DType)) -> Self {
TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,363 @@
use alloc::vec::Vec;
use core::ops::Range;
use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
use super::Numeric;
/// Trait for types that can be used as padding specifications.
///
/// Padding is specified as `(before, after)` pairs per dimension, returned as a
/// fixed-size array `[(usize, usize); D]`. If fewer pairs than dimensions are provided,
/// they apply to the **last** N dimensions (earlier dimensions are left unpadded).
pub trait IntoPadding<const D: usize> {
/// Converts into a fixed-size array of `(before, after)` padding pairs.
fn into_padding(self) -> [(usize, usize); D];
}
impl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
N <= D,
"Padding has {} pairs but tensor only has {} dimensions",
N,
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - N;
for (i, pair) in self.into_iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
/// Backward-compatible: `(left, right, top, bottom)` maps to last 2 dimensions.
///
/// Equivalent to `[(top, bottom), (left, right)]`.
impl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {
fn into_padding(self) -> [(usize, usize); D] {
let (left, right, top, bottom) = self;
let mut result = [(0usize, 0usize); D];
result[D - 2] = (top, bottom);
result[D - 1] = (left, right);
result
}
}
impl<const D: usize> IntoPadding<D> for &[(usize, usize)] {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
self.len() <= D,
"Padding has {} pairs but tensor only has {} dimensions",
self.len(),
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - self.len();
for (i, &pair) in self.iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
impl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
self.len() <= D,
"Padding has {} pairs but tensor only has {} dimensions",
self.len(),
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - self.len();
for (i, pair) in self.into_iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
/// Helper to build a range array for slice_assign, selecting a portion of one dimension.
fn build_slice_ranges<const D: usize>(
dims: [usize; D],
target_dim: usize,
start: usize,
len: usize,
) -> [Range<usize>; D] {
dims.iter()
.enumerate()
.map(|(i, &size)| {
if i == target_dim {
start..start + len
} else {
0..size
}
})
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap()
}
impl<B, const D: usize, K> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
/// Pads the tensor using the specified padding mode.
///
/// Padding is specified as `(before, after)` pairs. If fewer pairs than tensor dimensions
/// are provided, they apply to the **last** N dimensions (unspecified leading dimensions
/// are left unpadded).
///
/// For backward compatibility, a `(left, right, top, bottom)` tuple is also accepted,
/// which pads the last two dimensions.
///
/// # Arguments
///
/// * `padding` - Padding specification. Accepts:
/// - `[(before, after); N]` fixed-size array of pairs (N <= D)
/// - `&[(before, after)]` slice of pairs per dimension
/// - `Vec<(before, after)>` vector of pairs
/// - `(left, right, top, bottom)` tuple for last-2-dim backward compatibility
/// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`.
///
/// # Returns
///
/// A new tensor with the specified padding applied.
///
/// # Panics
///
/// - Panics if more padding pairs are provided than tensor dimensions.
/// - `Reflect` mode panics if padding exceeds `dimension_size - 1`.
/// - `Edge` mode panics if padding is applied to a zero-sized dimension.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
/// use burn_tensor::ops::PadMode;
///
/// fn example<B: Backend<FloatElem: From<f32>>>() {
/// let device = B::Device::default();
/// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
///
/// // Constant padding with value 0.0 (backward-compatible tuple)
/// let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0));
///
/// // Pad arbitrary dimensions with slice of (before, after) pairs
/// let padded = tensor.clone().pad([(1, 1), (2, 2)], PadMode::Constant(0.0));
///
/// // Pad only the last dimension
/// let padded = tensor.pad([(1, 1)], PadMode::Reflect);
/// }
/// ```
pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {
let pairs = padding.into_padding();
match mode.into() {
PadMode::Constant(value) => pad_constant(self, &pairs, value),
PadMode::Reflect => pad_reflect(self, &pairs),
PadMode::Edge => pad_edge(self, &pairs),
}
}
}
/// Pad with a constant value.
fn pad_constant<B, const D: usize, K, E>(
tensor: Tensor<B, D, K>,
padding: &[(usize, usize); D],
value: E,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
E: ElementConversion,
{
let mut padded_dims: [usize; D] = tensor.dims();
for (i, &(before, after)) in padding.iter().enumerate() {
padded_dims[i] += before + after;
}
let ranges: [Range<usize>; D] = padded_dims
.iter()
.enumerate()
.map(|(i, &dim)| {
let (before, after) = padding[i];
before..dim - after
})
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap();
let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
padded_tensor.slice_assign(ranges, tensor)
}
/// Pad using reflection at the boundaries (excluding edge values).
///
/// For ONNX "reflect" mode: mirrors from index 1, not index 0.
/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]`
fn pad_reflect<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: &[(usize, usize); D],
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
assert!(
before < dims[i] && after < dims[i],
"Reflect padding ({}, {}) must be less than dimension {} size ({})",
before,
after,
i,
dims[i]
);
}
}
let mut result = tensor;
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
result = pad_reflect_dim(result, i, before, after);
}
}
result
}
/// Helper to pad a single dimension using reflection.
fn pad_reflect_dim<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
dim: usize,
pad_before: usize,
pad_after: usize,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
let dim_size = dims[dim];
// Calculate output dimensions
let mut output_dims = dims;
output_dims[dim] += pad_before + pad_after;
// Create output tensor and place original in the center
let output = Tensor::zeros(output_dims, &tensor.device());
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
let mut output = output.slice_assign(original_range, tensor.clone());
// Assign reflected "before" padding (e.g., top or left)
// Reflect excludes the edge, so we take indices [1..pad_before+1] and flip
if pad_before > 0 {
let before_slice = tensor.clone().narrow(dim, 1, pad_before);
let before_flipped = before_slice.flip([dim as isize]);
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
output = output.slice_assign(before_range, before_flipped);
}
// Assign reflected "after" padding (e.g., bottom or right)
// Take indices [dim_size - pad_after - 1..dim_size - 1] and flip
if pad_after > 0 {
let start = dim_size - pad_after - 1;
let after_slice = tensor.narrow(dim, start, pad_after);
let after_flipped = after_slice.flip([dim as isize]);
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
output = output.slice_assign(after_range, after_flipped);
}
output
}
/// Pad by replicating edge values.
///
/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]`
fn pad_edge<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: &[(usize, usize); D],
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
assert!(
dims[i] > 0,
"Cannot apply edge padding to zero-sized dimension {}",
i
);
}
}
let mut result = tensor;
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
result = pad_edge_dim(result, i, before, after);
}
}
result
}
/// Helper to pad a single dimension by replicating edge values.
fn pad_edge_dim<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
dim: usize,
pad_before: usize,
pad_after: usize,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
let dim_size = dims[dim];
// Calculate output dimensions
let mut output_dims = dims;
output_dims[dim] += pad_before + pad_after;
// Create output tensor and place original in the center
let output = Tensor::zeros(output_dims, &tensor.device());
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
let mut output = output.slice_assign(original_range, tensor.clone());
// Assign "before" padding by repeating the first element
if pad_before > 0 {
let first_slice = tensor.clone().narrow(dim, 0, 1);
let before_pad = first_slice.repeat_dim(dim, pad_before);
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
output = output.slice_assign(before_range, before_pad);
}
// Assign "after" padding by repeating the last element
if pad_after > 0 {
let last_slice = tensor.narrow(dim, dim_size - 1, 1);
let after_pad = last_slice.repeat_dim(dim, pad_after);
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
output = output.slice_assign(after_range, after_pad);
}
output
}

View File

@@ -0,0 +1,98 @@
use crate::{AsIndex, BasicOps, Int, Tensor, backend::Backend, check, check::TensorCheck};
use alloc::vec::Vec;
impl<B, const D: usize, K> Tensor<B, D, K>
where
B: Backend,
K: BasicOps<B>,
{
/// Takes elements from the tensor along the given dimension using indices of any dimensionality.
///
/// This behaves like numpy's take function. When indices is multi-dimensional,
/// the output shape will be: input.shape\[:dim\] + indices.shape + input.shape\[dim+1:\]
///
/// # Arguments
///
/// * `dim` - The dimension along which to select elements. Supports negative indexing.
/// * `indices` - The indices of elements to select. Can be any dimensionality.
/// Must be valid indices in the range [0, dim_size).
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
///
/// // Example with 1D indices
/// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
/// let indices = Tensor::<B, 1, Int>::from_data([2, 0, 1], &device);
/// let result: Tensor<B, 2> = tensor.clone().take::<1, 2>(-1, indices); // -1 refers to last dimension
/// println!("{result}");
/// // [[3.0, 1.0, 2.0], [6.0, 4.0, 5.0]]
///
/// // Example with 2D indices - output will have +1 dimension (2D -> 3D)
/// let indices_2d = Tensor::<B, 2, Int>::from_data([[0, 2], [1, 0]], &device);
/// let result: Tensor<B, 3> = tensor.take::<2, 3>(1, indices_2d);
/// println!("{result}");
/// // [[[1.0, 3.0], [2.0, 1.0]], [[4.0, 6.0], [5.0, 4.0]]]
/// }
/// ```
pub fn take<const DI: usize, const DO: usize>(
self,
dim: impl AsIndex,
indices: Tensor<B, DI, Int>,
) -> Tensor<B, DO, K> {
let dim = dim.expect_dim_index(D);
check!(TensorCheck::take::<D, DI, DO>(dim));
// Store the indices shape for reshaping later
let indices_shape = indices.shape();
let indices_dims = indices_shape.clone();
// Flatten indices to 1D for processing
let indices_flat = indices.reshape([indices_shape.num_elements()]);
// Perform the selection with the flattened indices
let selected = self.select(dim, indices_flat);
// Build the output shape
// Output shape = input.shape[:dim] + indices.shape + input.shape[dim+1:]
let selected_shape = selected.shape();
let mut new_shape = Vec::with_capacity(DO);
// Add dimensions before the selected dimension
for i in 0..dim {
new_shape.push(selected_shape[i]);
}
// Add all indices dimensions
for &idx_dim in indices_dims.iter() {
new_shape.push(idx_dim);
}
// Add dimensions after the selected dimension
for i in (dim + 1)..D {
new_shape.push(selected_shape[i]);
}
// Verify we have the correct number of dimensions
assert_eq!(
new_shape.len(),
DO,
"Internal error: shape calculation resulted in {} dims but expected {}",
new_shape.len(),
DO
);
// Convert to fixed-size array for reshape
let mut shape_array = [0; DO];
for (i, &s) in new_shape.iter().enumerate() {
shape_array[i] = s;
}
selected.reshape(shape_array)
}
}

View File

@@ -0,0 +1,57 @@
use super::{BasicOps, Tensor};
use crate::{
TensorData,
backend::{Backend, ExecutionError},
ops::TransactionPrimitive,
};
use alloc::vec::Vec;
#[derive(Default)]
/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
/// compute utilization with optimized laziness.
///
/// # Example
///
/// ```rust,ignore
/// let [output_data, loss_data, targets_data] = Transaction::default()
/// .register(output)
/// .register(loss)
/// .register(targets)
/// .execute()
/// .try_into()
/// .expect("Correct amount of tensor data");
/// ```
pub struct Transaction<B: Backend> {
op: TransactionPrimitive<B>,
}
impl<B: Backend> Transaction<B> {
/// Add a [tensor](Tensor) to the transaction to be read.
pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
K::register_transaction(&mut self.op, tensor.into_primitive());
self
}
/// Executes the transaction synchronously and returns the [data](TensorData) in the same order
/// in which they were [registered](Self::register).
pub fn execute(self) -> Vec<TensorData> {
burn_std::future::block_on(self.execute_async())
.expect("Error while reading data: use `try_execute` to handle error at runtime")
}
/// Executes the transaction synchronously and returns the [data](TensorData) in the same
/// order in which they were [registered](Self::register).
///
/// # Returns
///
/// Any error that might have occurred since the last time the device was synchronized.
pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
burn_std::future::block_on(self.execute_async())
}
/// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
/// in which they were [registered](Self::register).
pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
self.op.execute_async().await
}
}

View File

@@ -0,0 +1,42 @@
use crate::{Float, Tensor, TensorPrimitive, backend::Backend};
impl<B, const D: usize> Tensor<B, D, Float>
where
B: Backend,
{
/// Truncates the tensor element-wise, rounding toward zero.
///
/// This function returns a new tensor with the same shape as the input tensor,
/// where each element is truncated toward zero. For positive values, this is
/// equivalent to floor, and for negative values, it's equivalent to ceil.
///
/// # Special Cases (IEEE 754 compliant)
///
/// - `trunc(±0)` returns ±0 (preserves sign of zero)
/// - `trunc(±∞)` returns ±∞
/// - `trunc(NaN)` returns NaN
///
/// # Returns
///
/// A tensor with the same shape where each element has been truncated toward zero.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let tensor = Tensor::<B, 1>::from_data([2.3, -1.7, 0.5, -0.5, 3.9], &device);
/// let truncated = tensor.trunc();
///
/// // Result: [2.0, -1.0, 0.0, -0.0, 3.0]
/// }
/// ```
pub fn trunc(self) -> Self {
Self::new(TensorPrimitive::Float(B::float_trunc(
self.primitive.tensor(),
)))
}
}

View File

@@ -0,0 +1,60 @@
use crate::ElementConversion;
use crate::backend::Backend;
use crate::s;
use crate::tensor::{Int, Tensor};
use alloc::vec;
/// Generate a tensor with homogeonous coordinates of each element's
/// transformed location
///
///
/// See:
/// - [torch.nn.functional.affine_grid](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html)
///
/// * `transform` - Transformation with shape (batch_size, 2, 3)
/// * `dims` - dimensions as (batch_size, channels, height, width)
///
/// # Returns
///
/// Tensor with shape (batch_size, height, width, 2), where dim 2 is (x, y)
/// All coordinates are broadcast on the batch dim
pub fn affine_grid_2d<B: Backend>(transform: Tensor<B, 3>, dims: [usize; 4]) -> Tensor<B, 4> {
let [batch_size, _c, height, width] = dims;
let device = &transform.device();
let x = Tensor::<B, 1, Int>::arange(0..width as i64, device)
.reshape([1, width])
.expand([height, width]);
let y = Tensor::<B, 1, Int>::arange(0..height as i64, device)
.reshape([height, 1])
.expand([height, width]);
// from ints (0..(width-1)) and (0..(height-1)), to (-1.0..1.0)
let x = x
.float()
.div_scalar(((width - 1) as f32 / 2.0).elem::<f32>())
.sub_scalar((1_f32).elem::<f32>());
let y = y
.float()
.div_scalar(((height - 1) as f32 / 2.0).elem::<f32>())
.sub_scalar((1_f32).elem::<f32>());
// Broadcast to batch dimension
let x = x.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W]
let y = y.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W]
// Apply affine transform
let a_11 = transform.clone().slice(s![.., 0, 0]);
let a_12 = transform.clone().slice(s![.., 0, 1]);
let trans_x = transform.clone().slice(s![.., 0, 2]);
let a_21 = transform.clone().slice(s![.., 1, 0]);
let a_22 = transform.clone().slice(s![.., 1, 1]);
let trans_y = transform.slice(s![.., 1, 2]);
let grid_x = a_11.mul(x.clone()).add(a_12.mul(y.clone())).add(trans_x);
let grid_y = a_21.mul(x).add(a_22.mul(y)).add(trans_y);
Tensor::stack(vec![grid_x, grid_y], 3)
}

View File

@@ -0,0 +1,107 @@
use crate::backend::Backend;
use crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos};
use crate::tensor::{BasicOps, Tensor};
use alloc::vec::Vec;
/// Return a collection of coordinate matrices for coordinate vectors.
///
/// Takes N 1D tensors and returns N tensors where each tensor represents the coordinates
/// in one dimension across an N-dimensional grid.
///
/// Based upon `options.sparse`, the generated coordinate tensors can either be `Sparse` or `Dense`:
/// * In `Sparse` mode, output tensors will have shape 1 everywhere except their cardinal dimension.
/// * In `Dense` mode, output tensors will be expanded to the full grid shape.
///
/// Based upon `options.indexing`, the generated coordinate tensors will use either:
/// * `Matrix` indexing, where dimensions are in the same order as their cardinality.
/// * `Cartesian` indexing; where the first two dimensions are swapped.
///
/// See:
/// - [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html)
/// - [torch.meshgrid](https://pytorch.org/docs/stable/generated/torch.meshgrid.html)
///
/// # Arguments
///
/// * `tensors` - A slice of 1D tensors
/// * `options` - the options.
///
/// # Returns
///
/// A vector of N N-dimensional tensors representing the grid coordinates.
pub fn meshgrid<B: Backend, const N: usize, K, O>(
tensors: &[Tensor<B, 1, K>; N],
options: O,
) -> [Tensor<B, N, K>; N]
where
K: BasicOps<B>,
O: Into<GridOptions>,
{
let options = options.into();
let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1;
let dense = options.sparsity == GridSparsity::Dense;
let grid_shape: [usize; N] = tensors
.iter()
.map(|t| t.dims()[0])
.collect::<Vec<_>>()
.try_into()
.unwrap();
tensors
.iter()
.enumerate()
.map(|(i, tensor)| {
let mut coord_tensor_shape = [1; N];
coord_tensor_shape[i] = grid_shape[i];
// Reshape the tensor to have singleton dimensions in all but the i-th dimension
let mut tensor = tensor.clone().reshape(coord_tensor_shape);
if dense {
tensor = tensor.expand(grid_shape);
}
if swap_dims {
tensor = tensor.swap_dims(0, 1);
}
tensor
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
/// Return a coordinate matrix for a given set of 1D coordinate tensors.
///
/// Equivalent to stacking a dense matrix `meshgrid`,
/// where the stack is along the first or last dimension.
///
/// # Arguments
///
/// * `tensors`: A slice of 1D tensors.
/// * `index_pos`: The position of the index in the output tensor.
///
/// # Returns
///
/// A tensor of either ``(N, ..., |T[i]|, ...)`` or ``(..., |T[i]|, ..., N)``,
/// of coordinates, indexed on the first or last dimension.
pub fn meshgrid_stack<B: Backend, const D: usize, const D2: usize, K>(
tensors: &[Tensor<B, 1, K>; D],
index_pos: IndexPos,
) -> Tensor<B, D2, K>
where
K: BasicOps<B>,
{
assert_eq!(D2, D + 1, "D2 ({D2}) != D ({D}) + 1");
let xs: Vec<Tensor<B, D, K>> = meshgrid(tensors, GridOptions::default())
.into_iter()
.collect();
let dim = match index_pos {
IndexPos::First => 0,
IndexPos::Last => D,
};
Tensor::stack(xs, dim)
}

View File

@@ -0,0 +1,68 @@
mod affine_grid;
mod meshgrid;
pub use meshgrid::*;
pub use affine_grid::*;
/// Enum to specify index cardinal layout.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum GridIndexing {
/// Dimensions are in the same order as the cardinality of the inputs.
/// Equivalent to "ij" indexing in NumPy and PyTorch.
#[default]
Matrix,
/// The same as Matrix, but the first two dimensions are swapped.
/// Equivalent to "xy" indexing in NumPy and PyTorch.
Cartesian,
}
/// Enum to specify grid sparsity mode.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum GridSparsity {
/// The grid is fully expanded to the full cartesian product shape.
#[default]
Dense,
/// The grid is sparse, expanded only at the cardinal dimensions.
Sparse,
}
/// Grid policy options.
#[derive(new, Default, Debug, Copy, Clone)]
pub struct GridOptions {
/// Indexing mode.
pub indexing: GridIndexing,
/// Sparsity mode.
pub sparsity: GridSparsity,
}
impl From<GridIndexing> for GridOptions {
fn from(value: GridIndexing) -> Self {
Self {
indexing: value,
..Default::default()
}
}
}
impl From<GridSparsity> for GridOptions {
fn from(value: GridSparsity) -> Self {
Self {
sparsity: value,
..Default::default()
}
}
}
/// Enum to specify the index dimension position.
#[derive(Default, Debug, Copy, Clone)]
pub enum IndexPos {
/// The index is in the first dimension.
#[default]
First,
/// The index is in the last dimension.
Last,
}

View File

@@ -0,0 +1,48 @@
use crate::ElementConversion;
use crate::backend::Backend;
use crate::tensor::Tensor;
use super::l2_norm;
/// Default epsilon value to avoid division by zero
pub const DEFAULT_EPSILON: f64 = 1e-8;
/// Computes the cosine similarity between two tensors along a specified dimension.
///
/// Calculates the cosine of the angle between inputs as their dot product divided
/// by the product of their L2 norms.
///
/// # Arguments
///
/// * `x1` - First input tensor
/// * `x2` - Second input tensor
/// * `dim` - Dimension along which to compute the similarity
/// (negative indices allowed: -1 for last dimension)
/// * `eps` - Small value to avoid division by zero (default: 1e-8)
///
/// # Returns
///
/// Tensor containing the cosine similarity between x1 and x2
pub fn cosine_similarity<B: Backend, const D: usize>(
x1: Tensor<B, D>,
x2: Tensor<B, D>,
dim: i32,
eps: Option<B::FloatElem>,
) -> Tensor<B, D> {
let eps = eps.unwrap_or_else(|| B::FloatElem::from_elem(DEFAULT_EPSILON));
// Convert negative dimension to positive
let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;
// Compute dot product: sum(x1 * x2) along the specified dimension
let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);
// Compute L2 norms: ||x1|| and ||x2||
let norm_x1 = l2_norm(x1, dim_idx);
let norm_x2 = l2_norm(x2, dim_idx);
// Calculate the denominator (product of the norms) with epsilon to avoid division by zero
let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);
// Return the cosine similarity (dot product divided by the product of norms)
dot_product / denominator
}

View File

@@ -0,0 +1,44 @@
use crate::backend::Backend;
use crate::check;
use crate::check::TensorCheck;
use crate::tensor::{Int, Tensor};
use crate::{BasicOps, TensorKind};
/// Returns the diag of a matrix.
///
/// For batched inputs, returns of each matrix in the batch independently.
///
/// The diag operation extracts the diagonal elements of the last two dimensions,
/// treating them as the matrix dimensions, while preserving all leading batch dimensions.
///
/// # Arguments
///
/// * `tensor` - The input tensor with at least 2 dimensions.
///
/// # Returns
/// A tensor of rank `D - 1`, where the last dimension contains the diagonal elements of the input.
pub fn diag<B: Backend, const D: usize, const DO: usize, K>(
tensor: Tensor<B, D, K>,
) -> Tensor<B, DO, K>
where
K: TensorKind<B> + BasicOps<B>,
{
check!(TensorCheck::diag::<D, DO>());
let shape = tensor.shape();
let rows = shape[D - 2];
let cols = shape[D - 1];
let diag_len = rows.min(cols);
let device = tensor.device();
// create the indices for the diag
let mut flat_shape = shape.clone();
flat_shape[D - 2] = rows * cols;
flat_shape[D - 1] = 1;
let flat: Tensor<B, D, K> = tensor.reshape(flat_shape);
let range = Tensor::<B, 1, Int>::arange(0..diag_len as i64, &device);
let step_tensor = Tensor::<B, 1, Int>::from_data([cols as i64 + 1], &device);
let indices = range * step_tensor;
flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1)
}

View File

@@ -0,0 +1,79 @@
use crate::{
Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s,
tensor::Tensor,
};
/// Performs PLU decomposition of a square matrix.
///
/// The function decomposes a given square matrix `A` into three matrices: a permutation vector `p`,
/// a lower triangular matrix `L`, and an upper triangular matrix `U`, such that `PA = LU`.
/// The permutation vector `p` represents the row swaps made during the decomposition process.
/// The lower triangular matrix `L` has ones on its diagonal and contains the multipliers used
/// during the elimination process below the diagonal. The upper triangular matrix `U` contains
/// the resulting upper triangular form of the matrix after the elimination process.
///
/// # Arguments
/// * `tensor` - A square matrix to decompose, represented as a 2D tensor.
///
/// # Returns
/// A tuple containing:
/// - A 2D tensor representing the combined `L` and `U` matrices.
/// - A 1D tensor representing the permutation vector `p`.
///
/// # Panics and numerical issues
/// - The function will panic if the input matrix is singular or near-singular.
/// - The function will panic if the input matrix is not square.
/// # Performance note (synchronization / device transfers)
/// This function may involve multiple synchronizations and device transfers, especially
/// when determining pivot elements and performing row swaps. This can impact performance,
pub fn lu_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
check!(TensorCheck::is_square::<2>(
"lu_decomposition",
&tensor.shape()
));
let dims = tensor.shape().dims::<2>();
let n = dims[0];
let mut permutations = Tensor::arange(0..n as i64, &tensor.device());
let mut tensor = tensor;
for k in 0..n {
// Find the pivot row
let p = tensor
.clone()
.slice(s![k.., k])
.abs()
.argmax(0)
.into_scalar()
.to_usize()
+ k;
let max = tensor.clone().slice(s![p, k]).abs();
// Avoid division by zero
let pivot = max.into_scalar();
check!(TensorCheck::lu_decomposition_pivot::<B>(pivot));
if p != k {
tensor = swap_slices(tensor, s![k, ..], s![p, ..]);
permutations = swap_slices(permutations, s![k], s![p]);
}
// Normalize k-th column under the diagonal
if k < n - 1 {
let a_kk = tensor.clone().slice(s![k, k]);
let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk;
tensor = tensor.slice_assign(s![(k + 1).., k], column);
}
// Update the trailing submatrix
for i in (k + 1)..n {
// a[i, k+1..] -= a[i, k] * a[k, k+1..]
let a_ik = tensor.clone().slice(s![i, k]);
let row_k = tensor.clone().slice(s![k, (k + 1)..]);
let update = a_ik * row_k;
let row_i = tensor.clone().slice(s![i, (k + 1)..]);
tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update);
}
}
(tensor, permutations)
}

View File

@@ -0,0 +1,59 @@
use crate::Numeric;
use crate::backend::Backend;
use crate::tensor::{BasicOps, Shape, Tensor};
/// Performs matrix-vector multiplication with optional batch dimensions.
///
/// The `matrix` tensor is expected to have rank `DM` with the last two dimensions representing
/// the matrix rows and columns. The `vector` tensor should have rank `DV = DM - 1`, sharing
/// broadcast-compatible batch dimensions and matching the last dimension of the matrix.
///
/// # Panics
///
/// * If the matrix rank is lower than 2.
/// * If the vector rank isn't one less than the matrix rank.
/// * If batch dimensions differ between the operands.
/// * If the inner dimensions are incompatible for multiplication.
pub fn matvec<B: Backend, const DM: usize, const DV: usize, K>(
matrix: Tensor<B, DM, K>,
vector: Tensor<B, DV, K>,
) -> Tensor<B, DV, K>
where
K: BasicOps<B> + Numeric<B>,
{
assert!(
DM >= 2,
"matvec expects the matrix to be at least rank 2 (got {DM})"
);
assert!(
DM == DV + 1,
"matvec expects the vector rank ({DV}) to be exactly one less than the matrix rank ({DM})",
);
let matrix_dims = matrix.shape().dims::<DM>();
let vector_dims = vector.shape().dims::<DV>();
// Validate batch dimensions (all leading dimensions prior to the matrix axes).
let batch_rank = DM.saturating_sub(2);
if batch_rank > 0 {
let matrix_batch = Shape::from(&matrix_dims[..batch_rank]);
let vector_batch = Shape::from(&vector_dims[..batch_rank]);
assert!(
matrix_batch.broadcast(&vector_batch).is_ok(),
"Batch dimensions are not broadcast-compatible: matrix {:?} vs vector {:?}",
&matrix_dims[..batch_rank],
&vector_dims[..batch_rank]
);
}
let matrix_inner = matrix_dims[DM - 1];
let vector_inner = vector_dims[DV - 1];
assert!(
matrix_inner == vector_inner,
"Inner dimension mismatch: matrix has {matrix_inner} columns but vector has {vector_inner} entries",
);
let vector_expanded = vector.unsqueeze_dim::<DM>(DV);
matrix.matmul(vector_expanded).squeeze_dim::<DV>(DM - 1)
}

View File

@@ -0,0 +1,42 @@
mod cosine_similarity;
mod diag;
mod lu_decomposition;
mod matvec;
mod outer;
mod trace;
mod vector_norm;
pub use cosine_similarity::*;
pub use diag::*;
pub use lu_decomposition::*;
pub use matvec::*;
pub use outer::*;
pub use trace::*;
pub use vector_norm::*;
use crate::{BasicOps, SliceArg, Tensor, TensorKind, backend::Backend};
/// Swaps two slices of a tensor.
/// # Arguments
/// * `tensor` - The input tensor.
/// * `slices1` - The first slice to swap.
/// * `slices2` - The second slice to swap.
/// # Returns
/// A new tensor with the specified slices swapped.
/// # Notes
/// This method will be useful for matrix factorization algorithms.
fn swap_slices<B: Backend, const D: usize, K, S>(
tensor: Tensor<B, D, K>,
slices1: S,
slices2: S,
) -> Tensor<B, D, K>
where
S: SliceArg + Clone,
K: TensorKind<B> + BasicOps<B>,
{
let temporary = tensor.clone().slice(slices1.clone());
let tensor = tensor
.clone()
.slice_assign(slices1, tensor.slice(slices2.clone()));
tensor.slice_assign(slices2, temporary)
}

View File

@@ -0,0 +1,77 @@
use crate::backend::Backend;
use crate::tensor::{BasicOps, Tensor};
use crate::{AsIndex, Numeric};
/// Computes the outer product for the last columns of 2 tensors.
///
/// See also: [`outer_dim`].
///
/// # Arguments
/// - `lhs`: the "row" tensor, with shape ``[..., i]``.
/// - `rhs`: the "col" tensor, with shape ``[..., j]``.
/// - `dim`: the dimension to product.
///
/// # Returns
///
/// A tensor of rank `R = D + 1`, where:
///
/// ``
/// result[..., i, j] = lhs[..., i] * rhs[..., j]
/// ``
pub fn outer<B: Backend, const D: usize, const R: usize, K>(
x: Tensor<B, D, K>,
y: Tensor<B, D, K>,
) -> Tensor<B, R, K>
where
K: BasicOps<B> + Numeric<B>,
{
outer_dim(x, y, -1)
}
/// Computes the outer product along a specific dimension, broadcasting over others.
///
/// For the given `dim`, computes the outer product of elements along that dimension,
/// expanding it into two dimensions of size ``M × N`` at positions ``(dim, dim + 1)``.
///
/// # Arguments
///
/// - `lhs`: left operand, the "row" tensor, with size `M` at dimension `dim`.
/// - `rhs`: right operand, the "col" tensor, with size `N` at dimension `dim`.
/// - `dim`: dimension to compute the outer product along (supports negative indexing).
///
/// # Returns
///
/// A tensor of rank `R = D + 1`, where:
///
/// ``
/// result[..., i, j, ...] = lhs[..., i, ...] * rhs[..., j, ...]
/// ``
//
// Notes:
// - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant
// than broadcasted elemwise multiply; benchmarking needed to confirm.
pub fn outer_dim<B: Backend, const D: usize, const R: usize, Dim: AsIndex, K>(
lhs: Tensor<B, D, K>,
rhs: Tensor<B, D, K>,
dim: Dim,
) -> Tensor<B, R, K>
where
K: BasicOps<B> + Numeric<B>,
{
assert_eq!(
R,
D + 1,
"`outer` with D={D} expects R={} (got R={R})",
D + 1
);
let dim = dim.expect_dim_index(D);
// (..., i, 1, ...)
let x = lhs.unsqueeze_dim::<R>(dim + 1);
// (..., 1, j, ...)
let y = rhs.unsqueeze_dim::<R>(dim);
// (..., i, j, ...)
x * y
}

View File

@@ -0,0 +1,24 @@
use super::diag;
use crate::backend::Backend;
use crate::tensor::Tensor;
/// Computes the trace of a matrix.
///
/// For batched inputs, computes the trace of each matrix in the batch independently.
///
/// The trace operation sums the diagonal elements of the last two dimensions,
/// treating them as the matrix dimensions, while preserving all leading batch dimensions.
///
/// # Arguments
///
/// * `tensor` - The input tensor with at least 2 dimensions.
///
/// # Returns
///
/// A tensor of rank `D - 1`, where the last dimension contains the sum along the diagonals
/// of the input.
pub fn trace<B: Backend, const D: usize, const DO: usize>(tensor: Tensor<B, D>) -> Tensor<B, DO> {
let diag_tensor = diag::<_, D, DO, _>(tensor);
diag_tensor.sum_dim(DO - 1)
}

View File

@@ -0,0 +1,291 @@
use burn_backend::tensor::Ordered;
use crate::backend::Backend;
use crate::tensor::{BasicOps, Tensor};
use crate::{ElementConversion, Numeric};
#[allow(unused_imports)]
use num_traits::float::Float;
/// Specifies the type of norm to compute.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Norm {
/// L0 norm (count of non-zero elements)
L0,
/// L1 norm (sum of absolute values)
L1,
/// L2 norm (Euclidean norm)
L2,
/// L:INFINITY norm (maximum absolute value)
LInf,
/// L:NEG_INFINITY norm (minimum absolute value)
LNegInf,
/// Lp norm (generalized norm)
Lp(f64),
}
impl Norm {
/// Get the exponent of the norm.
pub fn to_exponent(self) -> f64 {
use Norm::*;
match self {
L0 => 0.0,
L1 => 1.0,
L2 => 2.0,
LInf => f64::INFINITY,
LNegInf => f64::NEG_INFINITY,
Lp(p) => p,
}
}
}
impl From<u32> for Norm {
fn from(value: u32) -> Self {
use Norm::*;
match value {
0 => L0,
1 => L1,
2 => L2,
u32::MAX => LInf,
_ => Lp(value as f64),
}
}
}
impl From<i32> for Norm {
fn from(value: i32) -> Self {
use Norm::*;
match value {
0 => L0,
1 => L1,
2 => L2,
i32::MAX => LInf,
i32::MIN => LNegInf,
_ => Lp(value as f64),
}
}
}
impl From<f32> for Norm {
fn from(value: f32) -> Self {
use Norm::*;
match value {
0.0 => L0,
1.0 => L1,
2.0 => L2,
f32::INFINITY => LInf,
f32::NEG_INFINITY => LNegInf,
_ => Lp(value as f64),
}
}
}
impl From<f64> for Norm {
fn from(value: f64) -> Self {
use Norm::*;
match value {
0.0 => L0,
1.0 => L1,
2.0 => L2,
f64::INFINITY => LInf,
f64::NEG_INFINITY => LNegInf,
_ => Lp(value),
}
}
}
/// Computes the vector norm of a tensor along a specified dimension.
///
/// Generic dispatch wrapper over specialized / optimized norms.
///
/// See:
/// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)
/// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html)
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `norm` - The selected norm.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The vector norm of the input tensor.
pub fn vector_norm<B: Backend, const D: usize>(
x: Tensor<B, D>,
norm: impl Into<Norm>,
dim: usize,
) -> Tensor<B, D> {
lp_norm(x, norm.into().to_exponent(), dim)
}
/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.
///
/// Uses the specialized implementations for:
/// * 0.0
/// * 1.0
/// * 2.0
/// * 2 * N for integral N,
/// * f64::INFINITY,
/// * f64::NEG_INFINITY,
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `p` - The exponent of the Lp norm.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The ``L(p)`` norm of the input tensor.
pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
match p {
0.0 => l0_norm(x, dim),
1.0 => l1_norm(x, dim),
2.0 => l2_norm(x, dim),
p if is_even_integer(p) => lp_signed_norm(x, p as u32, dim),
f64::INFINITY => max_abs_norm(x, dim),
f64::NEG_INFINITY => min_abs_norm(x, dim),
_ => lp_norm_base(x, p, dim),
}
}
/// Normalize a tensor versus its `vector_norm`.
///
/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `norm` - The selected norm.
/// * `dim` - The dimension to compute the norm over.
/// * `eps` - The epsilon for the norm.
///
/// # Returns
///
/// The normalized tensor.
pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(
x: Tensor<B, D>,
norm: impl Into<Norm>,
dim: usize,
eps: E,
) -> Tensor<B, D> {
let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps);
x / norm
}
/// Computes the L0 norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L0 norm of the input tensor.
pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
where
K: BasicOps<B> + Numeric<B>,
{
x.zeros_like()
.mask_fill(x.not_equal_elem(0), 1)
.sum_dim(dim)
}
/// Computes the L1 norm of a tensor along a specified dimension.
///
/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L1 norm of the input tensor.
pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
where
K: BasicOps<B> + Numeric<B>,
{
x.abs().sum_dim(dim)
}
/// Computes the L2 norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L2 norm of the input tensor.
pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
x.square().sum_dim(dim).sqrt()
}
fn is_even_integer(x: f64) -> bool {
x.fract() == 0.0 && (x as i64) % 2 == 0
}
/// Computes ``L(2*n)`` for even integer ``n``.
///
/// This lets us skip the abs.
fn lp_signed_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: u32, dim: usize) -> Tensor<B, D> {
x.powi_scalar(p).sum_dim(dim).powf_scalar(1. / (p as f64))
}
/// Computes the general ``L(p)`` using the generalized method.
///
/// This uses no specialized implementations and cannot handle:
/// * 0.0
/// * f64::INFINITY,
/// * f64::NEG_INFINITY,
fn lp_norm_base<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
}
/// Computes the L:INFINITY norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L:INFINITY norm of the input tensor.
pub fn max_abs_norm<B: Backend, const D: usize, K>(
x: Tensor<B, D, K>,
dim: usize,
) -> Tensor<B, D, K>
where
K: Ordered<B>,
{
x.max_abs_dim(dim)
}
/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L:NEG_INFINITY norm of the input tensor.
pub fn min_abs_norm<B: Backend, const D: usize, K>(
x: Tensor<B, D, K>,
dim: usize,
) -> Tensor<B, D, K>
where
K: Ordered<B>,
{
x.abs().min_dim(dim)
}

View File

@@ -0,0 +1,23 @@
use crate::backend::Backend;
use crate::{Tensor, activation};
/// Computes the log softmax cross entropy between logits and target probabilities.
///
/// # Arguments
///
/// * `logits` - The logits.
/// * `target_probs` - The target probabilities.
///
/// # Returns
///
/// The log softmax cross entropy.
pub fn cross_entropy_with_logits<B: Backend, const D: usize>(
logits: Tensor<B, D>,
target_probs: Tensor<B, D>,
) -> Tensor<B, 1> {
let tensor = activation::log_softmax(logits, D - 1);
let tensor = tensor.mul(target_probs);
let tensor = tensor.sum_dim(D - 1);
tensor.mean().neg()
}

View File

@@ -0,0 +1,67 @@
pub(crate) mod stats;
mod api;
pub use api::*;
// Re-exported types
pub use burn_backend::{
DType, DataError, FloatDType, IntDType, TensorData, TensorMetadata, TensorPrimitive, Tolerance,
distribution::*,
element::*,
indexing::*,
ops::TransactionPrimitive,
shape::*,
slice::*,
tensor::{Bool, Float, Int, TensorKind},
};
/// The activation module.
pub mod activation;
/// The backend module.
pub mod backend {
pub use burn_backend::backend::*;
}
/// The container module.
pub mod container {
pub use burn_backend::tensor::TensorContainer;
}
/// The grid module.
pub mod grid;
/// The linalg module.
pub mod linalg;
/// The loss module.
pub mod loss;
/// The neural network module.
pub mod module;
/// Operations on tensors module.
pub mod ops {
pub use burn_backend::backend::ops::*;
pub use burn_backend::tensor::{
BoolElem, BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
};
}
/// Tensor quantization module.
pub mod quantization;
#[cfg(feature = "std")]
pub use report::*;
#[cfg(feature = "std")]
mod report;
#[cfg(feature = "experimental-named-tensor")]
mod named;
#[cfg(feature = "experimental-named-tensor")]
pub use named::*;
pub use ops::Device; // Re-export device so that it's available from `burn_tensor::Device`.

View File

@@ -0,0 +1,555 @@
use crate::{
Bool, Int, Tensor, TensorPrimitive,
backend::Backend,
check,
check::TensorCheck,
ops::{
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode,
PaddedConvOptions, UnfoldOptions,
},
};
use super::ops::DeformConvOptions;
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::embedding(
weights.primitive.tensor(),
indices.primitive,
)))
}
/// Applies a [1D convolution](crate::ops::ModuleOps::conv1d).
///
/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
/// asymmetric padding. When asymmetric padding is specified, an explicit pad
/// operation is applied before the convolution backend op.
pub fn conv1d<B>(
x: Tensor<B, 3>,
weight: Tensor<B, 3>,
bias: Option<Tensor<B, 1>>,
options: impl Into<PaddedConvOptions<1>>,
) -> Tensor<B, 3>
where
B: Backend,
{
let padded_options = options.into();
check!(TensorCheck::conv(
"conv1d",
x.dims(),
weight.dims(),
padded_options.options.groups,
));
if let Some(padding_end) = padded_options.padding_end {
let left = padded_options.options.padding[0];
let right = padding_end[0];
// For 1D (NCL format), pad the length dimension
let padded = x.pad((left, right, 0, 0), PadMode::Constant(0.0));
let zero_options = ConvOptions::new(
padded_options.options.stride,
[0],
padded_options.options.dilation,
padded_options.options.groups,
);
Tensor::new(TensorPrimitive::Float(B::conv1d(
padded.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
zero_options,
)))
} else {
Tensor::new(TensorPrimitive::Float(B::conv1d(
x.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
padded_options.options,
)))
}
}
/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).
///
/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
/// asymmetric padding. When asymmetric padding is specified, an explicit pad
/// operation is applied before the convolution backend op.
pub fn conv2d<B>(
x: Tensor<B, 4>,
weight: Tensor<B, 4>,
bias: Option<Tensor<B, 1>>,
options: impl Into<PaddedConvOptions<2>>,
) -> Tensor<B, 4>
where
B: Backend,
{
let padded_options = options.into();
check!(TensorCheck::conv(
"conv2d",
x.dims(),
weight.dims(),
padded_options.options.groups,
));
if let Some(padding_end) = padded_options.padding_end {
let top = padded_options.options.padding[0];
let left = padded_options.options.padding[1];
let bottom = padding_end[0];
let right = padding_end[1];
// For 2D (NCHW format), pad height and width
let padded = x.pad((left, right, top, bottom), PadMode::Constant(0.0));
let zero_options = ConvOptions::new(
padded_options.options.stride,
[0, 0],
padded_options.options.dilation,
padded_options.options.groups,
);
Tensor::new(TensorPrimitive::Float(B::conv2d(
padded.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
zero_options,
)))
} else {
Tensor::new(TensorPrimitive::Float(B::conv2d(
x.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
padded_options.options,
)))
}
}
/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).
///
/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
/// asymmetric padding. Asymmetric 3D padding is not yet supported.
pub fn conv3d<B>(
x: Tensor<B, 5>,
weight: Tensor<B, 5>,
bias: Option<Tensor<B, 1>>,
options: impl Into<PaddedConvOptions<3>>,
) -> Tensor<B, 5>
where
B: Backend,
{
let padded_options = options.into();
check!(TensorCheck::conv(
"conv3d",
x.dims(),
weight.dims(),
padded_options.options.groups,
));
if padded_options.is_asymmetric() {
panic!("Asymmetric padding is not yet supported for conv3d");
}
Tensor::new(TensorPrimitive::Float(B::conv3d(
x.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
padded_options.options,
)))
}
/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
pub fn deform_conv2d<B>(
x: Tensor<B, 4>,
offset: Tensor<B, 4>,
weight: Tensor<B, 4>,
mask: Option<Tensor<B, 4>>,
bias: Option<Tensor<B, 1>>,
options: DeformConvOptions<2>,
) -> Tensor<B, 4>
where
B: Backend,
{
check!(TensorCheck::conv(
"deform_conv2d",
x.dims(),
weight.dims(),
options.weight_groups,
));
Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
x.primitive.tensor(),
offset.primitive.tensor(),
weight.primitive.tensor(),
mask.map(|m| m.primitive.tensor()),
bias.map(|b| b.primitive.tensor()),
options,
)))
}
/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
pub fn conv_transpose1d<B>(
x: Tensor<B, 3>,
weight: Tensor<B, 3>,
bias: Option<Tensor<B, 1>>,
options: ConvTransposeOptions<1>,
) -> Tensor<B, 3>
where
B: Backend,
{
check!(TensorCheck::conv_transpose(
"conv_transpose1d",
x.dims(),
weight.dims(),
));
Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
x.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
options,
)))
}
/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).
pub fn conv_transpose2d<B>(
x: Tensor<B, 4>,
weight: Tensor<B, 4>,
bias: Option<Tensor<B, 1>>,
options: ConvTransposeOptions<2>,
) -> Tensor<B, 4>
where
B: Backend,
{
check!(TensorCheck::conv_transpose(
"conv_transpose2d",
x.dims(),
weight.dims(),
));
Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
x.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
options,
)))
}
/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).
pub fn conv_transpose3d<B>(
x: Tensor<B, 5>,
weight: Tensor<B, 5>,
bias: Option<Tensor<B, 1>>,
options: ConvTransposeOptions<3>,
) -> Tensor<B, 5>
where
B: Backend,
{
check!(TensorCheck::conv_transpose(
"conv_transpose3d",
x.dims(),
weight.dims(),
));
Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
x.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
options,
)))
}
/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).
pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::unfold4d(
x.primitive.tensor(),
kernel_size,
options,
)))
}
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
pub fn max_pool1d<B>(
x: Tensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::max_pool1d(
x.primitive.tensor(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)))
}
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
pub fn max_pool2d<B>(
x: Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::max_pool2d(
x.primitive.tensor(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)))
}
/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
pub fn avg_pool2d<B>(
x: Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
x.primitive.tensor(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
)))
}
/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
pub fn avg_pool1d<B>(
x: Tensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
x.primitive.tensor(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
)))
}
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
pub fn max_pool1d_with_indices<B>(
x: Tensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
where
B: Backend,
{
let output = B::max_pool1d_with_indices(
x.primitive.tensor(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
);
(
Tensor::new(TensorPrimitive::Float(output.output)),
Tensor::new(output.indices),
)
}
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
pub fn max_pool2d_with_indices<B>(
x: Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
where
B: Backend,
{
let output = B::max_pool2d_with_indices(
x.primitive.tensor(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
);
(
Tensor::new(TensorPrimitive::Float(output.output)),
Tensor::new(output.indices),
)
}
/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
x.primitive.tensor(),
output_size,
)))
}
/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).
pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
x.primitive.tensor(),
output_size,
)))
}
/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).
pub fn interpolate<B>(
x: Tensor<B, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::interpolate(
x.primitive.tensor(),
output_size,
options,
)))
}
/// Applies a linear transformation to the input tensor using the given weight and bias.
///
/// ```math
/// y = x @ weight + [bias]
/// ```
///
/// # Arguments:
///
/// - `input` is the input tensor, ``[..., d_input]``.
/// - `weight` is the weight tensor, ``[d_input, d_output]``.
/// - `bias` is the bias tensor (optional), ``[d_output]``.
///
/// # Returns:
///
/// The transformed tensor, ``[..., d_output]``.
///
/// # Compatibility
///
/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
/// multiplication:
///
/// ```math
/// y = x @ weight^T + [bias]
/// ```
pub fn linear<B: Backend, const D: usize>(
input: Tensor<B, D>,
weight: Tensor<B, 2>,
bias: Option<Tensor<B, 1>>,
) -> Tensor<B, D> {
if D == 1 {
// Insert and remove an extra batch dimension for the batch matmul to work.
let input = input.unsqueeze::<2>();
let output = linear(input, weight, bias);
return output.squeeze_dim(0);
}
// Perform broadcasting
//
// Important to be done before doing operations to easily fuse.
let weight = weight.unsqueeze::<D>();
let bias = bias.map(|bias| bias.unsqueeze::<D>());
let output = input.matmul(weight);
match bias {
Some(bias) => output.add(bias),
None => output,
}
}
/// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
/// where scale defaults to 1/sqrt(head_dim) (configurable via `options.scale`).
/// Optionally applies masking, additive bias, causal masking, and softcap.
///
/// # Arguments
/// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
/// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
/// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
/// where `true` indicates positions to mask (i.e. set to -inf before softmax).
/// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
/// added to the attention scores before softmax (e.g. ALiBi, relative position biases).
/// - `options`: Additional attention options (custom scale, softcap, causal masking).
///
/// # Returns
/// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
/// representing the attended context per head.
///
/// # Note
/// This implementation does not support dropout and is intended for inference or
/// use cases where dropout is not needed.
pub fn attention<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
mask: Option<Tensor<B, 4, Bool>>,
attn_bias: Option<Tensor<B, 4>>,
options: AttentionModuleOptions,
) -> Tensor<B, 4> {
Tensor::new(TensorPrimitive::Float(B::attention(
query.primitive.tensor(),
key.primitive.tensor(),
value.primitive.tensor(),
mask.map(|mask| mask.primitive),
attn_bias.map(|bias| bias.primitive.tensor()),
options,
)))
}
/// Exports attention fallback to test backend's attention against.
pub fn attention_fallback<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
mask: Option<Tensor<B, 4, Bool>>,
attn_bias: Option<Tensor<B, 4>>,
options: AttentionModuleOptions,
) -> Tensor<B, 4> {
Tensor::new(TensorPrimitive::Float(
crate::ops::attention::attention_fallback::<B>(
query.primitive.tensor(),
key.primitive.tensor(),
value.primitive.tensor(),
mask.map(|mask| mask.primitive),
attn_bias.map(|bias| bias.primitive.tensor()),
options,
),
))
}

View File

@@ -0,0 +1,85 @@
use alloc::format;
use crate::backend::Backend;
use crate::{Distribution, NamedDims, Shape, Tensor};
/// A tensor with named dimensions.
#[derive(Debug, Clone)]
pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
pub(crate) tensor: D::Tensor,
}
impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<NamedTensor<B, ND>>
for Tensor<B, D>
{
fn from(nt: NamedTensor<B, ND>) -> Self {
nt.tensor
}
}
impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<Tensor<B, D>>
for NamedTensor<B, ND>
{
fn from(tensor: Tensor<B, D>) -> Self {
Self::from_tensor(tensor)
}
}
impl<B: Backend, const D: usize, ND: NamedDims<B>> core::fmt::Display for NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&format!(
"NamedTensor[shape={:?}, dims={}]",
self.shape(),
ND::to_string(),
))
}
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Create a named tensor from a tensor.
pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
Self { tensor }
}
/// Create a random named tensor of the given shape where each element is sampled from
/// the given distribution.
pub fn random<S: Into<Shape>>(
shape: S,
distribution: Distribution,
device: &B::Device,
) -> Self {
Self::from_tensor(Tensor::random(shape, distribution, device))
}
/// Returns the shape of the current tensor.
pub fn shape(&self) -> Shape {
self.tensor.shape()
}
/// Applies element wise multiplication operation.
///
/// `y = x2 * x1`
#[allow(clippy::should_implement_trait)]
pub fn mul(self, rhs: Self) -> Self {
Self::from_tensor(self.tensor.mul(rhs.tensor))
}
/// Reshape the tensor to have the given shape.
///
/// # Panics
///
/// If the tensor can not be reshape to the given shape.
pub fn reshape<const D2: usize, S, ND2>(self, shape: S, _: ND2) -> NamedTensor<B, ND2>
where
S: Into<Shape>,
ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
{
NamedTensor::from_tensor(self.tensor.reshape(shape.into()))
}
}

View File

@@ -0,0 +1,95 @@
use alloc::format;
use alloc::string::String;
use crate::Tensor;
use crate::backend::Backend;
/// Dimension trait.
pub trait Dim: core::fmt::Debug {
/// Converts the dimension to a string.
fn to_string() -> String;
}
/// Named dimensions trait.
pub trait NamedDims<B: Backend>: core::fmt::Debug {
/// Tensor type.
type Tensor;
/// Converts the named dimensions to a string.
fn to_string() -> String;
}
/// Named dimension macro.
#[macro_export]
macro_rules! NamedDim {
($name:ident) => {
#[derive(Debug, Clone)]
pub struct $name;
impl Dim for $name {
fn to_string() -> String {
stringify!($name).to_string()
}
}
};
}
impl<B: Backend, D1> NamedDims<B> for (D1,)
where
B: Backend,
D1: Dim,
{
type Tensor = Tensor<B, 1>;
fn to_string() -> String {
format!("[{}]", D1::to_string())
}
}
impl<B: Backend, D1, D2> NamedDims<B> for (D1, D2)
where
B: Backend,
D1: Dim,
D2: Dim,
{
type Tensor = Tensor<B, 2>;
fn to_string() -> String {
format!("[{}, {}]", D1::to_string(), D2::to_string())
}
}
impl<B: Backend, D1, D2, D3> NamedDims<B> for (D1, D2, D3)
where
B: Backend,
D1: Dim,
D2: Dim,
D3: Dim,
{
type Tensor = Tensor<B, 3>;
fn to_string() -> String {
format!(
"[{}, {}, {}]",
D1::to_string(),
D2::to_string(),
D3::to_string()
)
}
}
impl<B: Backend, D1, D2, D3, D4> NamedDims<B> for (D1, D2, D3, D4)
where
B: Backend,
D1: Dim,
D2: Dim,
D3: Dim,
D4: Dim,
{
type Tensor = Tensor<B, 4>;
fn to_string() -> String {
format!(
"[{}, {}, {}, {}]",
D1::to_string(),
D2::to_string(),
D3::to_string(),
D4::to_string()
)
}
}

View File

@@ -0,0 +1,59 @@
use crate::backend::Backend;
use crate::{Dim, NamedDims, NamedTensor, Tensor};
pub trait Matmul<Rhs, Out> {
fn matmul(self, rhs: Rhs) -> Out;
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Applies the matrix multiplication operation.
///
/// `C = AB`
///
/// # Panics
///
/// If the two tensors dont' have a compatible shape.
pub fn matmul<NamedDimsRhs, NamedDimsOut>(
self,
rhs: NamedTensor<B, NamedDimsRhs>,
) -> NamedTensor<B, NamedDimsOut>
where
NamedDimsRhs: NamedDims<B, Tensor = Tensor<B, D>>,
NamedDimsOut: NamedDims<B, Tensor = Tensor<B, D>>,
Self: Matmul<NamedTensor<B, NamedDimsRhs>, NamedTensor<B, NamedDimsOut>>,
{
Matmul::matmul(self, rhs)
}
}
impl<B: Backend, X: Dim, Y: Dim, Z: Dim> Matmul<NamedTensor<B, (Y, Z)>, NamedTensor<B, (X, Z)>>
for NamedTensor<B, (X, Y)>
{
fn matmul(self, rhs: NamedTensor<B, (Y, Z)>) -> NamedTensor<B, (X, Z)> {
NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
}
}
impl<B: Backend, Batch: Dim, X: Dim, Y: Dim, Z: Dim>
Matmul<NamedTensor<B, (Batch, Y, Z)>, NamedTensor<B, (Batch, X, Z)>>
for NamedTensor<B, (Batch, X, Y)>
{
fn matmul(self, rhs: NamedTensor<B, (Batch, Y, Z)>) -> NamedTensor<B, (Batch, X, Z)> {
NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
}
}
impl<B: Backend, Batch1: Dim, Batch2: Dim, X: Dim, Y: Dim, Z: Dim>
Matmul<NamedTensor<B, (Batch1, Batch2, Y, Z)>, NamedTensor<B, (Batch1, Batch2, X, Z)>>
for NamedTensor<B, (Batch1, Batch2, X, Y)>
{
fn matmul(
self,
rhs: NamedTensor<B, (Batch1, Batch2, Y, Z)>,
) -> NamedTensor<B, (Batch1, Batch2, X, Z)> {
NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
}
}

View File

@@ -0,0 +1,7 @@
mod base;
mod dims;
mod matmul;
mod swap_dims;
pub use base::*;
pub use dims::*;

View File

@@ -0,0 +1,62 @@
use crate::backend::Backend;
use crate::{Dim, NamedDims, NamedTensor, Tensor};
pub trait SwapDims<N, const D1: usize, const D2: usize> {
fn swap_dims(self) -> N;
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Swap two dimensions.
pub fn swap_dims<ND2, const D1: usize, const D2: usize>(self) -> NamedTensor<B, ND2>
where
ND2: NamedDims<B, Tensor = Tensor<B, D>>,
Self: SwapDims<NamedTensor<B, ND2>, D1, D2>,
{
SwapDims::swap_dims(self)
}
}
macro_rules! generate_permute {
(2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
for NamedTensor<B, (D1, D2)>
{
fn swap_dims(self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};
(3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
for NamedTensor<B, (D1, D2, D3)>
{
fn swap_dims(self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};
(4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
SwapDims<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
{
fn swap_dims(self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};
}
generate_permute!(2 => (D2, D1), (0, 1));
generate_permute!(3 => (D2, D1, D3), (0, 1));
generate_permute!(3 => (D3, D2, D1), (0, 2));
generate_permute!(3 => (D1, D3, D2), (1, 2));
generate_permute!(4 => (D2, D1, D3, D4), (0, 1));
generate_permute!(4 => (D3, D2, D1, D4), (0, 2));
generate_permute!(4 => (D4, D2, D3, D1), (0, 3));
generate_permute!(4 => (D1, D3, D2, D4), (1, 2));
generate_permute!(4 => (D1, D4, D3, D2), (1, 3));

View File

@@ -0,0 +1,52 @@
use crate::{Tensor, TensorPrimitive, backend::Backend};
use burn_backend::tensor::quantization;
// We re-export those types.
pub use burn_backend::{QTensorPrimitive, quantization::*};
/// The tensor quantization parameters.
pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
/// The observed input calibration range.
#[derive(Clone, Debug)]
pub struct CalibrationRange<B: Backend> {
/// Minimum observed value(s).
pub min: Tensor<B, 1>,
/// Maximum observed value(s).
pub max: Tensor<B, 1>,
}
/// Compute the quantization range mapping.
pub fn compute_range<B: Backend, const D: usize>(
scheme: &QuantScheme,
tensor: &Tensor<B, D>,
calibration: &Calibration,
) -> CalibrationRange<B> {
let (min, max) = match &tensor.primitive {
TensorPrimitive::Float(tensor) => {
quantization::compute_range::<B>(scheme, tensor.clone(), calibration)
}
TensorPrimitive::QFloat(_) => unreachable!(),
};
CalibrationRange {
min: Tensor::from_primitive(TensorPrimitive::Float(min)),
max: Tensor::from_primitive(TensorPrimitive::Float(max)),
}
}
/// Compute the quantization parameters.
pub fn compute_q_params<B: Backend>(
scheme: &QuantScheme,
range: CalibrationRange<B>,
) -> QuantizationParameters<B> {
match (range.min.primitive, range.max.primitive) {
(TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => {
let qparams = quantization::compute_q_params::<B>(scheme, min, max);
QuantizationParameters {
scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)),
}
}
_ => unreachable!(),
}
}

View File

@@ -0,0 +1,106 @@
use super::{Tensor, backend::Backend};
use colored::*;
/// Checks the closeness of two tensors and prints the results.
///
/// Compares tensors by checking the absolute difference between each element.
/// Prints the percentage of elements within specified tolerances.
///
/// # Arguments
///
/// * `output` - The output tensor.
/// * `expected` - The expected tensor.
///
/// # Example
///
/// ```no_run
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{check_closeness, Tensor};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor1 = Tensor::<B, 1>::from_floats(
/// [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
/// &device,
/// );
/// let tensor2 = Tensor::<B, 1>::from_floats(
/// [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
/// &device,
/// );
/// check_closeness(&tensor1, &tensor2);
///}
/// ```
///
/// # Output
///
/// ```text
/// Tensor Closeness Check Results:
/// ===============================
/// Epsilon: 1e-1
/// Close elements: 10/10 (100.00%)
/// [PASS] All elements are within tolerance
///
/// Epsilon: 1e-2
/// Close elements: 10/10 (100.00%)
/// [PASS] All elements are within tolerance
///
/// Epsilon: 1e-3
/// Close elements: 9/10 (90.00%)
/// [WARN] Most elements are within tolerance
///
/// Epsilon: 1e-4
/// Close elements: 6/10 (60.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-5
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-6
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-7
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-8
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Closeness check complete.
/// ```
pub fn check_closeness<B: Backend, const D: usize>(output: &Tensor<B, D>, expected: &Tensor<B, D>) {
println!("{}", "Tensor Closeness Check Results:".bold());
println!("===============================");
for epsilon in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8].iter() {
println!("{} {:e}", "Epsilon:".bold(), epsilon);
let close = output
.clone()
.is_close(expected.clone(), Some(*epsilon), Some(*epsilon));
let data = close.clone().into_data();
let num_elements = data.num_elements();
// Count the number of elements that are close (true)
let count = data.iter::<bool>().filter(|x| *x).count();
let percentage = (count as f64 / num_elements as f64) * 100.0;
println!(" Close elements: {count}/{num_elements} ({percentage:.2}%)");
if percentage == 100.0 {
println!(" {} All elements are within tolerance", "[PASS]".green());
} else if percentage >= 90.0 {
println!(" {} Most elements are within tolerance", "[WARN]".yellow());
} else {
println!(" {} Significant differences detected", "[FAIL]".red());
}
println!();
}
println!("{}", "Closeness check complete.".bold());
}

View File

@@ -0,0 +1,74 @@
use crate::{Tensor, backend::Backend};
use burn_backend::tensor::Int;
pub fn var<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let mean = tensor.clone().mean_dim(dim);
var_with_mean(tensor, mean, dim)
}
pub fn var_with_mean<B: Backend, const D: usize>(
tensor: Tensor<B, D>,
mean: Tensor<B, D>,
dim: usize,
) -> Tensor<B, D> {
let n = tensor.shape()[dim] - 1;
var_with_mean_n(tensor, mean, dim, n)
}
pub fn var_bias<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let mean = tensor.clone().mean_dim(dim);
var_with_mean_bias(tensor, mean, dim)
}
pub fn var_with_mean_bias<B: Backend, const D: usize>(
tensor: Tensor<B, D>,
mean: Tensor<B, D>,
dim: usize,
) -> Tensor<B, D> {
let n = tensor.shape()[dim];
var_with_mean_n(tensor, mean, dim, n)
}
pub fn var_with_mean_n<B: Backend, const D: usize>(
tensor: Tensor<B, D>,
mean: Tensor<B, D>,
dim: usize,
n: usize,
) -> Tensor<B, D> {
tensor.sub(mean).square().sum_dim(dim).div_scalar(n as f32)
}
pub fn median<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let total_elem_numbers = tensor.dims()[dim];
let sorted_tensor = tensor.sort(dim);
// Following the PyTorch behavior:
// - Odd count: the median
// - Even count: the lower of the two median elements
//
// Example:
// - 5 elements: (5 - 1) / 2 = 4 / 2 = 2
// - 4 elements: (4 - 1) / 2 = 3 / 2 = 1
let median_index = (total_elem_numbers - 1) / 2;
sorted_tensor.narrow(dim, median_index, 1)
}
pub fn median_with_indices<B: Backend, const D: usize>(
tensor: Tensor<B, D>,
dim: usize,
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
let total_elem_numbers = tensor.dims()[dim];
let (sorted_tensor, indices) = tensor.sort_with_indices(dim);
// Following the PyTorch behavior:
// - Odd count: the median
// - Even count: the lower of the two median elements
//
// Example:
// - 5 elements: (5 - 1) / 2 = 4 / 2 = 2
// - 4 elements: (4 - 1) / 2 = 3 / 2 = 1
let median_index = (total_elem_numbers - 1) / 2;
let median_values = sorted_tensor.narrow(dim, median_index, 1);
let median_indices = indices.narrow(dim, median_index, 1);
(median_values, median_indices)
}