feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Tensor library with user-friendly APIs and automatic differentiation support"
|
||||
documentation = "https://docs.rs/burn-tensor"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-tensor"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
doc = ["default"]
|
||||
std = [
|
||||
"num-traits/std",
|
||||
"burn-std/std",
|
||||
"burn-backend/std",
|
||||
"colored",
|
||||
]
|
||||
tracing = [
|
||||
"burn-std/tracing",
|
||||
"burn-backend/tracing",
|
||||
]
|
||||
|
||||
cubecl = ["burn-std/cubecl", "burn-backend/cubecl"]
|
||||
cubecl-cuda = ["burn-backend/cubecl-cuda"]
|
||||
cubecl-hip = ["burn-backend/cubecl-hip"]
|
||||
cubecl-wgpu = ["burn-backend/cubecl-wgpu"]
|
||||
cubecl-cpu = ["burn-backend/cubecl-cpu"]
|
||||
experimental-named-tensor = []
|
||||
|
||||
[dependencies]
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
colored = { workspace = true, optional = true }
|
||||
derive-new = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# Device
|
||||
hashbrown = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
portable-atomic-util = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-tensor/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-tensor/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,12 @@
|
||||
# Burn Tensor
|
||||
|
||||
> [Burn](https://github.com/tracel-ai/burn) Tensor Library
|
||||
|
||||
[](https://crates.io/crates/burn-tensor)
|
||||
[](https://github.com/tracel-ai/burn-tensor/blob/master/README.md)
|
||||
|
||||
This library provides the core abstractions required to run tensor operations with Burn.
|
||||
|
||||
`Tensor`s are generic over the backend to allow users to perform operations using different
|
||||
`Backend` implementations. Burn's tensors also support auto-differentiation thanks to the
|
||||
`AutodiffBackend` trait.
|
||||
@@ -0,0 +1 @@
|
||||
../../docs/katex-header.html
|
||||
@@ -0,0 +1,465 @@
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use burn_backend::{Backend, Device, DeviceId, DeviceOps};
|
||||
use burn_std::stub::RwLock;
|
||||
use burn_std::{DType, FloatDType, IntDType};
|
||||
|
||||
#[cfg(target_has_atomic = "ptr")]
|
||||
use alloc::sync::Arc;
|
||||
|
||||
#[cfg(not(target_has_atomic = "ptr"))]
|
||||
use portable_atomic_util::Arc;
|
||||
use thiserror::Error;
|
||||
|
||||
use core::any::TypeId;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::collections::HashMap;
|
||||
#[cfg(feature = "std")]
|
||||
use std::sync::LazyLock;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use hashbrown::HashMap;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use spin::Lazy as LazyLock;
|
||||
|
||||
/// Policy controlling default device behavior.
|
||||
///
|
||||
/// This includes default data types used for tensor creation.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub(crate) struct DevicePolicy {
|
||||
/// Default floating-point data type for tensor creation.
|
||||
float_dtype: Option<FloatDType>,
|
||||
/// Default integer data type for tensor creation.
|
||||
int_dtype: Option<IntDType>,
|
||||
}
|
||||
|
||||
impl DevicePolicy {
|
||||
/// Returns the default floating-point data type used for tensor creation.
|
||||
pub(crate) fn float_dtype(&self) -> Option<FloatDType> {
|
||||
self.float_dtype
|
||||
}
|
||||
|
||||
/// Returns the default integer data type used for tensor creation.
|
||||
pub(crate) fn int_dtype(&self) -> Option<IntDType> {
|
||||
self.int_dtype
|
||||
}
|
||||
|
||||
/// Sets the default floating-point data type.
|
||||
pub(crate) fn set_float_dtype(&mut self, dtype: FloatDType) {
|
||||
self.float_dtype = Some(dtype);
|
||||
}
|
||||
|
||||
/// Sets the default integer data type.
|
||||
pub(crate) fn set_int_dtype(&mut self, dtype: IntDType) {
|
||||
self.int_dtype = Some(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
/// Key for the registry: physical device type + device id
|
||||
type RegistryKey = (DeviceId, TypeId);
|
||||
|
||||
/// Global registry mapping devices to their policies.
|
||||
static REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<DevicePolicy>>>> =
|
||||
LazyLock::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
/// Device policy management for controlling default tensor creation behavior.
|
||||
///
|
||||
/// # Policy Semantics
|
||||
///
|
||||
/// Device policies use snapshot semantics: when you retrieve a policy with
|
||||
/// [`get_device_policy`], you get an immutable snapshot of the current configuration.
|
||||
/// Updates to the policy (via [`set_default_dtypes`], [`set_default_float_dtype`], etc.)
|
||||
/// only affect future policy retrievals, not existing references.
|
||||
///
|
||||
/// This is intended for the common case where policies are set once during
|
||||
/// initialization and then read frequently during tensor creation.
|
||||
struct DevicePolicyRegistry;
|
||||
|
||||
impl DevicePolicyRegistry {
|
||||
/// Get the policy for a physical device type and device id.
|
||||
///
|
||||
/// If no policy exists yet, a default one is created and stored.
|
||||
fn get<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
|
||||
let key = Self::key(device);
|
||||
|
||||
if let Some(policy) = REGISTRY.read().unwrap().get(&key) {
|
||||
return Arc::clone(policy);
|
||||
}
|
||||
|
||||
let mut map = REGISTRY.write().unwrap();
|
||||
Arc::clone(
|
||||
map.entry(key)
|
||||
.or_insert_with(|| Arc::new(DevicePolicy::default())),
|
||||
)
|
||||
}
|
||||
|
||||
/// Mutate the policy for a given device.
|
||||
fn update<D: DeviceOps>(device: &D, update_fn: impl FnOnce(&mut DevicePolicy)) {
|
||||
let key = Self::key(device);
|
||||
let mut map = REGISTRY.write().unwrap();
|
||||
|
||||
let policy = map
|
||||
.entry(key)
|
||||
.or_insert_with(|| Arc::new(DevicePolicy::default()));
|
||||
|
||||
// Update the policy
|
||||
let policy_mut = Arc::make_mut(policy);
|
||||
update_fn(policy_mut);
|
||||
}
|
||||
|
||||
/// Returns the device registry key.
|
||||
fn key<D: Device>(device: &D) -> RegistryKey {
|
||||
(device.to_id(), TypeId::of::<D>())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [`device`'s policy](DevicePolicy).
|
||||
///
|
||||
/// Returns an immutable snapshot of the device's current policy. If the policy
|
||||
/// is updated after retrieval, this snapshot will not reflect those changes.
|
||||
pub(crate) fn get_device_policy<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
|
||||
DevicePolicyRegistry::get(device)
|
||||
}
|
||||
|
||||
/// Errors that can occur during device-related operations.
|
||||
///
|
||||
/// This covers errors related to hardware capability mismatches, such as
|
||||
/// requesting a data type not supported by the device, and configuration
|
||||
/// errors like attempting to change a policy in an invalid context.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DeviceError {
|
||||
/// Unsupported data type by the device.
|
||||
#[error("Device {device} does not support the requested data type {dtype:?}")]
|
||||
UnsupportedDType {
|
||||
/// The string representation of the device.
|
||||
device: String,
|
||||
/// The data type that caused the error.
|
||||
dtype: DType,
|
||||
},
|
||||
// TODO: `InvalidContext` if a device policy cannot be changed after init / during training / etc.
|
||||
}
|
||||
|
||||
impl DeviceError {
|
||||
/// Helper to create a [`DeviceError::UnsupportedDType`] from any device.
|
||||
pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {
|
||||
Self::UnsupportedDType {
|
||||
device: format!("{device:?}"),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_dtype_support<B: Backend>(
|
||||
device: &B::Device,
|
||||
dtype: impl Into<DType>,
|
||||
) -> Result<(), DeviceError> {
|
||||
let dtype = dtype.into();
|
||||
// Default dtypes should have `DTypeUsage::general()`. Types restricted to specialized
|
||||
// operations should not be used as default.
|
||||
if B::supports_dtype(device, dtype) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(DeviceError::unsupported_dtype(device, dtype))
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the default data types for the device.
|
||||
///
|
||||
/// This updates the device's default data types used for tensor creation.
|
||||
/// The policy should typically be set once during initialization and then
|
||||
/// remains global for all subsequent operations on that device.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{DType, Int, Tensor, set_default_dtypes};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
///
|
||||
/// // Update the device policy
|
||||
/// set_default_dtypes::<B>(&device, DType::F16, DType::I32);
|
||||
///
|
||||
/// // All float tensors created after this will use F16 by default
|
||||
/// let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
|
||||
/// // All int tensors created after this will use I32 default
|
||||
/// let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn set_default_dtypes<B: Backend>(
|
||||
device: &B::Device,
|
||||
float_dtype: impl Into<FloatDType>,
|
||||
int_dtype: impl Into<IntDType>,
|
||||
) -> Result<(), DeviceError> {
|
||||
let float_dtype = float_dtype.into();
|
||||
let int_dtype = int_dtype.into();
|
||||
check_dtype_support::<B>(device, float_dtype)?;
|
||||
check_dtype_support::<B>(device, int_dtype)?;
|
||||
|
||||
set_default_dtypes_unchecked(device, float_dtype, int_dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the default floating-point data type for the device.
|
||||
///
|
||||
/// This updates the device's default data types used for tensor creation.
|
||||
/// The policy should typically be set once during initialization and then
|
||||
/// remains global for all subsequent operations on that device.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{DType, Tensor, set_default_float_dtype};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
///
|
||||
/// // Update the device policy
|
||||
/// set_default_float_dtype::<B>(&device, DType::F16);
|
||||
///
|
||||
/// // All float tensors created after this will use F16 by default
|
||||
/// let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn set_default_float_dtype<B: Backend>(
|
||||
device: &B::Device,
|
||||
dtype: impl Into<FloatDType>,
|
||||
) -> Result<(), DeviceError> {
|
||||
let dtype = dtype.into();
|
||||
check_dtype_support::<B>(device, dtype)?;
|
||||
|
||||
set_default_float_dtype_unchecked(device, dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the default integer data type for the device.
|
||||
///
|
||||
/// This updates the device's default data types used for tensor creation.
|
||||
/// The policy should typically be set once during initialization and then
|
||||
/// remains global for all subsequent operations on that device.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{DType, Int, Tensor, set_default_int_dtype};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
///
|
||||
/// // Update the device policy
|
||||
/// set_default_int_dtype::<B>(&device, DType::I32);
|
||||
///
|
||||
/// // All int tensors created after this will use I32 default
|
||||
/// let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn set_default_int_dtype<B: Backend>(
|
||||
device: &B::Device,
|
||||
dtype: impl Into<IntDType>,
|
||||
) -> Result<(), DeviceError> {
|
||||
let dtype = dtype.into();
|
||||
check_dtype_support::<B>(device, dtype)?;
|
||||
|
||||
set_default_int_dtype_unchecked(device, dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Unchecked versions
|
||||
fn set_default_dtypes_unchecked<D: DeviceOps>(
|
||||
device: &D,
|
||||
float_dtype: FloatDType,
|
||||
int_dtype: IntDType,
|
||||
) {
|
||||
DevicePolicyRegistry::update(device, |p| {
|
||||
p.set_float_dtype(float_dtype);
|
||||
p.set_int_dtype(int_dtype);
|
||||
});
|
||||
}
|
||||
|
||||
fn set_default_float_dtype_unchecked<D: DeviceOps>(device: &D, dtype: FloatDType) {
|
||||
DevicePolicyRegistry::update(device, |p| {
|
||||
p.set_float_dtype(dtype);
|
||||
});
|
||||
}
|
||||
|
||||
fn set_default_int_dtype_unchecked<D: DeviceOps>(device: &D, dtype: IntDType) {
|
||||
DevicePolicyRegistry::update(device, |p| {
|
||||
p.set_int_dtype(dtype);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use serial_test::serial;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn clear_registry() {
|
||||
REGISTRY.write().unwrap().clear();
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, new)]
|
||||
pub struct TestDeviceA {
|
||||
index: u32,
|
||||
}
|
||||
|
||||
impl Device for TestDeviceA {
|
||||
fn from_id(device_id: DeviceId) -> Self {
|
||||
Self {
|
||||
index: device_id.index_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
DeviceId {
|
||||
type_id: 0,
|
||||
index_id: self.index,
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for TestDeviceA {}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, new)]
|
||||
pub struct TestDeviceB {
|
||||
index: u32,
|
||||
}
|
||||
|
||||
impl Device for TestDeviceB {
|
||||
fn from_id(device_id: DeviceId) -> Self {
|
||||
Self {
|
||||
index: device_id.index_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
DeviceId {
|
||||
type_id: 0,
|
||||
index_id: self.index,
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for TestDeviceB {}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn default_policy_is_created_and_shared() {
|
||||
clear_registry(); // reset registry for each test
|
||||
|
||||
let device = TestDeviceA::new(0);
|
||||
|
||||
let p1 = get_device_policy(&device);
|
||||
let p2 = get_device_policy(&device);
|
||||
|
||||
assert!(Arc::ptr_eq(&p1, &p2));
|
||||
// Not explicitly set
|
||||
assert!(p1.float_dtype().is_none());
|
||||
assert!(p1.int_dtype().is_none());
|
||||
assert!(p2.float_dtype().is_none());
|
||||
assert!(p2.int_dtype().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn updated_policy_is_shared() {
|
||||
clear_registry(); // reset registry for each test
|
||||
|
||||
let device = TestDeviceA::new(0);
|
||||
|
||||
// The device policy is meant to be set once at initialization
|
||||
set_default_dtypes_unchecked(&device, FloatDType::BF16, IntDType::I32);
|
||||
let p1 = get_device_policy(&device);
|
||||
let p2 = get_device_policy(&device);
|
||||
|
||||
assert!(Arc::ptr_eq(&p1, &p2));
|
||||
assert_eq!(p1.float_dtype(), Some(FloatDType::BF16));
|
||||
assert_eq!(p1.int_dtype(), Some(IntDType::I32));
|
||||
assert_eq!(p2.float_dtype(), Some(FloatDType::BF16));
|
||||
assert_eq!(p2.int_dtype(), Some(IntDType::I32));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn policy_is_device_id_specific() {
|
||||
clear_registry(); // reset registry for each test
|
||||
|
||||
let d1 = TestDeviceA::new(0);
|
||||
let d2 = TestDeviceA::new(1);
|
||||
|
||||
set_default_float_dtype_unchecked(&d1, FloatDType::F16);
|
||||
|
||||
let p1 = get_device_policy(&d1);
|
||||
let p2 = get_device_policy(&d2);
|
||||
|
||||
assert!(!Arc::ptr_eq(&p1, &p2));
|
||||
assert_eq!(p1.float_dtype(), Some(FloatDType::F16));
|
||||
assert!(p1.int_dtype().is_none());
|
||||
assert!(p2.float_dtype().is_none());
|
||||
assert!(p2.int_dtype().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn policy_is_device_type_specific() {
|
||||
clear_registry(); // reset registry for each test
|
||||
|
||||
let d1 = TestDeviceA::new(0);
|
||||
let d2 = TestDeviceB::new(0);
|
||||
|
||||
set_default_float_dtype_unchecked(&d2, FloatDType::F16);
|
||||
|
||||
let p1 = get_device_policy(&d1);
|
||||
let p2 = get_device_policy(&d2);
|
||||
|
||||
assert!(p1.float_dtype().is_none());
|
||||
assert!(p1.int_dtype().is_none());
|
||||
assert_eq!(p2.float_dtype(), Some(FloatDType::F16));
|
||||
assert!(p2.int_dtype().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn updating_policy_should_not_affect_snapshot() {
|
||||
clear_registry(); // reset registry for each test
|
||||
|
||||
// The device policy is meant to be set once at initialization
|
||||
let device = TestDeviceA::new(0);
|
||||
let before = get_device_policy(&device);
|
||||
|
||||
set_default_float_dtype_unchecked(&device, FloatDType::BF16);
|
||||
|
||||
let after = get_device_policy(&device);
|
||||
|
||||
assert!(!Arc::ptr_eq(&before, &after));
|
||||
assert_eq!(after.float_dtype(), Some(FloatDType::BF16));
|
||||
assert!(before.float_dtype().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn set_default_dtypes_overwrites_fields() {
|
||||
clear_registry(); // reset registry for each test
|
||||
|
||||
let device = TestDeviceA::new(0);
|
||||
|
||||
set_default_dtypes_unchecked(&device, FloatDType::F16, IntDType::I64);
|
||||
|
||||
let policy = get_device_policy(&device);
|
||||
|
||||
assert_eq!(policy.float_dtype(), Some(FloatDType::F16));
|
||||
assert_eq!(policy.int_dtype(), Some(IntDType::I64));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! This library provides the core abstractions required to run tensor operations with Burn.
|
||||
//! `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations.
|
||||
//! Burn's tensors also support auto-differentiation thanks to the `AutodiffBackend` trait.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
mod tensor;
|
||||
|
||||
pub(crate) use tensor::check::macros::check;
|
||||
pub use tensor::*;
|
||||
|
||||
// Re-exported types
|
||||
pub use burn_backend::{AllocationProperty, Bytes, StreamId, bf16, f16, read_sync, try_read_sync};
|
||||
|
||||
mod device;
|
||||
pub use device::*;
|
||||
@@ -0,0 +1,647 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::{Tensor, TensorPrimitive, check, s};
|
||||
|
||||
/// Applies the rectified linear unit function element-wise
|
||||
/// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).
|
||||
///
|
||||
#[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")]
|
||||
#[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")]
|
||||
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.relu()
|
||||
}
|
||||
|
||||
/// Applies the leaky rectified linear unit function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\)
|
||||
$$
|
||||
|
||||
or
|
||||
|
||||
$$
|
||||
\text{LeakyReLU}(x) =
|
||||
\begin{cases}
|
||||
x & \text{if } x \geq 0 \newline
|
||||
\text{negative\\_slope} \cdot x & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`"
|
||||
)]
|
||||
pub fn leaky_relu<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
negative_slope: f64,
|
||||
) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
|
||||
tensor.primitive.tensor(),
|
||||
negative_slope.into(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies the Gaussian Error Linear Units function as described in the paper
|
||||
/// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{GELU}(x)
|
||||
= x \cdot \Phi(x)
|
||||
= x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
|
||||
$$
|
||||
|
||||
where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = r#"
|
||||
`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`
|
||||
|
||||
where `Φ(x)` is the cumulative distribution function for the Gaussian distribution.
|
||||
"#
|
||||
)]
|
||||
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
|
||||
}
|
||||
|
||||
/// Applies the tanh-based approximate GELU function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{GELU\_approx}(x)
|
||||
= \frac{x}{2}\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`GELU_approx(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`"
|
||||
)]
|
||||
pub fn gelu_approximate<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
/// sqrt(2/π) precomputed as FRAC_2_SQRT_PI * FRAC_1_SQRT_2
|
||||
const SQRT_2_OVER_PI: f64 =
|
||||
core::f64::consts::FRAC_2_SQRT_PI * core::f64::consts::FRAC_1_SQRT_2;
|
||||
|
||||
let x = tensor;
|
||||
let inner = x.clone() + x.clone().powf_scalar(3.0) * 0.044715;
|
||||
let inner = inner * SQRT_2_OVER_PI;
|
||||
(x.clone() * (inner.tanh() + 1)) * 0.5
|
||||
}
|
||||
|
||||
/// Applies Parametric ReLu activation function as described in the paper
|
||||
/// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).
|
||||
///
|
||||
/// - The tensor is assumed to be of shape `[batch_size, channels, ...]`.
|
||||
/// - `alpha` is assumed to be of shape `[channels]` or `[1]`.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\)
|
||||
$$
|
||||
|
||||
or
|
||||
|
||||
$$
|
||||
\text{PReLU}(x) =
|
||||
\begin{cases}
|
||||
x & \text{if } x \geq 0 \newline
|
||||
\alpha x & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")]
|
||||
pub fn prelu<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
alpha: Tensor<B, 1>,
|
||||
) -> Tensor<B, D> {
|
||||
check!(TensorCheck::check_prelu_shape::<D>(
|
||||
&tensor.shape(),
|
||||
&alpha.shape()
|
||||
));
|
||||
|
||||
let weight = if alpha.dims()[0] == 1 {
|
||||
// if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
|
||||
alpha.reshape([1; D])
|
||||
} else {
|
||||
// D>=2 because the case where D==1 and num_weights >1 is handled by check function
|
||||
// there is more than 1 weight and rank is more than 2
|
||||
let num_weights = alpha.dims()[0];
|
||||
let mut s = [1; D];
|
||||
s[1] = num_weights;
|
||||
// reshape the weights to (1, channels,1 ...)
|
||||
alpha.reshape(s)
|
||||
};
|
||||
|
||||
Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
|
||||
tensor.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies the softmax function on the input tensor along the given dimension.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")]
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `dim`: the dimension along which Softmax will be computed.
|
||||
///
|
||||
/// # Panics
|
||||
/// - If `dim` is outside [0, D)
|
||||
pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::dim_ops::<D>("softmax", dim));
|
||||
|
||||
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
|
||||
let tensor = tensor.exp();
|
||||
let tensor_tmp = tensor.clone().sum_dim(dim);
|
||||
|
||||
tensor.div(tensor_tmp)
|
||||
}
|
||||
|
||||
/// Applies the softmin function on the input tensor along the given dimension.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")]
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `dim`: the dimension along which Softmax will be computed.
|
||||
///
|
||||
/// # Panics
|
||||
/// - If `dim` is outside [0, D)
|
||||
pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::dim_ops::<D>("softmin", dim));
|
||||
softmax(tensor.neg(), dim)
|
||||
}
|
||||
|
||||
/// Applies the SoftPlus function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")]
|
||||
///
|
||||
/// The SoftPlus function is a smooth approximation of the ReLU function.
|
||||
pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
|
||||
let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
|
||||
tensor.div_scalar(beta)
|
||||
}
|
||||
|
||||
/// Applies the "quiet softmax" function on the input tensor along the given dimension.
|
||||
///
|
||||
/// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html).
|
||||
///
|
||||
/// This function is similar to the softmax function, but it allows for "no selection" when
|
||||
/// all the outputs are close to zero.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`"
|
||||
)]
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `dim`: the dimension along which Softmax will be computed.
|
||||
///
|
||||
/// # Panics
|
||||
/// - If `dim` is outside [0, D)
|
||||
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::dim_ops::<D>("softmax", dim));
|
||||
|
||||
let max_vals = tensor.clone().detach().max_dim(dim);
|
||||
let exp_x = (tensor - max_vals.clone()).exp();
|
||||
let sum_exp = exp_x.clone().sum_dim(dim);
|
||||
|
||||
exp_x.div(sum_exp + max_vals.neg().exp())
|
||||
}
|
||||
|
||||
/// Applies the log softmax function on the input tensor along the given dimension.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{log\\_softmax}\(x_i\)
|
||||
= \log\left(\text{softmax}\(x_i\)\right)
|
||||
= \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`"
|
||||
)]
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `dim`: the dimension along which Softmax will be computed.
|
||||
///
|
||||
/// # Panics
|
||||
/// - If `dim` is outside [0, D)
|
||||
pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::dim_ops::<D>("log softmax", dim));
|
||||
|
||||
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
|
||||
let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();
|
||||
|
||||
tensor.sub(tensor_tmp)
|
||||
}
|
||||
|
||||
/// Applies the sigmoid function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{sigmoid}\(x\)
|
||||
= \sigma(x)
|
||||
= \frac{1}{1 + \exp(-x)}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")]
|
||||
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
|
||||
tensor.primitive.tensor(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies the hard sigmoid function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta))
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")]
|
||||
pub fn hard_sigmoid<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
|
||||
tensor.primitive.tensor(),
|
||||
alpha.into(),
|
||||
beta.into(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies the log sigmoid function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")]
|
||||
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
|
||||
tensor.primitive.tensor(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies the SiLU function (also known as the swish function) element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")]
|
||||
pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.clone().mul(sigmoid(tensor))
|
||||
}
|
||||
|
||||
/// Applies the hard swish function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5))
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`"
|
||||
)]
|
||||
pub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5))
|
||||
}
|
||||
|
||||
/// Applies the Mish function as described in the paper in
|
||||
/// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{Mish}\(x\)
|
||||
= x \cdot \tanh(\text{Softplus}(x))
|
||||
= \tanh\left(\log\(1 + \exp\(x\)\)\right)
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`"
|
||||
)]
|
||||
pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.clone().mul(softplus(tensor, 1.0).tanh())
|
||||
}
|
||||
|
||||
/// Applies the tanh function element-wise.
|
||||
pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.tanh()
|
||||
}
|
||||
|
||||
/// Applies the Exponential Linear Unit function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{ELU}\(x\) =
|
||||
\begin{cases}
|
||||
x & \text{if } x > 0 \newline
|
||||
\alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`f(x) =`\n- `x for x > 0`\n- `alpha * (exp(x) - 1) for x <= 0`"
|
||||
)]
|
||||
pub fn elu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
|
||||
let mask = tensor.clone().lower_equal_elem(0);
|
||||
let scaled = tensor.clone().exp().sub_scalar(1).mul_scalar(alpha);
|
||||
tensor.mask_where(mask, scaled)
|
||||
}
|
||||
|
||||
/// Applies the Continuously Differentiable Exponential Linear Unit function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{CELU}(x) =
|
||||
\begin{cases}
|
||||
x & \text{if } x \geq 0 \newline
|
||||
\alpha \cdot \left(\exp\left(\frac{x}{\alpha}\right) - 1\right) & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`"
|
||||
)]
|
||||
///
|
||||
/// See also [CELU](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html)
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `alpha`: scaling parameter for the negative part.
|
||||
pub fn celu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
|
||||
let mask = tensor.clone().lower_equal_elem(0);
|
||||
let scaled = tensor
|
||||
.clone()
|
||||
.div_scalar(alpha)
|
||||
.exp()
|
||||
.sub_scalar(1)
|
||||
.mul_scalar(alpha);
|
||||
tensor.mask_where(mask, scaled)
|
||||
}
|
||||
|
||||
/// Applies the Scaled Exponential Linear Unit function element-wise
|
||||
/// as described in the paper [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{SELU}\(x\) = \gamma \cdot
|
||||
\begin{cases}
|
||||
x & \text{if } x > 0 \newline
|
||||
\alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
|
||||
\end{cases}
|
||||
$$
|
||||
|
||||
where $\alpha \approx 1.6733$ and $\gamma \approx 1.0507$.
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0`"
|
||||
)]
|
||||
pub fn selu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// Constants from the SELU paper / ONNX spec
|
||||
const ALPHA: f64 = 1.6732632423543772848170429916717_f64;
|
||||
const GAMMA: f64 = 1.0507009873554804934193349852946_f64;
|
||||
|
||||
let mask = tensor.clone().greater_equal_elem(0.0);
|
||||
let positive = tensor.clone().mul_scalar(GAMMA);
|
||||
let negative = tensor.exp().sub_scalar(1.0).mul_scalar(ALPHA * GAMMA);
|
||||
|
||||
negative.mask_where(mask, positive)
|
||||
}
|
||||
|
||||
/// Applies the thresholded rectified linear unit function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{ThresholdedReLU}(x) =
|
||||
\begin{cases}
|
||||
x & \text{if } x > \alpha \newline
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`f(x) =`\n- `x if x > alpha`\n- `0 otherwise`")]
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `alpha`: threshold value (default in ONNX is 1.0).
|
||||
pub fn thresholded_relu<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
alpha: f64,
|
||||
) -> Tensor<B, D> {
|
||||
let mask = tensor.clone().lower_equal_elem(alpha);
|
||||
tensor.mask_fill(mask, 0)
|
||||
}
|
||||
|
||||
/// Applies the gated linear unit function.
|
||||
///
|
||||
/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
|
||||
///
|
||||
/// **Note**:
|
||||
/// * The size of the input tensor along `dim` must be divisible by 2.
|
||||
///
|
||||
/// ### Arguments
|
||||
/// * `tensor` - The input tensor.
|
||||
///
|
||||
/// ### Returns
|
||||
/// * A tensor with the same shape as the input, except the size along `dim` is halved.
|
||||
pub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
// TODO: Handle negative indices with AsIndex for compatibility with Pytorch nn.GLU.
|
||||
|
||||
assert!(
|
||||
tensor.dims()[dim].is_multiple_of(2),
|
||||
"Input tensor along dimension {dim} must have an even size. N is divisible by 2."
|
||||
);
|
||||
let new_len = tensor.dims()[dim] / 2;
|
||||
|
||||
let a = tensor.clone().slice_dim(dim, s![0..new_len]);
|
||||
let b = tensor.slice_dim(dim, s![new_len..new_len * 2]);
|
||||
|
||||
a.mul(sigmoid(b))
|
||||
}
|
||||
|
||||
/// Applies the Softsign function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{softsign}(x) = \frac{x}{1 + |x|}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(not(doc), doc = "`softsign(x_i) = x_i / (1 + |x_i|)`")]
|
||||
pub fn softsign<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.clone().div(tensor.abs() + 1)
|
||||
}
|
||||
|
||||
/// Applies the HardShrink function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{hard\_shrink}(x) =
|
||||
\begin{cases}
|
||||
x & \text{if } x > \lambda \newline
|
||||
x & \text{if } x < -\lambda \newline
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`"
|
||||
)]
|
||||
/// # Arguments
|
||||
/// - `lambda`: the lambda value for the Hard Shrink formulation. Default is 0.5.
|
||||
pub fn hard_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
|
||||
let mask = tensor.clone().abs().lower_equal_elem(lambda);
|
||||
tensor.mask_fill(mask, 0)
|
||||
}
|
||||
|
||||
/// Applies the SoftShrink function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{soft\_shrink}(x) =
|
||||
\begin{cases}
|
||||
x - \lambda & \text{if } x > \lambda \newline
|
||||
x + \lambda & \text{if } x < -\lambda \newline
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`"
|
||||
)]
|
||||
/// # Arguments
|
||||
/// - `lambda`: the lambda value for the Soft Shrink formulation. Default is 0.5.
|
||||
pub fn soft_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
|
||||
shrink(tensor, lambda, lambda)
|
||||
}
|
||||
|
||||
/// Applies the Shrink function element-wise.
|
||||
///
|
||||
#[cfg_attr(
|
||||
doc,
|
||||
doc = r#"
|
||||
$$
|
||||
\text{shrink}(x) =
|
||||
\begin{cases}
|
||||
x - \text{bias} & \text{if } x > \lambda \newline
|
||||
x + \text{bias} & \text{if } x < -\lambda \newline
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
not(doc),
|
||||
doc = "`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`"
|
||||
)]
|
||||
/// # Arguments
|
||||
/// - `lambda`: the lambda value for the Shrink formulation.
|
||||
/// - `bias`: the bias value for the Shrink formulation.
|
||||
pub fn shrink<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
lambda: f64,
|
||||
bias: f64,
|
||||
) -> Tensor<B, D> {
|
||||
let abs_tensor = tensor.clone().abs();
|
||||
let sign = tensor.clone().sign();
|
||||
let shrunk = tensor.sub(sign.mul_scalar(bias));
|
||||
let mask = abs_tensor.lower_equal_elem(lambda);
|
||||
shrunk.mask_fill(mask, 0)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,75 @@
|
||||
pub use burn_backend::tensor::BasicAutodiffOps;
|
||||
|
||||
use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
|
||||
/// Backward pass of the tensor.
|
||||
pub fn backward(&self) -> B::Gradients {
|
||||
B::backward(self.primitive.clone().tensor())
|
||||
}
|
||||
|
||||
/// Get the gradients of a tensor if it exist.
|
||||
///
|
||||
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
|
||||
/// be accessed multiple times. If you only need to get the gradients one time,
|
||||
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
|
||||
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
|
||||
match &self.primitive {
|
||||
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
|
||||
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
|
||||
match &self.primitive {
|
||||
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
TensorPrimitive::QFloat(_tensor) => {
|
||||
B::grad_remove(&self.primitive.clone().tensor(), grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
|
||||
/// gradient.
|
||||
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
|
||||
match &self.primitive {
|
||||
TensorPrimitive::Float(tensor) => {
|
||||
B::grad_replace(tensor, grads, grad.primitive.tensor())
|
||||
}
|
||||
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
|
||||
&self.primitive.clone().tensor(),
|
||||
grads,
|
||||
grad.primitive.tensor(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
|
||||
/// Returns the inner tensor without the autodiff information.
|
||||
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
|
||||
Tensor::new(K::inner(self.primitive))
|
||||
}
|
||||
|
||||
/// Convert a tensor to the autodiff backend.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `inner` - The tensor to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor converted to the autodiff backend.
|
||||
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
|
||||
Self::new(K::from_inner(inner.primitive))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,429 @@
|
||||
use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
|
||||
use alloc::{vec, vec::Vec};
|
||||
|
||||
use crate::try_read_sync;
|
||||
|
||||
/// The part of the tensor to keep when creating a triangular mask.
|
||||
enum TriPart {
|
||||
/// Upper triangular part.
|
||||
Upper,
|
||||
|
||||
/// Lower triangular part.
|
||||
Lower,
|
||||
|
||||
/// Diagonal part.
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
impl<B, const D: usize> Tensor<B, D, Bool>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Create a boolean tensor from data on the given device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The tensor data.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
|
||||
/// println!("{tensor}");
|
||||
/// }
|
||||
/// ```
|
||||
pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
|
||||
Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
|
||||
}
|
||||
|
||||
/// Convert the bool tensor into an int tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An integer tensor where `true` is converted to `1` and `false` to `0`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
|
||||
/// let int_tensor = bool_tensor.int();
|
||||
/// println!("{int_tensor}"); // [1, 0, 1]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn int(self) -> Tensor<B, D, Int> {
|
||||
Tensor::new(B::bool_into_int(self.primitive))
|
||||
}
|
||||
|
||||
/// Convert the bool tensor into a float tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
|
||||
/// let float_tensor = bool_tensor.float();
|
||||
/// println!("{float_tensor}"); // [1.0, 0.0, 1.0]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn float(self) -> Tensor<B, D> {
|
||||
Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
|
||||
}
|
||||
|
||||
/// Inverses boolean values.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
|
||||
/// let inverted = tensor.bool_not();
|
||||
/// println!("{inverted}"); // [[false, true], [true, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn bool_not(self) -> Self {
|
||||
Tensor::new(B::bool_not(self.primitive))
|
||||
}
|
||||
|
||||
/// Performs logical and (`&&`) on two boolean tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rhs` - The right-hand side tensor for the AND operation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor where each element is the result of `self[i] && rhs[i]`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
|
||||
/// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
|
||||
/// let result = a.bool_and(b);
|
||||
/// println!("{result}"); // [[true, false], [false, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::bool_and(self.primitive, rhs.primitive))
|
||||
}
|
||||
|
||||
/// Performs logical or (`||`) on two boolean tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rhs` - The right-hand side tensor for the OR operation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor where each element is the result of `self[i] || rhs[i]`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
|
||||
/// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
|
||||
/// let result = a.bool_or(b);
|
||||
/// println!("{result}"); // [[true, true], [true, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::bool_or(self.primitive, rhs.primitive))
|
||||
}
|
||||
|
||||
/// Performs logical xor (`^`) on two boolean tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rhs` - The right-hand side tensor for the XOR operation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.
|
||||
/// Returns `true` when exactly one of the operands is `true`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
|
||||
/// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
|
||||
/// let result = a.bool_xor(b);
|
||||
/// println!("{result}"); // [[false, true], [true, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
|
||||
}
|
||||
|
||||
/// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let tensor = Tensor::<B, 2, Bool>::from_bool(
|
||||
/// [[true, false, true], [false, true, false], [false, true, false]].into(),
|
||||
/// &device,
|
||||
/// );
|
||||
/// let indices = tensor.nonzero();
|
||||
/// println!("{}", indices[0]); // [0, 0, 1, 2]
|
||||
/// println!("{}", indices[1]); // [0, 2, 1, 1]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
|
||||
try_read_sync(self.nonzero_async())
|
||||
.expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
|
||||
}
|
||||
|
||||
/// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
|
||||
let indices = self.argwhere_async().await;
|
||||
|
||||
if indices.shape().num_elements() == 0 {
|
||||
// Return empty vec when all elements are zero
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let dims = indices.shape();
|
||||
indices
|
||||
.chunk(dims[1], 1)
|
||||
.into_iter()
|
||||
.map(|t| t.reshape(Shape::new([dims[0]])))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute the indices of the elements that are true, grouped by element.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
|
||||
/// result contains the indices of a non-zero element.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let tensor = Tensor::<B, 2, Bool>::from_bool(
|
||||
/// [[true, false, true], [false, true, false], [false, true, false]].into(),
|
||||
/// &device,
|
||||
/// );
|
||||
/// let indices = tensor.argwhere();
|
||||
/// println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn argwhere(self) -> Tensor<B, 2, Int> {
|
||||
try_read_sync(self.argwhere_async())
|
||||
.expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
|
||||
}
|
||||
|
||||
/// Compute the indices of the elements that are true, grouped by element.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
|
||||
/// result contains the indices of a non-zero element.
|
||||
pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
|
||||
Tensor::new(B::bool_argwhere(self.primitive).await)
|
||||
}
|
||||
|
||||
/// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
|
||||
/// fill the specified area with a value.
|
||||
fn tri_mask<S: Into<Shape>>(
|
||||
shape: S,
|
||||
tri_part: TriPart,
|
||||
offset: i64,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
let shape: Shape = shape.into();
|
||||
let height = shape[D - 2];
|
||||
let width = shape[D - 1];
|
||||
|
||||
// Generate row and column index tensors.
|
||||
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
|
||||
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
|
||||
|
||||
// Prepare shapes for broadcasting.
|
||||
let mut row_shape = [1; D];
|
||||
row_shape[D - 2] = height;
|
||||
let mut col_shape = [1; D];
|
||||
col_shape[D - 1] = width;
|
||||
|
||||
// Reshape for broadcasting.
|
||||
let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
|
||||
let col_broadcast = col_indices.reshape(Shape::new(col_shape));
|
||||
|
||||
// Broadcasting trick to create a matrix that facilitates comparison for mask generation.
|
||||
let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
|
||||
|
||||
// Select the appropriate comparison function based on `tri_part`.
|
||||
let compare = match tri_part {
|
||||
TriPart::Upper => Tensor::greater_elem,
|
||||
TriPart::Lower => Tensor::lower_elem,
|
||||
TriPart::Diagonal => Tensor::not_equal_elem,
|
||||
};
|
||||
|
||||
// Generate and return the mask by applying the comparison to the matrix.
|
||||
compare(matrix, 0).unsqueeze()
|
||||
}
|
||||
|
||||
/// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
|
||||
/// area with a value.
|
||||
///
|
||||
/// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape`: The shape of the matrix.
|
||||
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
|
||||
/// towards the upper triangle.
|
||||
/// * `device`: The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
|
||||
/// upper triangle taking into account the specified `offset`. All other elements are `true`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
|
||||
/// println!("{mask}");
|
||||
/// // [[false, false, false],
|
||||
/// // [true, false, false],
|
||||
/// // [true, true, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
|
||||
Self::tri_mask(shape, TriPart::Upper, offset, device)
|
||||
}
|
||||
|
||||
/// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
|
||||
/// area with a value.
|
||||
///
|
||||
/// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape`: The shape of the matrix.
|
||||
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
|
||||
/// towards the lower triangle.
|
||||
/// * `device`: The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
|
||||
/// lower triangle taking into account the specified `offset`. All other elements are `true`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
|
||||
/// println!("{mask}");
|
||||
/// // [[false, true, true],
|
||||
/// // [false, false, true],
|
||||
/// // [false, false, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
|
||||
Self::tri_mask(shape, TriPart::Lower, offset, device)
|
||||
}
|
||||
|
||||
/// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
|
||||
/// area with a value.
|
||||
///
|
||||
/// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape`: The shape of the matrix.
|
||||
/// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
|
||||
/// towards the upper triangle.
|
||||
/// * `device`: The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
|
||||
/// diagonal. All other elements are `true`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Bool};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
|
||||
/// println!("{mask}");
|
||||
/// // [[false, true, true],
|
||||
/// // [true, false, true],
|
||||
/// // [true, true, false]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
|
||||
Self::tri_mask(shape, TriPart::Diagonal, offset, device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
use crate::{Int, Shape, Tensor, backend::Backend};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `D2` is not equal to `D+1`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::Int;
|
||||
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||
/// println!("{}", result);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn cartesian_grid<B: Backend, S: Into<Shape>, const D: usize, const D2: usize>(
|
||||
shape: S,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D2, Int> {
|
||||
if D2 != D + 1 {
|
||||
panic!("D2 must equal D + 1 for Tensor::cartesian_grid")
|
||||
}
|
||||
|
||||
let dims = shape.into();
|
||||
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
|
||||
|
||||
for dim in 0..D {
|
||||
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
|
||||
|
||||
let mut shape = [1; D];
|
||||
shape[dim] = dims[dim];
|
||||
let mut dim_range = dim_range.reshape(shape);
|
||||
|
||||
for (i, &item) in dims.iter().enumerate() {
|
||||
if i == dim {
|
||||
continue;
|
||||
}
|
||||
dim_range = dim_range.repeat_dim(i, item);
|
||||
}
|
||||
|
||||
indices.push(dim_range);
|
||||
}
|
||||
|
||||
Tensor::stack::<D2>(indices, D)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,111 @@
|
||||
use crate::{Float, Tensor, backend::Backend};
|
||||
|
||||
impl<B, const D: usize> Tensor<B, D, Float>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Computes the floating-point remainder of dividing `self` by `other`.
|
||||
///
|
||||
/// The result has the same sign as `self` and magnitude less than `other`.
|
||||
/// This is equivalent to the IEEE 754 remainder operation.
|
||||
///
|
||||
/// # Special Cases (IEEE 754 compliant)
|
||||
///
|
||||
/// - If `self` is ±∞ and `other` is not NaN, NaN is returned
|
||||
/// - If `other` is ±0 and `self` is not NaN, NaN is returned
|
||||
/// - If `other` is ±∞ and `self` is finite, `self` is returned
|
||||
/// - If either argument is NaN, NaN is returned
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `other` - The divisor tensor. Must have the same shape as `self`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape where each element is the floating-point remainder.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::Tensor;
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
/// let dividend = Tensor::<B, 1>::from_data([5.3, -5.3, 5.3, -5.3], &device);
|
||||
/// let divisor = Tensor::<B, 1>::from_data([2.0, 2.0, -2.0, -2.0], &device);
|
||||
/// let result = dividend.fmod(divisor);
|
||||
///
|
||||
/// // Result: [1.3, -1.3, 1.3, -1.3]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn fmod(self, other: Self) -> Self {
|
||||
// Normal case: fmod(x, y) = x - y * trunc(x / y)
|
||||
let quotient = self.clone().div(other.clone());
|
||||
let truncated = quotient.trunc();
|
||||
let product = other.clone() * truncated.clone();
|
||||
|
||||
// When divisor is infinity and dividend is finite:
|
||||
// - quotient is 0, truncated is 0
|
||||
// - but 0 * infinity = NaN, which is wrong
|
||||
// We need to handle this case by replacing NaN with 0 when appropriate
|
||||
|
||||
// Check if the product is NaN due to 0 * inf
|
||||
let is_zero_times_inf = truncated.equal_elem(0.0).bool_and(other.is_inf());
|
||||
let zero_tensor = self.clone().mul_scalar(0.0);
|
||||
let corrected_product = product.mask_where(is_zero_times_inf, zero_tensor);
|
||||
|
||||
self - corrected_product
|
||||
}
|
||||
|
||||
/// Computes the floating-point remainder of dividing `self` by a scalar.
|
||||
///
|
||||
/// The result has the same sign as `self` and magnitude less than the scalar.
|
||||
///
|
||||
/// # Special Cases (IEEE 754 compliant)
|
||||
///
|
||||
/// - If `self` is ±∞ and scalar is not NaN, NaN is returned
|
||||
/// - If scalar is ±0 and `self` is not NaN, NaN is returned
|
||||
/// - If scalar is ±∞ and `self` is finite, `self` is returned
|
||||
/// - If either argument is NaN, NaN is returned
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `scalar` - The scalar divisor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape where each element is the floating-point remainder.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::Tensor;
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
/// let tensor = Tensor::<B, 1>::from_data([5.3, -5.3, 7.5, -7.5], &device);
|
||||
/// let result = tensor.fmod_scalar(2.0);
|
||||
///
|
||||
/// // Result: [1.3, -1.3, 1.5, -1.5]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn fmod_scalar(self, scalar: f32) -> Self {
|
||||
// Normal case: fmod(x, y) = x - y * trunc(x / y)
|
||||
let quotient = self.clone().div_scalar(scalar);
|
||||
let truncated = quotient.trunc();
|
||||
let product = truncated.mul_scalar(scalar);
|
||||
|
||||
// Handle the special case where scalar is infinity
|
||||
// When scalar is ±∞ and self is finite, quotient is 0, truncated is 0
|
||||
// but 0 * infinity = NaN, which is wrong - it should be 0
|
||||
if scalar.is_infinite() {
|
||||
// For finite values, fmod(x, ±∞) = x
|
||||
// For infinite values, fmod(±∞, ±∞) = NaN (which is handled by arithmetic)
|
||||
return self;
|
||||
}
|
||||
|
||||
self - product
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
use burn_backend::Scalar;
|
||||
|
||||
use crate::{
|
||||
Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend,
|
||||
cartesian_grid,
|
||||
};
|
||||
|
||||
use core::ops::Range;
|
||||
|
||||
impl<B> Tensor<B, 1, Int>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Returns a new integer tensor on the specified device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
|
||||
Tensor::new(B::int_arange(range, device))
|
||||
}
|
||||
|
||||
/// Returns a new integer tensor on the specified device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `step` - The step between each value.
|
||||
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
|
||||
Tensor::new(B::int_arange_step(range, step, device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D, Int>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Create a tensor from integers (i32), placing it on a given device.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Int};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
/// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
|
||||
/// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
|
||||
Self::from_data(ints.into().convert::<i32>(), device)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the same shape and device as the current tensor and the data
|
||||
/// cast to Float.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Int, Tensor};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
|
||||
/// let float_tensor = int_tensor.float();
|
||||
/// }
|
||||
/// ```
|
||||
pub fn float(self) -> Tensor<B, D, Float> {
|
||||
Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))
|
||||
}
|
||||
|
||||
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `D2` is not equal to `D+1`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::Int;
|
||||
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||
/// println!("{}", result);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
|
||||
shape: S,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D2, Int> {
|
||||
cartesian_grid::<B, S, D, D2>(shape, device)
|
||||
}
|
||||
|
||||
/// Applies the bitwise logical and operation with each bit representing the integer.
|
||||
pub fn bitwise_and(self, other: Self) -> Self {
|
||||
Self::new(B::bitwise_and(self.primitive, other.primitive))
|
||||
}
|
||||
|
||||
/// Applies the bitwise logical or operation with another tensor.
|
||||
pub fn bitwise_or(self, other: Self) -> Self {
|
||||
Self::new(B::bitwise_or(self.primitive, other.primitive))
|
||||
}
|
||||
|
||||
/// Applies the bitwise logical xor operation with another tensor.
|
||||
pub fn bitwise_xor(self, other: Self) -> Self {
|
||||
Self::new(B::bitwise_xor(self.primitive, other.primitive))
|
||||
}
|
||||
|
||||
/// Applies the bitwise logical not operation.
|
||||
pub fn bitwise_not(self) -> Self {
|
||||
Self::new(B::bitwise_not(self.primitive))
|
||||
}
|
||||
|
||||
/// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.
|
||||
pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {
|
||||
let other = Scalar::new(other, &self.dtype());
|
||||
Self::new(B::bitwise_and_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.
|
||||
pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {
|
||||
let other = Scalar::new(other, &self.dtype());
|
||||
Self::new(B::bitwise_or_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.
|
||||
pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {
|
||||
let other = Scalar::new(other, &self.dtype());
|
||||
Self::new(B::bitwise_xor_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Applies the bitwise left shift operation with the integers in the tensor.
|
||||
pub fn bitwise_left_shift(self, other: Self) -> Self {
|
||||
Self::new(B::bitwise_left_shift(self.primitive, other.primitive))
|
||||
}
|
||||
|
||||
/// Applies the bitwise right shift operation with the integers in the tensor.
|
||||
pub fn bitwise_right_shift(self, other: Self) -> Self {
|
||||
Self::new(B::bitwise_right_shift(self.primitive, other.primitive))
|
||||
}
|
||||
|
||||
/// Applies the bitwise left shift operation with the scalar.
|
||||
pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {
|
||||
let other = Scalar::new(other, &self.dtype());
|
||||
Self::new(B::bitwise_left_shift_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Applies the bitwise right shift operation with the scalar.
|
||||
pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {
|
||||
let other = Scalar::new(other, &self.dtype());
|
||||
Self::new(B::bitwise_right_shift_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Converts a tensor to the specified integer data type.
|
||||
///
|
||||
/// This is always a no-op when casting to the current dtype.
|
||||
///
|
||||
/// # Warning
|
||||
/// Most backends don't have automatic type promotion at this time, so make sure that all tensors
|
||||
/// have the same integer data type for operations multiple input tensors (e.g., binary ops).
|
||||
pub fn cast<F: Into<IntDType>>(self, dtype: F) -> Tensor<B, D, Int> {
|
||||
let dtype = dtype.into();
|
||||
let self_dtype: IntDType = self.dtype().into();
|
||||
if dtype == self_dtype {
|
||||
// no-op.
|
||||
return self;
|
||||
}
|
||||
Tensor::new(B::int_cast(self.primitive, dtype))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
pub(crate) mod check;
|
||||
|
||||
mod autodiff;
|
||||
mod base;
|
||||
mod bool;
|
||||
mod cartesian_grid;
|
||||
mod float;
|
||||
mod fmod;
|
||||
mod int;
|
||||
mod numeric;
|
||||
mod options;
|
||||
mod orderable;
|
||||
mod pad;
|
||||
pub use pad::IntoPadding;
|
||||
mod take;
|
||||
mod transaction;
|
||||
mod trunc;
|
||||
|
||||
pub use autodiff::*;
|
||||
pub use base::*;
|
||||
pub use cartesian_grid::cartesian_grid;
|
||||
pub use float::{DEFAULT_ATOL, DEFAULT_RTOL};
|
||||
pub use numeric::*;
|
||||
pub use options::*;
|
||||
pub use transaction::*;
|
||||
|
||||
pub use burn_backend::tensor::IndexingUpdateOp;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,116 @@
|
||||
use burn_backend::{Backend, Element, tensor::Device};
|
||||
use burn_std::DType;
|
||||
|
||||
use crate::get_device_policy;
|
||||
|
||||
/// Options for tensor creation.
|
||||
///
|
||||
/// This struct allows specifying the `device` and overriding the data type when creating a tensor.
|
||||
/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorCreationOptions<B: Backend> {
|
||||
/// Device where the tensor will be created.
|
||||
pub device: Device<B>,
|
||||
/// Optional data type.
|
||||
/// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).
|
||||
pub dtype: Option<DType>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for TensorCreationOptions<B> {
|
||||
/// Returns new options with the backend's default device.
|
||||
fn default() -> Self {
|
||||
Self::new(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorCreationOptions<B> {
|
||||
/// Create new options with a specific device.
|
||||
///
|
||||
/// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.
|
||||
pub fn new(device: Device<B>) -> Self {
|
||||
Self {
|
||||
device,
|
||||
dtype: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the tensor creation data type.
|
||||
pub fn with_dtype(mut self, dtype: DType) -> Self {
|
||||
self.dtype = Some(dtype);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the tensor creation device.
|
||||
pub fn with_device(mut self, device: Device<B>) -> Self {
|
||||
self.device = device;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Create options with backend's default device and float dtype.
|
||||
pub fn float() -> Self {
|
||||
Self::default().with_dtype(<B::FloatElem as Element>::dtype())
|
||||
}
|
||||
|
||||
/// Create options with backend's default device and int dtype.
|
||||
pub fn int() -> Self {
|
||||
Self::default().with_dtype(<B::IntElem as Element>::dtype())
|
||||
}
|
||||
|
||||
/// Create options with backend's default device and bool dtype.
|
||||
pub fn bool() -> Self {
|
||||
Self::default().with_dtype(<B::BoolElem as Element>::dtype())
|
||||
}
|
||||
|
||||
/// Returns the tensor data type, or a provided default if not set.
|
||||
///
|
||||
/// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
|
||||
pub fn dtype_or(&self, dtype: DType) -> DType {
|
||||
self.dtype.unwrap_or(dtype)
|
||||
}
|
||||
|
||||
/// Returns the tensor data type, or the default from the [device policy](crate::set_default_dtypes).
|
||||
pub(crate) fn resolve_policy(&self, dtype: DType) -> DType {
|
||||
// TODO: should rely on tensor kind, not element dtype
|
||||
self.dtype.unwrap_or_else(|| {
|
||||
let policy = get_device_policy(&self.device);
|
||||
if dtype.is_float()
|
||||
&& let Some(float_dtype) = policy.float_dtype()
|
||||
{
|
||||
float_dtype.into()
|
||||
} else if (dtype.is_int() || dtype.is_uint())
|
||||
&& let Some(int_dtype) = policy.int_dtype()
|
||||
{
|
||||
int_dtype.into()
|
||||
} else {
|
||||
// If policy was not explicitly set, use the fallback dtype (default backend elem type)
|
||||
dtype
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
|
||||
/// Convenience conversion from a reference to a device.
|
||||
///
|
||||
/// Example:
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::TensorCreationOptions;
|
||||
///
|
||||
/// fn example<B: Backend>(device: B::Device) {
|
||||
/// let options: TensorCreationOptions<B> = (&device).into();
|
||||
/// }
|
||||
/// ```
|
||||
fn from(device: &Device<B>) -> Self {
|
||||
TensorCreationOptions::new(device.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
|
||||
/// Convenience conversion for a specified `(&device, dtype)` tuple.
|
||||
fn from(args: (&Device<B>, DType)) -> Self {
|
||||
TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,363 @@
|
||||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
|
||||
|
||||
use super::Numeric;
|
||||
|
||||
/// Trait for types that can be used as padding specifications.
|
||||
///
|
||||
/// Padding is specified as `(before, after)` pairs per dimension, returned as a
|
||||
/// fixed-size array `[(usize, usize); D]`. If fewer pairs than dimensions are provided,
|
||||
/// they apply to the **last** N dimensions (earlier dimensions are left unpadded).
|
||||
pub trait IntoPadding<const D: usize> {
|
||||
/// Converts into a fixed-size array of `(before, after)` padding pairs.
|
||||
fn into_padding(self) -> [(usize, usize); D];
|
||||
}
|
||||
|
||||
impl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {
|
||||
fn into_padding(self) -> [(usize, usize); D] {
|
||||
assert!(
|
||||
N <= D,
|
||||
"Padding has {} pairs but tensor only has {} dimensions",
|
||||
N,
|
||||
D
|
||||
);
|
||||
let mut result = [(0usize, 0usize); D];
|
||||
let offset = D - N;
|
||||
for (i, pair) in self.into_iter().enumerate() {
|
||||
result[offset + i] = pair;
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Backward-compatible: `(left, right, top, bottom)` maps to last 2 dimensions.
|
||||
///
|
||||
/// Equivalent to `[(top, bottom), (left, right)]`.
|
||||
impl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {
|
||||
fn into_padding(self) -> [(usize, usize); D] {
|
||||
let (left, right, top, bottom) = self;
|
||||
let mut result = [(0usize, 0usize); D];
|
||||
result[D - 2] = (top, bottom);
|
||||
result[D - 1] = (left, right);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize> IntoPadding<D> for &[(usize, usize)] {
|
||||
fn into_padding(self) -> [(usize, usize); D] {
|
||||
assert!(
|
||||
self.len() <= D,
|
||||
"Padding has {} pairs but tensor only has {} dimensions",
|
||||
self.len(),
|
||||
D
|
||||
);
|
||||
let mut result = [(0usize, 0usize); D];
|
||||
let offset = D - self.len();
|
||||
for (i, &pair) in self.iter().enumerate() {
|
||||
result[offset + i] = pair;
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {
|
||||
fn into_padding(self) -> [(usize, usize); D] {
|
||||
assert!(
|
||||
self.len() <= D,
|
||||
"Padding has {} pairs but tensor only has {} dimensions",
|
||||
self.len(),
|
||||
D
|
||||
);
|
||||
let mut result = [(0usize, 0usize); D];
|
||||
let offset = D - self.len();
|
||||
for (i, pair) in self.into_iter().enumerate() {
|
||||
result[offset + i] = pair;
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to build a range array for slice_assign, selecting a portion of one dimension.
|
||||
fn build_slice_ranges<const D: usize>(
|
||||
dims: [usize; D],
|
||||
target_dim: usize,
|
||||
start: usize,
|
||||
len: usize,
|
||||
) -> [Range<usize>; D] {
|
||||
dims.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &size)| {
|
||||
if i == target_dim {
|
||||
start..start + len
|
||||
} else {
|
||||
0..size
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Range<usize>>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
impl<B, const D: usize, K> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
/// Pads the tensor using the specified padding mode.
|
||||
///
|
||||
/// Padding is specified as `(before, after)` pairs. If fewer pairs than tensor dimensions
|
||||
/// are provided, they apply to the **last** N dimensions (unspecified leading dimensions
|
||||
/// are left unpadded).
|
||||
///
|
||||
/// For backward compatibility, a `(left, right, top, bottom)` tuple is also accepted,
|
||||
/// which pads the last two dimensions.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `padding` - Padding specification. Accepts:
|
||||
/// - `[(before, after); N]` fixed-size array of pairs (N <= D)
|
||||
/// - `&[(before, after)]` slice of pairs per dimension
|
||||
/// - `Vec<(before, after)>` vector of pairs
|
||||
/// - `(left, right, top, bottom)` tuple for last-2-dim backward compatibility
|
||||
/// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor with the specified padding applied.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// - Panics if more padding pairs are provided than tensor dimensions.
|
||||
/// - `Reflect` mode panics if padding exceeds `dimension_size - 1`.
|
||||
/// - `Edge` mode panics if padding is applied to a zero-sized dimension.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Shape};
|
||||
/// use burn_tensor::ops::PadMode;
|
||||
///
|
||||
/// fn example<B: Backend<FloatElem: From<f32>>>() {
|
||||
/// let device = B::Device::default();
|
||||
/// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
|
||||
///
|
||||
/// // Constant padding with value 0.0 (backward-compatible tuple)
|
||||
/// let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0));
|
||||
///
|
||||
/// // Pad arbitrary dimensions with slice of (before, after) pairs
|
||||
/// let padded = tensor.clone().pad([(1, 1), (2, 2)], PadMode::Constant(0.0));
|
||||
///
|
||||
/// // Pad only the last dimension
|
||||
/// let padded = tensor.pad([(1, 1)], PadMode::Reflect);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {
|
||||
let pairs = padding.into_padding();
|
||||
match mode.into() {
|
||||
PadMode::Constant(value) => pad_constant(self, &pairs, value),
|
||||
PadMode::Reflect => pad_reflect(self, &pairs),
|
||||
PadMode::Edge => pad_edge(self, &pairs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pad with a constant value.
|
||||
fn pad_constant<B, const D: usize, K, E>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
padding: &[(usize, usize); D],
|
||||
value: E,
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
E: ElementConversion,
|
||||
{
|
||||
let mut padded_dims: [usize; D] = tensor.dims();
|
||||
|
||||
for (i, &(before, after)) in padding.iter().enumerate() {
|
||||
padded_dims[i] += before + after;
|
||||
}
|
||||
|
||||
let ranges: [Range<usize>; D] = padded_dims
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &dim)| {
|
||||
let (before, after) = padding[i];
|
||||
before..dim - after
|
||||
})
|
||||
.collect::<Vec<Range<usize>>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
|
||||
|
||||
padded_tensor.slice_assign(ranges, tensor)
|
||||
}
|
||||
|
||||
/// Pad using reflection at the boundaries (excluding edge values).
|
||||
///
|
||||
/// For ONNX "reflect" mode: mirrors from index 1, not index 0.
|
||||
/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]`
|
||||
fn pad_reflect<B, const D: usize, K>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
padding: &[(usize, usize); D],
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
let dims = tensor.dims();
|
||||
|
||||
for (i, &(before, after)) in padding.iter().enumerate() {
|
||||
if before > 0 || after > 0 {
|
||||
assert!(
|
||||
before < dims[i] && after < dims[i],
|
||||
"Reflect padding ({}, {}) must be less than dimension {} size ({})",
|
||||
before,
|
||||
after,
|
||||
i,
|
||||
dims[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = tensor;
|
||||
|
||||
for (i, &(before, after)) in padding.iter().enumerate() {
|
||||
if before > 0 || after > 0 {
|
||||
result = pad_reflect_dim(result, i, before, after);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Helper to pad a single dimension using reflection.
|
||||
fn pad_reflect_dim<B, const D: usize, K>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
dim: usize,
|
||||
pad_before: usize,
|
||||
pad_after: usize,
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
let dims = tensor.dims();
|
||||
let dim_size = dims[dim];
|
||||
|
||||
// Calculate output dimensions
|
||||
let mut output_dims = dims;
|
||||
output_dims[dim] += pad_before + pad_after;
|
||||
|
||||
// Create output tensor and place original in the center
|
||||
let output = Tensor::zeros(output_dims, &tensor.device());
|
||||
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
|
||||
let mut output = output.slice_assign(original_range, tensor.clone());
|
||||
|
||||
// Assign reflected "before" padding (e.g., top or left)
|
||||
// Reflect excludes the edge, so we take indices [1..pad_before+1] and flip
|
||||
if pad_before > 0 {
|
||||
let before_slice = tensor.clone().narrow(dim, 1, pad_before);
|
||||
let before_flipped = before_slice.flip([dim as isize]);
|
||||
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
|
||||
output = output.slice_assign(before_range, before_flipped);
|
||||
}
|
||||
|
||||
// Assign reflected "after" padding (e.g., bottom or right)
|
||||
// Take indices [dim_size - pad_after - 1..dim_size - 1] and flip
|
||||
if pad_after > 0 {
|
||||
let start = dim_size - pad_after - 1;
|
||||
let after_slice = tensor.narrow(dim, start, pad_after);
|
||||
let after_flipped = after_slice.flip([dim as isize]);
|
||||
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
|
||||
output = output.slice_assign(after_range, after_flipped);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Pad by replicating edge values.
|
||||
///
|
||||
/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]`
|
||||
fn pad_edge<B, const D: usize, K>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
padding: &[(usize, usize); D],
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
let dims = tensor.dims();
|
||||
|
||||
for (i, &(before, after)) in padding.iter().enumerate() {
|
||||
if before > 0 || after > 0 {
|
||||
assert!(
|
||||
dims[i] > 0,
|
||||
"Cannot apply edge padding to zero-sized dimension {}",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = tensor;
|
||||
|
||||
for (i, &(before, after)) in padding.iter().enumerate() {
|
||||
if before > 0 || after > 0 {
|
||||
result = pad_edge_dim(result, i, before, after);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Helper to pad a single dimension by replicating edge values.
|
||||
fn pad_edge_dim<B, const D: usize, K>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
dim: usize,
|
||||
pad_before: usize,
|
||||
pad_after: usize,
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
let dims = tensor.dims();
|
||||
let dim_size = dims[dim];
|
||||
|
||||
// Calculate output dimensions
|
||||
let mut output_dims = dims;
|
||||
output_dims[dim] += pad_before + pad_after;
|
||||
|
||||
// Create output tensor and place original in the center
|
||||
let output = Tensor::zeros(output_dims, &tensor.device());
|
||||
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
|
||||
let mut output = output.slice_assign(original_range, tensor.clone());
|
||||
|
||||
// Assign "before" padding by repeating the first element
|
||||
if pad_before > 0 {
|
||||
let first_slice = tensor.clone().narrow(dim, 0, 1);
|
||||
let before_pad = first_slice.repeat_dim(dim, pad_before);
|
||||
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
|
||||
output = output.slice_assign(before_range, before_pad);
|
||||
}
|
||||
|
||||
// Assign "after" padding by repeating the last element
|
||||
if pad_after > 0 {
|
||||
let last_slice = tensor.narrow(dim, dim_size - 1, 1);
|
||||
let after_pad = last_slice.repeat_dim(dim, pad_after);
|
||||
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
|
||||
output = output.slice_assign(after_range, after_pad);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use crate::{AsIndex, BasicOps, Int, Tensor, backend::Backend, check, check::TensorCheck};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
impl<B, const D: usize, K> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: BasicOps<B>,
|
||||
{
|
||||
/// Takes elements from the tensor along the given dimension using indices of any dimensionality.
|
||||
///
|
||||
/// This behaves like numpy's take function. When indices is multi-dimensional,
|
||||
/// the output shape will be: input.shape\[:dim\] + indices.shape + input.shape\[dim+1:\]
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension along which to select elements. Supports negative indexing.
|
||||
/// * `indices` - The indices of elements to select. Can be any dimensionality.
|
||||
/// Must be valid indices in the range [0, dim_size).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Int};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
///
|
||||
/// // Example with 1D indices
|
||||
/// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||||
/// let indices = Tensor::<B, 1, Int>::from_data([2, 0, 1], &device);
|
||||
/// let result: Tensor<B, 2> = tensor.clone().take::<1, 2>(-1, indices); // -1 refers to last dimension
|
||||
/// println!("{result}");
|
||||
/// // [[3.0, 1.0, 2.0], [6.0, 4.0, 5.0]]
|
||||
///
|
||||
/// // Example with 2D indices - output will have +1 dimension (2D -> 3D)
|
||||
/// let indices_2d = Tensor::<B, 2, Int>::from_data([[0, 2], [1, 0]], &device);
|
||||
/// let result: Tensor<B, 3> = tensor.take::<2, 3>(1, indices_2d);
|
||||
/// println!("{result}");
|
||||
/// // [[[1.0, 3.0], [2.0, 1.0]], [[4.0, 6.0], [5.0, 4.0]]]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn take<const DI: usize, const DO: usize>(
|
||||
self,
|
||||
dim: impl AsIndex,
|
||||
indices: Tensor<B, DI, Int>,
|
||||
) -> Tensor<B, DO, K> {
|
||||
let dim = dim.expect_dim_index(D);
|
||||
check!(TensorCheck::take::<D, DI, DO>(dim));
|
||||
|
||||
// Store the indices shape for reshaping later
|
||||
let indices_shape = indices.shape();
|
||||
let indices_dims = indices_shape.clone();
|
||||
|
||||
// Flatten indices to 1D for processing
|
||||
let indices_flat = indices.reshape([indices_shape.num_elements()]);
|
||||
|
||||
// Perform the selection with the flattened indices
|
||||
let selected = self.select(dim, indices_flat);
|
||||
|
||||
// Build the output shape
|
||||
// Output shape = input.shape[:dim] + indices.shape + input.shape[dim+1:]
|
||||
let selected_shape = selected.shape();
|
||||
let mut new_shape = Vec::with_capacity(DO);
|
||||
|
||||
// Add dimensions before the selected dimension
|
||||
for i in 0..dim {
|
||||
new_shape.push(selected_shape[i]);
|
||||
}
|
||||
|
||||
// Add all indices dimensions
|
||||
for &idx_dim in indices_dims.iter() {
|
||||
new_shape.push(idx_dim);
|
||||
}
|
||||
|
||||
// Add dimensions after the selected dimension
|
||||
for i in (dim + 1)..D {
|
||||
new_shape.push(selected_shape[i]);
|
||||
}
|
||||
|
||||
// Verify we have the correct number of dimensions
|
||||
assert_eq!(
|
||||
new_shape.len(),
|
||||
DO,
|
||||
"Internal error: shape calculation resulted in {} dims but expected {}",
|
||||
new_shape.len(),
|
||||
DO
|
||||
);
|
||||
|
||||
// Convert to fixed-size array for reshape
|
||||
let mut shape_array = [0; DO];
|
||||
for (i, &s) in new_shape.iter().enumerate() {
|
||||
shape_array[i] = s;
|
||||
}
|
||||
|
||||
selected.reshape(shape_array)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
use super::{BasicOps, Tensor};
|
||||
use crate::{
|
||||
TensorData,
|
||||
backend::{Backend, ExecutionError},
|
||||
ops::TransactionPrimitive,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(Default)]
|
||||
/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
|
||||
/// compute utilization with optimized laziness.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let [output_data, loss_data, targets_data] = Transaction::default()
|
||||
/// .register(output)
|
||||
/// .register(loss)
|
||||
/// .register(targets)
|
||||
/// .execute()
|
||||
/// .try_into()
|
||||
/// .expect("Correct amount of tensor data");
|
||||
/// ```
|
||||
pub struct Transaction<B: Backend> {
|
||||
op: TransactionPrimitive<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Transaction<B> {
|
||||
/// Add a [tensor](Tensor) to the transaction to be read.
|
||||
pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
|
||||
K::register_transaction(&mut self.op, tensor.into_primitive());
|
||||
self
|
||||
}
|
||||
|
||||
/// Executes the transaction synchronously and returns the [data](TensorData) in the same order
|
||||
/// in which they were [registered](Self::register).
|
||||
pub fn execute(self) -> Vec<TensorData> {
|
||||
burn_std::future::block_on(self.execute_async())
|
||||
.expect("Error while reading data: use `try_execute` to handle error at runtime")
|
||||
}
|
||||
|
||||
/// Executes the transaction synchronously and returns the [data](TensorData) in the same
|
||||
/// order in which they were [registered](Self::register).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Any error that might have occurred since the last time the device was synchronized.
|
||||
pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
|
||||
burn_std::future::block_on(self.execute_async())
|
||||
}
|
||||
|
||||
/// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
|
||||
/// in which they were [registered](Self::register).
|
||||
pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
|
||||
self.op.execute_async().await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
use crate::{Float, Tensor, TensorPrimitive, backend::Backend};
|
||||
|
||||
impl<B, const D: usize> Tensor<B, D, Float>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Truncates the tensor element-wise, rounding toward zero.
|
||||
///
|
||||
/// This function returns a new tensor with the same shape as the input tensor,
|
||||
/// where each element is truncated toward zero. For positive values, this is
|
||||
/// equivalent to floor, and for negative values, it's equivalent to ceil.
|
||||
///
|
||||
/// # Special Cases (IEEE 754 compliant)
|
||||
///
|
||||
/// - `trunc(±0)` returns ±0 (preserves sign of zero)
|
||||
/// - `trunc(±∞)` returns ±∞
|
||||
/// - `trunc(NaN)` returns NaN
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape where each element has been truncated toward zero.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::Tensor;
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = B::Device::default();
|
||||
/// let tensor = Tensor::<B, 1>::from_data([2.3, -1.7, 0.5, -0.5, 3.9], &device);
|
||||
/// let truncated = tensor.trunc();
|
||||
///
|
||||
/// // Result: [2.0, -1.0, 0.0, -0.0, 3.0]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn trunc(self) -> Self {
|
||||
Self::new(TensorPrimitive::Float(B::float_trunc(
|
||||
self.primitive.tensor(),
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use crate::ElementConversion;
|
||||
use crate::backend::Backend;
|
||||
use crate::s;
|
||||
use crate::tensor::{Int, Tensor};
|
||||
use alloc::vec;
|
||||
|
||||
/// Generate a tensor with homogeonous coordinates of each element's
|
||||
/// transformed location
|
||||
///
|
||||
///
|
||||
/// See:
|
||||
/// - [torch.nn.functional.affine_grid](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html)
|
||||
///
|
||||
/// * `transform` - Transformation with shape (batch_size, 2, 3)
|
||||
/// * `dims` - dimensions as (batch_size, channels, height, width)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Tensor with shape (batch_size, height, width, 2), where dim 2 is (x, y)
|
||||
/// All coordinates are broadcast on the batch dim
|
||||
pub fn affine_grid_2d<B: Backend>(transform: Tensor<B, 3>, dims: [usize; 4]) -> Tensor<B, 4> {
|
||||
let [batch_size, _c, height, width] = dims;
|
||||
|
||||
let device = &transform.device();
|
||||
|
||||
let x = Tensor::<B, 1, Int>::arange(0..width as i64, device)
|
||||
.reshape([1, width])
|
||||
.expand([height, width]);
|
||||
let y = Tensor::<B, 1, Int>::arange(0..height as i64, device)
|
||||
.reshape([height, 1])
|
||||
.expand([height, width]);
|
||||
|
||||
// from ints (0..(width-1)) and (0..(height-1)), to (-1.0..1.0)
|
||||
let x = x
|
||||
.float()
|
||||
.div_scalar(((width - 1) as f32 / 2.0).elem::<f32>())
|
||||
.sub_scalar((1_f32).elem::<f32>());
|
||||
let y = y
|
||||
.float()
|
||||
.div_scalar(((height - 1) as f32 / 2.0).elem::<f32>())
|
||||
.sub_scalar((1_f32).elem::<f32>());
|
||||
|
||||
// Broadcast to batch dimension
|
||||
let x = x.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W]
|
||||
let y = y.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W]
|
||||
|
||||
// Apply affine transform
|
||||
let a_11 = transform.clone().slice(s![.., 0, 0]);
|
||||
let a_12 = transform.clone().slice(s![.., 0, 1]);
|
||||
let trans_x = transform.clone().slice(s![.., 0, 2]);
|
||||
|
||||
let a_21 = transform.clone().slice(s![.., 1, 0]);
|
||||
let a_22 = transform.clone().slice(s![.., 1, 1]);
|
||||
let trans_y = transform.slice(s![.., 1, 2]);
|
||||
|
||||
let grid_x = a_11.mul(x.clone()).add(a_12.mul(y.clone())).add(trans_x);
|
||||
let grid_y = a_21.mul(x).add(a_22.mul(y)).add(trans_y);
|
||||
|
||||
Tensor::stack(vec![grid_x, grid_y], 3)
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos};
|
||||
use crate::tensor::{BasicOps, Tensor};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Return a collection of coordinate matrices for coordinate vectors.
|
||||
///
|
||||
/// Takes N 1D tensors and returns N tensors where each tensor represents the coordinates
|
||||
/// in one dimension across an N-dimensional grid.
|
||||
///
|
||||
/// Based upon `options.sparse`, the generated coordinate tensors can either be `Sparse` or `Dense`:
|
||||
/// * In `Sparse` mode, output tensors will have shape 1 everywhere except their cardinal dimension.
|
||||
/// * In `Dense` mode, output tensors will be expanded to the full grid shape.
|
||||
///
|
||||
/// Based upon `options.indexing`, the generated coordinate tensors will use either:
|
||||
/// * `Matrix` indexing, where dimensions are in the same order as their cardinality.
|
||||
/// * `Cartesian` indexing; where the first two dimensions are swapped.
|
||||
///
|
||||
/// See:
|
||||
/// - [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html)
|
||||
/// - [torch.meshgrid](https://pytorch.org/docs/stable/generated/torch.meshgrid.html)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - A slice of 1D tensors
|
||||
/// * `options` - the options.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of N N-dimensional tensors representing the grid coordinates.
|
||||
pub fn meshgrid<B: Backend, const N: usize, K, O>(
|
||||
tensors: &[Tensor<B, 1, K>; N],
|
||||
options: O,
|
||||
) -> [Tensor<B, N, K>; N]
|
||||
where
|
||||
K: BasicOps<B>,
|
||||
O: Into<GridOptions>,
|
||||
{
|
||||
let options = options.into();
|
||||
let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1;
|
||||
let dense = options.sparsity == GridSparsity::Dense;
|
||||
|
||||
let grid_shape: [usize; N] = tensors
|
||||
.iter()
|
||||
.map(|t| t.dims()[0])
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
tensors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, tensor)| {
|
||||
let mut coord_tensor_shape = [1; N];
|
||||
coord_tensor_shape[i] = grid_shape[i];
|
||||
|
||||
// Reshape the tensor to have singleton dimensions in all but the i-th dimension
|
||||
let mut tensor = tensor.clone().reshape(coord_tensor_shape);
|
||||
|
||||
if dense {
|
||||
tensor = tensor.expand(grid_shape);
|
||||
}
|
||||
if swap_dims {
|
||||
tensor = tensor.swap_dims(0, 1);
|
||||
}
|
||||
|
||||
tensor
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Return a coordinate matrix for a given set of 1D coordinate tensors.
|
||||
///
|
||||
/// Equivalent to stacking a dense matrix `meshgrid`,
|
||||
/// where the stack is along the first or last dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors`: A slice of 1D tensors.
|
||||
/// * `index_pos`: The position of the index in the output tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of either ``(N, ..., |T[i]|, ...)`` or ``(..., |T[i]|, ..., N)``,
|
||||
/// of coordinates, indexed on the first or last dimension.
|
||||
pub fn meshgrid_stack<B: Backend, const D: usize, const D2: usize, K>(
|
||||
tensors: &[Tensor<B, 1, K>; D],
|
||||
index_pos: IndexPos,
|
||||
) -> Tensor<B, D2, K>
|
||||
where
|
||||
K: BasicOps<B>,
|
||||
{
|
||||
assert_eq!(D2, D + 1, "D2 ({D2}) != D ({D}) + 1");
|
||||
|
||||
let xs: Vec<Tensor<B, D, K>> = meshgrid(tensors, GridOptions::default())
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let dim = match index_pos {
|
||||
IndexPos::First => 0,
|
||||
IndexPos::Last => D,
|
||||
};
|
||||
|
||||
Tensor::stack(xs, dim)
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
mod affine_grid;
|
||||
mod meshgrid;
|
||||
|
||||
pub use meshgrid::*;
|
||||
|
||||
pub use affine_grid::*;
|
||||
|
||||
/// Enum to specify index cardinal layout.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GridIndexing {
|
||||
/// Dimensions are in the same order as the cardinality of the inputs.
|
||||
/// Equivalent to "ij" indexing in NumPy and PyTorch.
|
||||
#[default]
|
||||
Matrix,
|
||||
|
||||
/// The same as Matrix, but the first two dimensions are swapped.
|
||||
/// Equivalent to "xy" indexing in NumPy and PyTorch.
|
||||
Cartesian,
|
||||
}
|
||||
|
||||
/// Enum to specify grid sparsity mode.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GridSparsity {
|
||||
/// The grid is fully expanded to the full cartesian product shape.
|
||||
#[default]
|
||||
Dense,
|
||||
|
||||
/// The grid is sparse, expanded only at the cardinal dimensions.
|
||||
Sparse,
|
||||
}
|
||||
|
||||
/// Grid policy options.
|
||||
#[derive(new, Default, Debug, Copy, Clone)]
|
||||
pub struct GridOptions {
|
||||
/// Indexing mode.
|
||||
pub indexing: GridIndexing,
|
||||
|
||||
/// Sparsity mode.
|
||||
pub sparsity: GridSparsity,
|
||||
}
|
||||
|
||||
impl From<GridIndexing> for GridOptions {
|
||||
fn from(value: GridIndexing) -> Self {
|
||||
Self {
|
||||
indexing: value,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<GridSparsity> for GridOptions {
|
||||
fn from(value: GridSparsity) -> Self {
|
||||
Self {
|
||||
sparsity: value,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Enum to specify the index dimension position.
|
||||
#[derive(Default, Debug, Copy, Clone)]
|
||||
pub enum IndexPos {
|
||||
/// The index is in the first dimension.
|
||||
#[default]
|
||||
First,
|
||||
|
||||
/// The index is in the last dimension.
|
||||
Last,
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
use crate::ElementConversion;
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use super::l2_norm;
|
||||
|
||||
/// Default epsilon value to avoid division by zero
|
||||
pub const DEFAULT_EPSILON: f64 = 1e-8;
|
||||
/// Computes the cosine similarity between two tensors along a specified dimension.
|
||||
///
|
||||
/// Calculates the cosine of the angle between inputs as their dot product divided
|
||||
/// by the product of their L2 norms.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x1` - First input tensor
|
||||
/// * `x2` - Second input tensor
|
||||
/// * `dim` - Dimension along which to compute the similarity
|
||||
/// (negative indices allowed: -1 for last dimension)
|
||||
/// * `eps` - Small value to avoid division by zero (default: 1e-8)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Tensor containing the cosine similarity between x1 and x2
|
||||
pub fn cosine_similarity<B: Backend, const D: usize>(
|
||||
x1: Tensor<B, D>,
|
||||
x2: Tensor<B, D>,
|
||||
dim: i32,
|
||||
eps: Option<B::FloatElem>,
|
||||
) -> Tensor<B, D> {
|
||||
let eps = eps.unwrap_or_else(|| B::FloatElem::from_elem(DEFAULT_EPSILON));
|
||||
|
||||
// Convert negative dimension to positive
|
||||
let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;
|
||||
|
||||
// Compute dot product: sum(x1 * x2) along the specified dimension
|
||||
let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);
|
||||
|
||||
// Compute L2 norms: ||x1|| and ||x2||
|
||||
let norm_x1 = l2_norm(x1, dim_idx);
|
||||
let norm_x2 = l2_norm(x2, dim_idx);
|
||||
|
||||
// Calculate the denominator (product of the norms) with epsilon to avoid division by zero
|
||||
let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);
|
||||
|
||||
// Return the cosine similarity (dot product divided by the product of norms)
|
||||
dot_product / denominator
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::check;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::tensor::{Int, Tensor};
|
||||
use crate::{BasicOps, TensorKind};
|
||||
|
||||
/// Returns the diag of a matrix.
|
||||
///
|
||||
/// For batched inputs, returns of each matrix in the batch independently.
|
||||
///
|
||||
/// The diag operation extracts the diagonal elements of the last two dimensions,
|
||||
/// treating them as the matrix dimensions, while preserving all leading batch dimensions.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor with at least 2 dimensions.
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of rank `D - 1`, where the last dimension contains the diagonal elements of the input.
|
||||
pub fn diag<B: Backend, const D: usize, const DO: usize, K>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
) -> Tensor<B, DO, K>
|
||||
where
|
||||
K: TensorKind<B> + BasicOps<B>,
|
||||
{
|
||||
check!(TensorCheck::diag::<D, DO>());
|
||||
|
||||
let shape = tensor.shape();
|
||||
let rows = shape[D - 2];
|
||||
let cols = shape[D - 1];
|
||||
let diag_len = rows.min(cols);
|
||||
let device = tensor.device();
|
||||
|
||||
// create the indices for the diag
|
||||
let mut flat_shape = shape.clone();
|
||||
flat_shape[D - 2] = rows * cols;
|
||||
flat_shape[D - 1] = 1;
|
||||
let flat: Tensor<B, D, K> = tensor.reshape(flat_shape);
|
||||
|
||||
let range = Tensor::<B, 1, Int>::arange(0..diag_len as i64, &device);
|
||||
let step_tensor = Tensor::<B, 1, Int>::from_data([cols as i64 + 1], &device);
|
||||
let indices = range * step_tensor;
|
||||
flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1)
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
use crate::{
|
||||
Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s,
|
||||
tensor::Tensor,
|
||||
};
|
||||
/// Performs PLU decomposition of a square matrix.
|
||||
///
|
||||
/// The function decomposes a given square matrix `A` into three matrices: a permutation vector `p`,
|
||||
/// a lower triangular matrix `L`, and an upper triangular matrix `U`, such that `PA = LU`.
|
||||
/// The permutation vector `p` represents the row swaps made during the decomposition process.
|
||||
/// The lower triangular matrix `L` has ones on its diagonal and contains the multipliers used
|
||||
/// during the elimination process below the diagonal. The upper triangular matrix `U` contains
|
||||
/// the resulting upper triangular form of the matrix after the elimination process.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensor` - A square matrix to decompose, represented as a 2D tensor.
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple containing:
|
||||
/// - A 2D tensor representing the combined `L` and `U` matrices.
|
||||
/// - A 1D tensor representing the permutation vector `p`.
|
||||
///
|
||||
/// # Panics and numerical issues
|
||||
/// - The function will panic if the input matrix is singular or near-singular.
|
||||
/// - The function will panic if the input matrix is not square.
|
||||
/// # Performance note (synchronization / device transfers)
|
||||
/// This function may involve multiple synchronizations and device transfers, especially
|
||||
/// when determining pivot elements and performing row swaps. This can impact performance,
|
||||
pub fn lu_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
|
||||
check!(TensorCheck::is_square::<2>(
|
||||
"lu_decomposition",
|
||||
&tensor.shape()
|
||||
));
|
||||
let dims = tensor.shape().dims::<2>();
|
||||
let n = dims[0];
|
||||
|
||||
let mut permutations = Tensor::arange(0..n as i64, &tensor.device());
|
||||
let mut tensor = tensor;
|
||||
|
||||
for k in 0..n {
|
||||
// Find the pivot row
|
||||
let p = tensor
|
||||
.clone()
|
||||
.slice(s![k.., k])
|
||||
.abs()
|
||||
.argmax(0)
|
||||
.into_scalar()
|
||||
.to_usize()
|
||||
+ k;
|
||||
let max = tensor.clone().slice(s![p, k]).abs();
|
||||
|
||||
// Avoid division by zero
|
||||
let pivot = max.into_scalar();
|
||||
check!(TensorCheck::lu_decomposition_pivot::<B>(pivot));
|
||||
|
||||
if p != k {
|
||||
tensor = swap_slices(tensor, s![k, ..], s![p, ..]);
|
||||
permutations = swap_slices(permutations, s![k], s![p]);
|
||||
}
|
||||
|
||||
// Normalize k-th column under the diagonal
|
||||
if k < n - 1 {
|
||||
let a_kk = tensor.clone().slice(s![k, k]);
|
||||
let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk;
|
||||
tensor = tensor.slice_assign(s![(k + 1).., k], column);
|
||||
}
|
||||
|
||||
// Update the trailing submatrix
|
||||
for i in (k + 1)..n {
|
||||
// a[i, k+1..] -= a[i, k] * a[k, k+1..]
|
||||
let a_ik = tensor.clone().slice(s![i, k]);
|
||||
let row_k = tensor.clone().slice(s![k, (k + 1)..]);
|
||||
let update = a_ik * row_k;
|
||||
let row_i = tensor.clone().slice(s![i, (k + 1)..]);
|
||||
tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update);
|
||||
}
|
||||
}
|
||||
|
||||
(tensor, permutations)
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
use crate::Numeric;
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::{BasicOps, Shape, Tensor};
|
||||
|
||||
/// Performs matrix-vector multiplication with optional batch dimensions.
|
||||
///
|
||||
/// The `matrix` tensor is expected to have rank `DM` with the last two dimensions representing
|
||||
/// the matrix rows and columns. The `vector` tensor should have rank `DV = DM - 1`, sharing
|
||||
/// broadcast-compatible batch dimensions and matching the last dimension of the matrix.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the matrix rank is lower than 2.
|
||||
/// * If the vector rank isn't one less than the matrix rank.
|
||||
/// * If batch dimensions differ between the operands.
|
||||
/// * If the inner dimensions are incompatible for multiplication.
|
||||
pub fn matvec<B: Backend, const DM: usize, const DV: usize, K>(
|
||||
matrix: Tensor<B, DM, K>,
|
||||
vector: Tensor<B, DV, K>,
|
||||
) -> Tensor<B, DV, K>
|
||||
where
|
||||
K: BasicOps<B> + Numeric<B>,
|
||||
{
|
||||
assert!(
|
||||
DM >= 2,
|
||||
"matvec expects the matrix to be at least rank 2 (got {DM})"
|
||||
);
|
||||
assert!(
|
||||
DM == DV + 1,
|
||||
"matvec expects the vector rank ({DV}) to be exactly one less than the matrix rank ({DM})",
|
||||
);
|
||||
|
||||
let matrix_dims = matrix.shape().dims::<DM>();
|
||||
let vector_dims = vector.shape().dims::<DV>();
|
||||
|
||||
// Validate batch dimensions (all leading dimensions prior to the matrix axes).
|
||||
let batch_rank = DM.saturating_sub(2);
|
||||
if batch_rank > 0 {
|
||||
let matrix_batch = Shape::from(&matrix_dims[..batch_rank]);
|
||||
let vector_batch = Shape::from(&vector_dims[..batch_rank]);
|
||||
|
||||
assert!(
|
||||
matrix_batch.broadcast(&vector_batch).is_ok(),
|
||||
"Batch dimensions are not broadcast-compatible: matrix {:?} vs vector {:?}",
|
||||
&matrix_dims[..batch_rank],
|
||||
&vector_dims[..batch_rank]
|
||||
);
|
||||
}
|
||||
|
||||
let matrix_inner = matrix_dims[DM - 1];
|
||||
let vector_inner = vector_dims[DV - 1];
|
||||
assert!(
|
||||
matrix_inner == vector_inner,
|
||||
"Inner dimension mismatch: matrix has {matrix_inner} columns but vector has {vector_inner} entries",
|
||||
);
|
||||
|
||||
let vector_expanded = vector.unsqueeze_dim::<DM>(DV);
|
||||
matrix.matmul(vector_expanded).squeeze_dim::<DV>(DM - 1)
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
mod cosine_similarity;
|
||||
mod diag;
|
||||
mod lu_decomposition;
|
||||
mod matvec;
|
||||
mod outer;
|
||||
mod trace;
|
||||
mod vector_norm;
|
||||
|
||||
pub use cosine_similarity::*;
|
||||
pub use diag::*;
|
||||
pub use lu_decomposition::*;
|
||||
pub use matvec::*;
|
||||
pub use outer::*;
|
||||
pub use trace::*;
|
||||
pub use vector_norm::*;
|
||||
|
||||
use crate::{BasicOps, SliceArg, Tensor, TensorKind, backend::Backend};
|
||||
|
||||
/// Swaps two slices of a tensor.
|
||||
/// # Arguments
|
||||
/// * `tensor` - The input tensor.
|
||||
/// * `slices1` - The first slice to swap.
|
||||
/// * `slices2` - The second slice to swap.
|
||||
/// # Returns
|
||||
/// A new tensor with the specified slices swapped.
|
||||
/// # Notes
|
||||
/// This method will be useful for matrix factorization algorithms.
|
||||
fn swap_slices<B: Backend, const D: usize, K, S>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
slices1: S,
|
||||
slices2: S,
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
S: SliceArg + Clone,
|
||||
K: TensorKind<B> + BasicOps<B>,
|
||||
{
|
||||
let temporary = tensor.clone().slice(slices1.clone());
|
||||
let tensor = tensor
|
||||
.clone()
|
||||
.slice_assign(slices1, tensor.slice(slices2.clone()));
|
||||
tensor.slice_assign(slices2, temporary)
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::{BasicOps, Tensor};
|
||||
use crate::{AsIndex, Numeric};
|
||||
|
||||
/// Computes the outer product for the last columns of 2 tensors.
|
||||
///
|
||||
/// See also: [`outer_dim`].
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `lhs`: the "row" tensor, with shape ``[..., i]``.
|
||||
/// - `rhs`: the "col" tensor, with shape ``[..., j]``.
|
||||
/// - `dim`: the dimension to product.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of rank `R = D + 1`, where:
|
||||
///
|
||||
/// ``
|
||||
/// result[..., i, j] = lhs[..., i] * rhs[..., j]
|
||||
/// ``
|
||||
pub fn outer<B: Backend, const D: usize, const R: usize, K>(
|
||||
x: Tensor<B, D, K>,
|
||||
y: Tensor<B, D, K>,
|
||||
) -> Tensor<B, R, K>
|
||||
where
|
||||
K: BasicOps<B> + Numeric<B>,
|
||||
{
|
||||
outer_dim(x, y, -1)
|
||||
}
|
||||
|
||||
/// Computes the outer product along a specific dimension, broadcasting over others.
|
||||
///
|
||||
/// For the given `dim`, computes the outer product of elements along that dimension,
|
||||
/// expanding it into two dimensions of size ``M × N`` at positions ``(dim, dim + 1)``.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `lhs`: left operand, the "row" tensor, with size `M` at dimension `dim`.
|
||||
/// - `rhs`: right operand, the "col" tensor, with size `N` at dimension `dim`.
|
||||
/// - `dim`: dimension to compute the outer product along (supports negative indexing).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of rank `R = D + 1`, where:
|
||||
///
|
||||
/// ``
|
||||
/// result[..., i, j, ...] = lhs[..., i, ...] * rhs[..., j, ...]
|
||||
/// ``
|
||||
//
|
||||
// Notes:
|
||||
// - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant
|
||||
// than broadcasted elemwise multiply; benchmarking needed to confirm.
|
||||
pub fn outer_dim<B: Backend, const D: usize, const R: usize, Dim: AsIndex, K>(
|
||||
lhs: Tensor<B, D, K>,
|
||||
rhs: Tensor<B, D, K>,
|
||||
dim: Dim,
|
||||
) -> Tensor<B, R, K>
|
||||
where
|
||||
K: BasicOps<B> + Numeric<B>,
|
||||
{
|
||||
assert_eq!(
|
||||
R,
|
||||
D + 1,
|
||||
"`outer` with D={D} expects R={} (got R={R})",
|
||||
D + 1
|
||||
);
|
||||
let dim = dim.expect_dim_index(D);
|
||||
|
||||
// (..., i, 1, ...)
|
||||
let x = lhs.unsqueeze_dim::<R>(dim + 1);
|
||||
|
||||
// (..., 1, j, ...)
|
||||
let y = rhs.unsqueeze_dim::<R>(dim);
|
||||
|
||||
// (..., i, j, ...)
|
||||
x * y
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use super::diag;
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Computes the trace of a matrix.
|
||||
///
|
||||
/// For batched inputs, computes the trace of each matrix in the batch independently.
|
||||
///
|
||||
/// The trace operation sums the diagonal elements of the last two dimensions,
|
||||
/// treating them as the matrix dimensions, while preserving all leading batch dimensions.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor with at least 2 dimensions.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of rank `D - 1`, where the last dimension contains the sum along the diagonals
|
||||
/// of the input.
|
||||
pub fn trace<B: Backend, const D: usize, const DO: usize>(tensor: Tensor<B, D>) -> Tensor<B, DO> {
|
||||
let diag_tensor = diag::<_, D, DO, _>(tensor);
|
||||
|
||||
diag_tensor.sum_dim(DO - 1)
|
||||
}
|
||||
@@ -0,0 +1,291 @@
|
||||
use burn_backend::tensor::Ordered;
|
||||
|
||||
use crate::backend::Backend;
|
||||
use crate::tensor::{BasicOps, Tensor};
|
||||
use crate::{ElementConversion, Numeric};
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::float::Float;
|
||||
/// Specifies the type of norm to compute.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Norm {
|
||||
/// L0 norm (count of non-zero elements)
|
||||
L0,
|
||||
|
||||
/// L1 norm (sum of absolute values)
|
||||
L1,
|
||||
|
||||
/// L2 norm (Euclidean norm)
|
||||
L2,
|
||||
|
||||
/// L:INFINITY norm (maximum absolute value)
|
||||
LInf,
|
||||
|
||||
/// L:NEG_INFINITY norm (minimum absolute value)
|
||||
LNegInf,
|
||||
|
||||
/// Lp norm (generalized norm)
|
||||
Lp(f64),
|
||||
}
|
||||
|
||||
impl Norm {
|
||||
/// Get the exponent of the norm.
|
||||
pub fn to_exponent(self) -> f64 {
|
||||
use Norm::*;
|
||||
match self {
|
||||
L0 => 0.0,
|
||||
L1 => 1.0,
|
||||
L2 => 2.0,
|
||||
LInf => f64::INFINITY,
|
||||
LNegInf => f64::NEG_INFINITY,
|
||||
Lp(p) => p,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for Norm {
|
||||
fn from(value: u32) -> Self {
|
||||
use Norm::*;
|
||||
match value {
|
||||
0 => L0,
|
||||
1 => L1,
|
||||
2 => L2,
|
||||
u32::MAX => LInf,
|
||||
_ => Lp(value as f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for Norm {
|
||||
fn from(value: i32) -> Self {
|
||||
use Norm::*;
|
||||
match value {
|
||||
0 => L0,
|
||||
1 => L1,
|
||||
2 => L2,
|
||||
i32::MAX => LInf,
|
||||
i32::MIN => LNegInf,
|
||||
_ => Lp(value as f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Norm {
|
||||
fn from(value: f32) -> Self {
|
||||
use Norm::*;
|
||||
match value {
|
||||
0.0 => L0,
|
||||
1.0 => L1,
|
||||
2.0 => L2,
|
||||
f32::INFINITY => LInf,
|
||||
f32::NEG_INFINITY => LNegInf,
|
||||
_ => Lp(value as f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for Norm {
|
||||
fn from(value: f64) -> Self {
|
||||
use Norm::*;
|
||||
match value {
|
||||
0.0 => L0,
|
||||
1.0 => L1,
|
||||
2.0 => L2,
|
||||
f64::INFINITY => LInf,
|
||||
f64::NEG_INFINITY => LNegInf,
|
||||
_ => Lp(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the vector norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// Generic dispatch wrapper over specialized / optimized norms.
|
||||
///
|
||||
/// See:
|
||||
/// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)
|
||||
/// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `norm` - The selected norm.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The vector norm of the input tensor.
|
||||
pub fn vector_norm<B: Backend, const D: usize>(
|
||||
x: Tensor<B, D>,
|
||||
norm: impl Into<Norm>,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D> {
|
||||
lp_norm(x, norm.into().to_exponent(), dim)
|
||||
}
|
||||
|
||||
/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// Uses the specialized implementations for:
|
||||
/// * 0.0
|
||||
/// * 1.0
|
||||
/// * 2.0
|
||||
/// * 2 * N for integral N,
|
||||
/// * f64::INFINITY,
|
||||
/// * f64::NEG_INFINITY,
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `p` - The exponent of the Lp norm.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The ``L(p)`` norm of the input tensor.
|
||||
pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
|
||||
match p {
|
||||
0.0 => l0_norm(x, dim),
|
||||
1.0 => l1_norm(x, dim),
|
||||
2.0 => l2_norm(x, dim),
|
||||
p if is_even_integer(p) => lp_signed_norm(x, p as u32, dim),
|
||||
f64::INFINITY => max_abs_norm(x, dim),
|
||||
f64::NEG_INFINITY => min_abs_norm(x, dim),
|
||||
_ => lp_norm_base(x, p, dim),
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize a tensor versus its `vector_norm`.
|
||||
///
|
||||
/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `norm` - The selected norm.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
/// * `eps` - The epsilon for the norm.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The normalized tensor.
|
||||
pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(
|
||||
x: Tensor<B, D>,
|
||||
norm: impl Into<Norm>,
|
||||
dim: usize,
|
||||
eps: E,
|
||||
) -> Tensor<B, D> {
|
||||
let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps);
|
||||
x / norm
|
||||
}
|
||||
|
||||
/// Computes the L0 norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The L0 norm of the input tensor.
|
||||
pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
|
||||
where
|
||||
K: BasicOps<B> + Numeric<B>,
|
||||
{
|
||||
x.zeros_like()
|
||||
.mask_fill(x.not_equal_elem(0), 1)
|
||||
.sum_dim(dim)
|
||||
}
|
||||
|
||||
/// Computes the L1 norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The L1 norm of the input tensor.
|
||||
pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
|
||||
where
|
||||
K: BasicOps<B> + Numeric<B>,
|
||||
{
|
||||
x.abs().sum_dim(dim)
|
||||
}
|
||||
|
||||
/// Computes the L2 norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The L2 norm of the input tensor.
|
||||
pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
x.square().sum_dim(dim).sqrt()
|
||||
}
|
||||
|
||||
fn is_even_integer(x: f64) -> bool {
|
||||
x.fract() == 0.0 && (x as i64) % 2 == 0
|
||||
}
|
||||
|
||||
/// Computes ``L(2*n)`` for even integer ``n``.
|
||||
///
|
||||
/// This lets us skip the abs.
|
||||
fn lp_signed_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: u32, dim: usize) -> Tensor<B, D> {
|
||||
x.powi_scalar(p).sum_dim(dim).powf_scalar(1. / (p as f64))
|
||||
}
|
||||
|
||||
/// Computes the general ``L(p)`` using the generalized method.
|
||||
///
|
||||
/// This uses no specialized implementations and cannot handle:
|
||||
/// * 0.0
|
||||
/// * f64::INFINITY,
|
||||
/// * f64::NEG_INFINITY,
|
||||
fn lp_norm_base<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
|
||||
x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
|
||||
}
|
||||
|
||||
/// Computes the L:INFINITY norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The L:INFINITY norm of the input tensor.
|
||||
pub fn max_abs_norm<B: Backend, const D: usize, K>(
|
||||
x: Tensor<B, D, K>,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
K: Ordered<B>,
|
||||
{
|
||||
x.max_abs_dim(dim)
|
||||
}
|
||||
|
||||
/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `dim` - The dimension to compute the norm over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The L:NEG_INFINITY norm of the input tensor.
|
||||
pub fn min_abs_norm<B: Backend, const D: usize, K>(
|
||||
x: Tensor<B, D, K>,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D, K>
|
||||
where
|
||||
K: Ordered<B>,
|
||||
{
|
||||
x.abs().min_dim(dim)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::{Tensor, activation};
|
||||
|
||||
/// Computes the log softmax cross entropy between logits and target probabilities.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `logits` - The logits.
|
||||
/// * `target_probs` - The target probabilities.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The log softmax cross entropy.
|
||||
pub fn cross_entropy_with_logits<B: Backend, const D: usize>(
|
||||
logits: Tensor<B, D>,
|
||||
target_probs: Tensor<B, D>,
|
||||
) -> Tensor<B, 1> {
|
||||
let tensor = activation::log_softmax(logits, D - 1);
|
||||
let tensor = tensor.mul(target_probs);
|
||||
let tensor = tensor.sum_dim(D - 1);
|
||||
|
||||
tensor.mean().neg()
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
pub(crate) mod stats;
|
||||
|
||||
mod api;
|
||||
|
||||
pub use api::*;
|
||||
|
||||
// Re-exported types
|
||||
pub use burn_backend::{
|
||||
DType, DataError, FloatDType, IntDType, TensorData, TensorMetadata, TensorPrimitive, Tolerance,
|
||||
distribution::*,
|
||||
element::*,
|
||||
indexing::*,
|
||||
ops::TransactionPrimitive,
|
||||
shape::*,
|
||||
slice::*,
|
||||
tensor::{Bool, Float, Int, TensorKind},
|
||||
};
|
||||
|
||||
/// The activation module.
|
||||
pub mod activation;
|
||||
|
||||
/// The backend module.
|
||||
pub mod backend {
|
||||
pub use burn_backend::backend::*;
|
||||
}
|
||||
|
||||
/// The container module.
|
||||
pub mod container {
|
||||
pub use burn_backend::tensor::TensorContainer;
|
||||
}
|
||||
|
||||
/// The grid module.
|
||||
pub mod grid;
|
||||
|
||||
/// The linalg module.
|
||||
pub mod linalg;
|
||||
|
||||
/// The loss module.
|
||||
pub mod loss;
|
||||
|
||||
/// The neural network module.
|
||||
pub mod module;
|
||||
|
||||
/// Operations on tensors module.
|
||||
pub mod ops {
|
||||
pub use burn_backend::backend::ops::*;
|
||||
pub use burn_backend::tensor::{
|
||||
BoolElem, BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
|
||||
};
|
||||
}
|
||||
|
||||
/// Tensor quantization module.
|
||||
pub mod quantization;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use report::*;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod report;
|
||||
|
||||
#[cfg(feature = "experimental-named-tensor")]
|
||||
mod named;
|
||||
|
||||
#[cfg(feature = "experimental-named-tensor")]
|
||||
pub use named::*;
|
||||
|
||||
pub use ops::Device; // Re-export device so that it's available from `burn_tensor::Device`.
|
||||
@@ -0,0 +1,555 @@
|
||||
use crate::{
|
||||
Bool, Int, Tensor, TensorPrimitive,
|
||||
backend::Backend,
|
||||
check,
|
||||
check::TensorCheck,
|
||||
ops::{
|
||||
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode,
|
||||
PaddedConvOptions, UnfoldOptions,
|
||||
},
|
||||
};
|
||||
|
||||
use super::ops::DeformConvOptions;
|
||||
|
||||
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
|
||||
pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::embedding(
|
||||
weights.primitive.tensor(),
|
||||
indices.primitive,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D convolution](crate::ops::ModuleOps::conv1d).
|
||||
///
|
||||
/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
|
||||
/// asymmetric padding. When asymmetric padding is specified, an explicit pad
|
||||
/// operation is applied before the convolution backend op.
|
||||
pub fn conv1d<B>(
|
||||
x: Tensor<B, 3>,
|
||||
weight: Tensor<B, 3>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: impl Into<PaddedConvOptions<1>>,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let padded_options = options.into();
|
||||
check!(TensorCheck::conv(
|
||||
"conv1d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
padded_options.options.groups,
|
||||
));
|
||||
|
||||
if let Some(padding_end) = padded_options.padding_end {
|
||||
let left = padded_options.options.padding[0];
|
||||
let right = padding_end[0];
|
||||
// For 1D (NCL format), pad the length dimension
|
||||
let padded = x.pad((left, right, 0, 0), PadMode::Constant(0.0));
|
||||
let zero_options = ConvOptions::new(
|
||||
padded_options.options.stride,
|
||||
[0],
|
||||
padded_options.options.dilation,
|
||||
padded_options.options.groups,
|
||||
);
|
||||
Tensor::new(TensorPrimitive::Float(B::conv1d(
|
||||
padded.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
zero_options,
|
||||
)))
|
||||
} else {
|
||||
Tensor::new(TensorPrimitive::Float(B::conv1d(
|
||||
x.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
padded_options.options,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).
|
||||
///
|
||||
/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
|
||||
/// asymmetric padding. When asymmetric padding is specified, an explicit pad
|
||||
/// operation is applied before the convolution backend op.
|
||||
pub fn conv2d<B>(
|
||||
x: Tensor<B, 4>,
|
||||
weight: Tensor<B, 4>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: impl Into<PaddedConvOptions<2>>,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let padded_options = options.into();
|
||||
check!(TensorCheck::conv(
|
||||
"conv2d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
padded_options.options.groups,
|
||||
));
|
||||
|
||||
if let Some(padding_end) = padded_options.padding_end {
|
||||
let top = padded_options.options.padding[0];
|
||||
let left = padded_options.options.padding[1];
|
||||
let bottom = padding_end[0];
|
||||
let right = padding_end[1];
|
||||
// For 2D (NCHW format), pad height and width
|
||||
let padded = x.pad((left, right, top, bottom), PadMode::Constant(0.0));
|
||||
let zero_options = ConvOptions::new(
|
||||
padded_options.options.stride,
|
||||
[0, 0],
|
||||
padded_options.options.dilation,
|
||||
padded_options.options.groups,
|
||||
);
|
||||
Tensor::new(TensorPrimitive::Float(B::conv2d(
|
||||
padded.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
zero_options,
|
||||
)))
|
||||
} else {
|
||||
Tensor::new(TensorPrimitive::Float(B::conv2d(
|
||||
x.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
padded_options.options,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).
|
||||
///
|
||||
/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
|
||||
/// asymmetric padding. Asymmetric 3D padding is not yet supported.
|
||||
pub fn conv3d<B>(
|
||||
x: Tensor<B, 5>,
|
||||
weight: Tensor<B, 5>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: impl Into<PaddedConvOptions<3>>,
|
||||
) -> Tensor<B, 5>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let padded_options = options.into();
|
||||
check!(TensorCheck::conv(
|
||||
"conv3d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
padded_options.options.groups,
|
||||
));
|
||||
|
||||
if padded_options.is_asymmetric() {
|
||||
panic!("Asymmetric padding is not yet supported for conv3d");
|
||||
}
|
||||
|
||||
Tensor::new(TensorPrimitive::Float(B::conv3d(
|
||||
x.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
padded_options.options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
|
||||
pub fn deform_conv2d<B>(
|
||||
x: Tensor<B, 4>,
|
||||
offset: Tensor<B, 4>,
|
||||
weight: Tensor<B, 4>,
|
||||
mask: Option<Tensor<B, 4>>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
check!(TensorCheck::conv(
|
||||
"deform_conv2d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
options.weight_groups,
|
||||
));
|
||||
Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
|
||||
x.primitive.tensor(),
|
||||
offset.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
mask.map(|m| m.primitive.tensor()),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
|
||||
pub fn conv_transpose1d<B>(
|
||||
x: Tensor<B, 3>,
|
||||
weight: Tensor<B, 3>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
check!(TensorCheck::conv_transpose(
|
||||
"conv_transpose1d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
));
|
||||
Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
|
||||
x.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).
|
||||
pub fn conv_transpose2d<B>(
|
||||
x: Tensor<B, 4>,
|
||||
weight: Tensor<B, 4>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
check!(TensorCheck::conv_transpose(
|
||||
"conv_transpose2d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
));
|
||||
Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
|
||||
x.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).
|
||||
pub fn conv_transpose3d<B>(
|
||||
x: Tensor<B, 5>,
|
||||
weight: Tensor<B, 5>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> Tensor<B, 5>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
check!(TensorCheck::conv_transpose(
|
||||
"conv_transpose3d",
|
||||
x.dims(),
|
||||
weight.dims(),
|
||||
));
|
||||
Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
|
||||
x.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).
|
||||
pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::unfold4d(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
|
||||
pub fn max_pool1d<B>(
|
||||
x: Tensor<B, 3>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::max_pool1d(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
|
||||
pub fn max_pool2d<B>(
|
||||
x: Tensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::max_pool2d(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
|
||||
pub fn avg_pool2d<B>(
|
||||
x: Tensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
|
||||
pub fn avg_pool1d<B>(
|
||||
x: Tensor<B, 3>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
|
||||
pub fn max_pool1d_with_indices<B>(
|
||||
x: Tensor<B, 3>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let output = B::max_pool1d_with_indices(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
(
|
||||
Tensor::new(TensorPrimitive::Float(output.output)),
|
||||
Tensor::new(output.indices),
|
||||
)
|
||||
}
|
||||
|
||||
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
|
||||
pub fn max_pool2d_with_indices<B>(
|
||||
x: Tensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let output = B::max_pool2d_with_indices(
|
||||
x.primitive.tensor(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
(
|
||||
Tensor::new(TensorPrimitive::Float(output.output)),
|
||||
Tensor::new(output.indices),
|
||||
)
|
||||
}
|
||||
|
||||
/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
|
||||
pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
|
||||
x.primitive.tensor(),
|
||||
output_size,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).
|
||||
pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
|
||||
x.primitive.tensor(),
|
||||
output_size,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).
|
||||
pub fn interpolate<B>(
|
||||
x: Tensor<B, 4>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::interpolate(
|
||||
x.primitive.tensor(),
|
||||
output_size,
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a linear transformation to the input tensor using the given weight and bias.
|
||||
///
|
||||
/// ```math
|
||||
/// y = x @ weight + [bias]
|
||||
/// ```
|
||||
///
|
||||
/// # Arguments:
|
||||
///
|
||||
/// - `input` is the input tensor, ``[..., d_input]``.
|
||||
/// - `weight` is the weight tensor, ``[d_input, d_output]``.
|
||||
/// - `bias` is the bias tensor (optional), ``[d_output]``.
|
||||
///
|
||||
/// # Returns:
|
||||
///
|
||||
/// The transformed tensor, ``[..., d_output]``.
|
||||
///
|
||||
/// # Compatibility
|
||||
///
|
||||
/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
|
||||
/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
|
||||
/// multiplication:
|
||||
///
|
||||
/// ```math
|
||||
/// y = x @ weight^T + [bias]
|
||||
/// ```
|
||||
pub fn linear<B: Backend, const D: usize>(
|
||||
input: Tensor<B, D>,
|
||||
weight: Tensor<B, 2>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
) -> Tensor<B, D> {
|
||||
if D == 1 {
|
||||
// Insert and remove an extra batch dimension for the batch matmul to work.
|
||||
let input = input.unsqueeze::<2>();
|
||||
let output = linear(input, weight, bias);
|
||||
return output.squeeze_dim(0);
|
||||
}
|
||||
|
||||
// Perform broadcasting
|
||||
//
|
||||
// Important to be done before doing operations to easily fuse.
|
||||
let weight = weight.unsqueeze::<D>();
|
||||
let bias = bias.map(|bias| bias.unsqueeze::<D>());
|
||||
|
||||
let output = input.matmul(weight);
|
||||
match bias {
|
||||
Some(bias) => output.add(bias),
|
||||
None => output,
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
|
||||
/// where scale defaults to 1/sqrt(head_dim) (configurable via `options.scale`).
|
||||
/// Optionally applies masking, additive bias, causal masking, and softcap.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
|
||||
/// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
|
||||
/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
|
||||
/// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
|
||||
/// where `true` indicates positions to mask (i.e. set to -inf before softmax).
|
||||
/// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
|
||||
/// added to the attention scores before softmax (e.g. ALiBi, relative position biases).
|
||||
/// - `options`: Additional attention options (custom scale, softcap, causal masking).
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
|
||||
/// representing the attended context per head.
|
||||
///
|
||||
/// # Note
|
||||
/// This implementation does not support dropout and is intended for inference or
|
||||
/// use cases where dropout is not needed.
|
||||
pub fn attention<B: Backend>(
|
||||
query: Tensor<B, 4>,
|
||||
key: Tensor<B, 4>,
|
||||
value: Tensor<B, 4>,
|
||||
mask: Option<Tensor<B, 4, Bool>>,
|
||||
attn_bias: Option<Tensor<B, 4>>,
|
||||
options: AttentionModuleOptions,
|
||||
) -> Tensor<B, 4> {
|
||||
Tensor::new(TensorPrimitive::Float(B::attention(
|
||||
query.primitive.tensor(),
|
||||
key.primitive.tensor(),
|
||||
value.primitive.tensor(),
|
||||
mask.map(|mask| mask.primitive),
|
||||
attn_bias.map(|bias| bias.primitive.tensor()),
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Exports attention fallback to test backend's attention against.
|
||||
pub fn attention_fallback<B: Backend>(
|
||||
query: Tensor<B, 4>,
|
||||
key: Tensor<B, 4>,
|
||||
value: Tensor<B, 4>,
|
||||
mask: Option<Tensor<B, 4, Bool>>,
|
||||
attn_bias: Option<Tensor<B, 4>>,
|
||||
options: AttentionModuleOptions,
|
||||
) -> Tensor<B, 4> {
|
||||
Tensor::new(TensorPrimitive::Float(
|
||||
crate::ops::attention::attention_fallback::<B>(
|
||||
query.primitive.tensor(),
|
||||
key.primitive.tensor(),
|
||||
value.primitive.tensor(),
|
||||
mask.map(|mask| mask.primitive),
|
||||
attn_bias.map(|bias| bias.primitive.tensor()),
|
||||
options,
|
||||
),
|
||||
))
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
use alloc::format;
|
||||
|
||||
use crate::backend::Backend;
|
||||
use crate::{Distribution, NamedDims, Shape, Tensor};
|
||||
|
||||
/// A tensor with named dimensions.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
|
||||
pub(crate) tensor: D::Tensor,
|
||||
}
|
||||
|
||||
impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<NamedTensor<B, ND>>
|
||||
for Tensor<B, D>
|
||||
{
|
||||
fn from(nt: NamedTensor<B, ND>) -> Self {
|
||||
nt.tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<Tensor<B, D>>
|
||||
for NamedTensor<B, ND>
|
||||
{
|
||||
fn from(tensor: Tensor<B, D>) -> Self {
|
||||
Self::from_tensor(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND: NamedDims<B>> core::fmt::Display for NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(&format!(
|
||||
"NamedTensor[shape={:?}, dims={}]",
|
||||
self.shape(),
|
||||
ND::to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
/// Create a named tensor from a tensor.
|
||||
pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
|
||||
Self { tensor }
|
||||
}
|
||||
|
||||
/// Create a random named tensor of the given shape where each element is sampled from
|
||||
/// the given distribution.
|
||||
pub fn random<S: Into<Shape>>(
|
||||
shape: S,
|
||||
distribution: Distribution,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
Self::from_tensor(Tensor::random(shape, distribution, device))
|
||||
}
|
||||
|
||||
/// Returns the shape of the current tensor.
|
||||
pub fn shape(&self) -> Shape {
|
||||
self.tensor.shape()
|
||||
}
|
||||
|
||||
/// Applies element wise multiplication operation.
|
||||
///
|
||||
/// `y = x2 * x1`
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn mul(self, rhs: Self) -> Self {
|
||||
Self::from_tensor(self.tensor.mul(rhs.tensor))
|
||||
}
|
||||
|
||||
/// Reshape the tensor to have the given shape.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor can not be reshape to the given shape.
|
||||
pub fn reshape<const D2: usize, S, ND2>(self, shape: S, _: ND2) -> NamedTensor<B, ND2>
|
||||
where
|
||||
S: Into<Shape>,
|
||||
ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
|
||||
{
|
||||
NamedTensor::from_tensor(self.tensor.reshape(shape.into()))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
|
||||
use crate::Tensor;
|
||||
use crate::backend::Backend;
|
||||
|
||||
/// Dimension trait.
|
||||
pub trait Dim: core::fmt::Debug {
|
||||
/// Converts the dimension to a string.
|
||||
fn to_string() -> String;
|
||||
}
|
||||
|
||||
/// Named dimensions trait.
|
||||
pub trait NamedDims<B: Backend>: core::fmt::Debug {
|
||||
/// Tensor type.
|
||||
type Tensor;
|
||||
|
||||
/// Converts the named dimensions to a string.
|
||||
fn to_string() -> String;
|
||||
}
|
||||
|
||||
/// Named dimension macro.
|
||||
#[macro_export]
|
||||
macro_rules! NamedDim {
|
||||
($name:ident) => {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct $name;
|
||||
impl Dim for $name {
|
||||
fn to_string() -> String {
|
||||
stringify!($name).to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl<B: Backend, D1> NamedDims<B> for (D1,)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 1>;
|
||||
fn to_string() -> String {
|
||||
format!("[{}]", D1::to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, D1, D2> NamedDims<B> for (D1, D2)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
D2: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 2>;
|
||||
fn to_string() -> String {
|
||||
format!("[{}, {}]", D1::to_string(), D2::to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, D1, D2, D3> NamedDims<B> for (D1, D2, D3)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
D2: Dim,
|
||||
D3: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 3>;
|
||||
fn to_string() -> String {
|
||||
format!(
|
||||
"[{}, {}, {}]",
|
||||
D1::to_string(),
|
||||
D2::to_string(),
|
||||
D3::to_string()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, D1, D2, D3, D4> NamedDims<B> for (D1, D2, D3, D4)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
D2: Dim,
|
||||
D3: Dim,
|
||||
D4: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 4>;
|
||||
fn to_string() -> String {
|
||||
format!(
|
||||
"[{}, {}, {}, {}]",
|
||||
D1::to_string(),
|
||||
D2::to_string(),
|
||||
D3::to_string(),
|
||||
D4::to_string()
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::{Dim, NamedDims, NamedTensor, Tensor};
|
||||
|
||||
pub trait Matmul<Rhs, Out> {
|
||||
fn matmul(self, rhs: Rhs) -> Out;
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
/// Applies the matrix multiplication operation.
|
||||
///
|
||||
/// `C = AB`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the two tensors dont' have a compatible shape.
|
||||
pub fn matmul<NamedDimsRhs, NamedDimsOut>(
|
||||
self,
|
||||
rhs: NamedTensor<B, NamedDimsRhs>,
|
||||
) -> NamedTensor<B, NamedDimsOut>
|
||||
where
|
||||
NamedDimsRhs: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
NamedDimsOut: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
Self: Matmul<NamedTensor<B, NamedDimsRhs>, NamedTensor<B, NamedDimsOut>>,
|
||||
{
|
||||
Matmul::matmul(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, X: Dim, Y: Dim, Z: Dim> Matmul<NamedTensor<B, (Y, Z)>, NamedTensor<B, (X, Z)>>
|
||||
for NamedTensor<B, (X, Y)>
|
||||
{
|
||||
fn matmul(self, rhs: NamedTensor<B, (Y, Z)>) -> NamedTensor<B, (X, Z)> {
|
||||
NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, Batch: Dim, X: Dim, Y: Dim, Z: Dim>
|
||||
Matmul<NamedTensor<B, (Batch, Y, Z)>, NamedTensor<B, (Batch, X, Z)>>
|
||||
for NamedTensor<B, (Batch, X, Y)>
|
||||
{
|
||||
fn matmul(self, rhs: NamedTensor<B, (Batch, Y, Z)>) -> NamedTensor<B, (Batch, X, Z)> {
|
||||
NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, Batch1: Dim, Batch2: Dim, X: Dim, Y: Dim, Z: Dim>
|
||||
Matmul<NamedTensor<B, (Batch1, Batch2, Y, Z)>, NamedTensor<B, (Batch1, Batch2, X, Z)>>
|
||||
for NamedTensor<B, (Batch1, Batch2, X, Y)>
|
||||
{
|
||||
fn matmul(
|
||||
self,
|
||||
rhs: NamedTensor<B, (Batch1, Batch2, Y, Z)>,
|
||||
) -> NamedTensor<B, (Batch1, Batch2, X, Z)> {
|
||||
NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod base;
|
||||
mod dims;
|
||||
mod matmul;
|
||||
mod swap_dims;
|
||||
|
||||
pub use base::*;
|
||||
pub use dims::*;
|
||||
@@ -0,0 +1,62 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::{Dim, NamedDims, NamedTensor, Tensor};
|
||||
|
||||
pub trait SwapDims<N, const D1: usize, const D2: usize> {
|
||||
fn swap_dims(self) -> N;
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
/// Swap two dimensions.
|
||||
pub fn swap_dims<ND2, const D1: usize, const D2: usize>(self) -> NamedTensor<B, ND2>
|
||||
where
|
||||
ND2: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
Self: SwapDims<NamedTensor<B, ND2>, D1, D2>,
|
||||
{
|
||||
SwapDims::swap_dims(self)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! generate_permute {
|
||||
(2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
|
||||
impl<B: Backend, D1: Dim, D2: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
|
||||
for NamedTensor<B, (D1, D2)>
|
||||
{
|
||||
fn swap_dims(self) -> NamedTensor<B, $output> {
|
||||
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
|
||||
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
|
||||
for NamedTensor<B, (D1, D2, D3)>
|
||||
{
|
||||
fn swap_dims(self) -> NamedTensor<B, $output> {
|
||||
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
|
||||
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
|
||||
SwapDims<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
|
||||
{
|
||||
fn swap_dims(self) -> NamedTensor<B, $output> {
|
||||
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
generate_permute!(2 => (D2, D1), (0, 1));
|
||||
generate_permute!(3 => (D2, D1, D3), (0, 1));
|
||||
generate_permute!(3 => (D3, D2, D1), (0, 2));
|
||||
generate_permute!(3 => (D1, D3, D2), (1, 2));
|
||||
generate_permute!(4 => (D2, D1, D3, D4), (0, 1));
|
||||
generate_permute!(4 => (D3, D2, D1, D4), (0, 2));
|
||||
generate_permute!(4 => (D4, D2, D3, D1), (0, 3));
|
||||
generate_permute!(4 => (D1, D3, D2, D4), (1, 2));
|
||||
generate_permute!(4 => (D1, D4, D3, D2), (1, 3));
|
||||
@@ -0,0 +1,52 @@
|
||||
use crate::{Tensor, TensorPrimitive, backend::Backend};
|
||||
use burn_backend::tensor::quantization;
|
||||
|
||||
// We re-export those types.
|
||||
pub use burn_backend::{QTensorPrimitive, quantization::*};
|
||||
|
||||
/// The tensor quantization parameters.
|
||||
pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
|
||||
|
||||
/// The observed input calibration range.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CalibrationRange<B: Backend> {
|
||||
/// Minimum observed value(s).
|
||||
pub min: Tensor<B, 1>,
|
||||
/// Maximum observed value(s).
|
||||
pub max: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
/// Compute the quantization range mapping.
|
||||
pub fn compute_range<B: Backend, const D: usize>(
|
||||
scheme: &QuantScheme,
|
||||
tensor: &Tensor<B, D>,
|
||||
calibration: &Calibration,
|
||||
) -> CalibrationRange<B> {
|
||||
let (min, max) = match &tensor.primitive {
|
||||
TensorPrimitive::Float(tensor) => {
|
||||
quantization::compute_range::<B>(scheme, tensor.clone(), calibration)
|
||||
}
|
||||
TensorPrimitive::QFloat(_) => unreachable!(),
|
||||
};
|
||||
|
||||
CalibrationRange {
|
||||
min: Tensor::from_primitive(TensorPrimitive::Float(min)),
|
||||
max: Tensor::from_primitive(TensorPrimitive::Float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the quantization parameters.
|
||||
pub fn compute_q_params<B: Backend>(
|
||||
scheme: &QuantScheme,
|
||||
range: CalibrationRange<B>,
|
||||
) -> QuantizationParameters<B> {
|
||||
match (range.min.primitive, range.max.primitive) {
|
||||
(TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => {
|
||||
let qparams = quantization::compute_q_params::<B>(scheme, min, max);
|
||||
QuantizationParameters {
|
||||
scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
use super::{Tensor, backend::Backend};
|
||||
|
||||
use colored::*;
|
||||
|
||||
/// Checks the closeness of two tensors and prints the results.
|
||||
///
|
||||
/// Compares tensors by checking the absolute difference between each element.
|
||||
/// Prints the percentage of elements within specified tolerances.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `output` - The output tensor.
|
||||
/// * `expected` - The expected tensor.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{check_closeness, Tensor};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let tensor1 = Tensor::<B, 1>::from_floats(
|
||||
/// [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
|
||||
/// &device,
|
||||
/// );
|
||||
/// let tensor2 = Tensor::<B, 1>::from_floats(
|
||||
/// [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
|
||||
/// &device,
|
||||
/// );
|
||||
/// check_closeness(&tensor1, &tensor2);
|
||||
///}
|
||||
/// ```
|
||||
///
|
||||
/// # Output
|
||||
///
|
||||
/// ```text
|
||||
/// Tensor Closeness Check Results:
|
||||
/// ===============================
|
||||
/// Epsilon: 1e-1
|
||||
/// Close elements: 10/10 (100.00%)
|
||||
/// [PASS] All elements are within tolerance
|
||||
///
|
||||
/// Epsilon: 1e-2
|
||||
/// Close elements: 10/10 (100.00%)
|
||||
/// [PASS] All elements are within tolerance
|
||||
///
|
||||
/// Epsilon: 1e-3
|
||||
/// Close elements: 9/10 (90.00%)
|
||||
/// [WARN] Most elements are within tolerance
|
||||
///
|
||||
/// Epsilon: 1e-4
|
||||
/// Close elements: 6/10 (60.00%)
|
||||
/// [FAIL] Significant differences detected
|
||||
///
|
||||
/// Epsilon: 1e-5
|
||||
/// Close elements: 5/10 (50.00%)
|
||||
/// [FAIL] Significant differences detected
|
||||
///
|
||||
/// Epsilon: 1e-6
|
||||
/// Close elements: 5/10 (50.00%)
|
||||
/// [FAIL] Significant differences detected
|
||||
///
|
||||
/// Epsilon: 1e-7
|
||||
/// Close elements: 5/10 (50.00%)
|
||||
/// [FAIL] Significant differences detected
|
||||
///
|
||||
/// Epsilon: 1e-8
|
||||
/// Close elements: 5/10 (50.00%)
|
||||
/// [FAIL] Significant differences detected
|
||||
///
|
||||
/// Closeness check complete.
|
||||
/// ```
|
||||
pub fn check_closeness<B: Backend, const D: usize>(output: &Tensor<B, D>, expected: &Tensor<B, D>) {
|
||||
println!("{}", "Tensor Closeness Check Results:".bold());
|
||||
println!("===============================");
|
||||
|
||||
for epsilon in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8].iter() {
|
||||
println!("{} {:e}", "Epsilon:".bold(), epsilon);
|
||||
|
||||
let close = output
|
||||
.clone()
|
||||
.is_close(expected.clone(), Some(*epsilon), Some(*epsilon));
|
||||
let data = close.clone().into_data();
|
||||
let num_elements = data.num_elements();
|
||||
|
||||
// Count the number of elements that are close (true)
|
||||
let count = data.iter::<bool>().filter(|x| *x).count();
|
||||
|
||||
let percentage = (count as f64 / num_elements as f64) * 100.0;
|
||||
|
||||
println!(" Close elements: {count}/{num_elements} ({percentage:.2}%)");
|
||||
|
||||
if percentage == 100.0 {
|
||||
println!(" {} All elements are within tolerance", "[PASS]".green());
|
||||
} else if percentage >= 90.0 {
|
||||
println!(" {} Most elements are within tolerance", "[WARN]".yellow());
|
||||
} else {
|
||||
println!(" {} Significant differences detected", "[FAIL]".red());
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
println!("{}", "Closeness check complete.".bold());
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
use crate::{Tensor, backend::Backend};
|
||||
use burn_backend::tensor::Int;
|
||||
|
||||
pub fn var<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
let mean = tensor.clone().mean_dim(dim);
|
||||
var_with_mean(tensor, mean, dim)
|
||||
}
|
||||
|
||||
pub fn var_with_mean<B: Backend, const D: usize>(
|
||||
tensor: Tensor<B, D>,
|
||||
mean: Tensor<B, D>,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D> {
|
||||
let n = tensor.shape()[dim] - 1;
|
||||
var_with_mean_n(tensor, mean, dim, n)
|
||||
}
|
||||
|
||||
pub fn var_bias<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
let mean = tensor.clone().mean_dim(dim);
|
||||
var_with_mean_bias(tensor, mean, dim)
|
||||
}
|
||||
|
||||
pub fn var_with_mean_bias<B: Backend, const D: usize>(
|
||||
tensor: Tensor<B, D>,
|
||||
mean: Tensor<B, D>,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D> {
|
||||
let n = tensor.shape()[dim];
|
||||
var_with_mean_n(tensor, mean, dim, n)
|
||||
}
|
||||
|
||||
pub fn var_with_mean_n<B: Backend, const D: usize>(
|
||||
tensor: Tensor<B, D>,
|
||||
mean: Tensor<B, D>,
|
||||
dim: usize,
|
||||
n: usize,
|
||||
) -> Tensor<B, D> {
|
||||
tensor.sub(mean).square().sum_dim(dim).div_scalar(n as f32)
|
||||
}
|
||||
|
||||
pub fn median<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
let total_elem_numbers = tensor.dims()[dim];
|
||||
let sorted_tensor = tensor.sort(dim);
|
||||
|
||||
// Following the PyTorch behavior:
|
||||
// - Odd count: the median
|
||||
// - Even count: the lower of the two median elements
|
||||
//
|
||||
// Example:
|
||||
// - 5 elements: (5 - 1) / 2 = 4 / 2 = 2
|
||||
// - 4 elements: (4 - 1) / 2 = 3 / 2 = 1
|
||||
let median_index = (total_elem_numbers - 1) / 2;
|
||||
sorted_tensor.narrow(dim, median_index, 1)
|
||||
}
|
||||
|
||||
pub fn median_with_indices<B: Backend, const D: usize>(
|
||||
tensor: Tensor<B, D>,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
let total_elem_numbers = tensor.dims()[dim];
|
||||
let (sorted_tensor, indices) = tensor.sort_with_indices(dim);
|
||||
|
||||
// Following the PyTorch behavior:
|
||||
// - Odd count: the median
|
||||
// - Even count: the lower of the two median elements
|
||||
//
|
||||
// Example:
|
||||
// - 5 elements: (5 - 1) / 2 = 4 / 2 = 2
|
||||
// - 4 elements: (4 - 1) / 2 = 3 / 2 = 1
|
||||
let median_index = (total_elem_numbers - 1) / 2;
|
||||
let median_values = sorted_tensor.narrow(dim, median_index, 1);
|
||||
let median_indices = indices.narrow(dim, median_index, 1);
|
||||
(median_values, median_indices)
|
||||
}
|
||||
Reference in New Issue
Block a user