feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,87 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Neural network building blocks for the Burn deep learning framework"
documentation = "https://docs.rs/burn-nn"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-nn"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-nn"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"std",
"burn-core/default",
]
doc = [
"std",
# Doc features
"burn-core/doc",
]
std = [
"burn-core/std",
"num-traits/std",
]
tracing = [
"burn-core/tracing",
"burn-cuda?/tracing",
"burn-rocm?/tracing",
"burn-tch?/tracing",
"burn-wgpu?/tracing",
"burn-fusion?/tracing",
]
test-cuda = [
"burn-cuda/default",
] # To use cuda during testing, default uses ndarray.
test-rocm = [
"burn-rocm/default",
] # To use hip during testing, default uses ndarray.
test-tch = [
"burn-tch/default",
] # To use tch during testing, default uses ndarray.
test-wgpu = [
"burn-wgpu/default",
] # To use wgpu during testing, default uses ndarray.
test-vulkan = [
"test-wgpu",
"burn-wgpu/vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.
test-metal = [
"test-wgpu",
"burn-wgpu/metal",
] # To use wgpu-spirv during testing, default uses ndarray.
# Memory checks are disabled by default
test-memory-checks = ["burn-fusion/memory-checks"]
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false }
num-traits = { workspace = true }
# FOR TESTING
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
rstest = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1,3 @@
# Burn Neural Networks
Core building blocks for Burn neural networks.

View File

@@ -0,0 +1,598 @@
use burn_core as burn;
use crate::activation::{
Celu, CeluConfig, Elu, EluConfig, Gelu, HardShrink, HardShrinkConfig, HardSigmoid,
HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu, PReluConfig, Relu, Selu,
Shrink, ShrinkConfig, Sigmoid, SoftShrink, SoftShrinkConfig, Softplus, SoftplusConfig,
Softsign, SwiGlu, SwiGluConfig, Tanh, ThresholdedRelu, ThresholdedReluConfig,
};
use burn::config::Config;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// [`Activation`] Configuration.
#[derive(Config, Debug)]
#[non_exhaustive]
pub enum ActivationConfig {
/// [`Gelu`] activation layer.
Gelu,
/// [`Gelu`] activation layer with tanh approximation.
GeluApproximate,
/// [`PRelu`] activation layer.
PRelu(PReluConfig),
/// [`Relu`] activation layer.
Relu,
/// [`LeakyRelu`] activation layer.
LeakyRelu(LeakyReluConfig),
/// [`SwiGlu`] activation layer.
SwiGlu(SwiGluConfig),
/// [`Selu`] activation layer.
Selu,
/// [`Sigmoid`] activation layer.
Sigmoid,
/// [`Tanh`] activation layer.
Tanh,
/// [`HardSigmoid`] activation layer.
HardSigmoid(HardSigmoidConfig),
/// [`HardSwish`] activation layer.
HardSwish,
/// [`Softplus`] activation layer.
Softplus(SoftplusConfig),
/// [`Softsign`] activation layer.
Softsign,
/// [`Elu`] activation layer.
Elu(EluConfig),
/// [`Celu`] activation layer.
Celu(CeluConfig),
/// [`ThresholdedRelu`] activation layer.
ThresholdedRelu(ThresholdedReluConfig),
/// [`HardShrink`] activation layer.
HardShrink(HardShrinkConfig),
/// [`SoftShrink`] activation layer.
SoftShrink(SoftShrinkConfig),
/// [`Shrink`] activation layer.
Shrink(ShrinkConfig),
}
impl From<PReluConfig> for ActivationConfig {
fn from(config: PReluConfig) -> Self {
Self::PRelu(config)
}
}
impl From<LeakyReluConfig> for ActivationConfig {
fn from(config: LeakyReluConfig) -> Self {
Self::LeakyRelu(config)
}
}
impl From<SwiGluConfig> for ActivationConfig {
fn from(config: SwiGluConfig) -> Self {
Self::SwiGlu(config)
}
}
impl From<HardSigmoidConfig> for ActivationConfig {
fn from(config: HardSigmoidConfig) -> Self {
Self::HardSigmoid(config)
}
}
impl From<SoftplusConfig> for ActivationConfig {
fn from(config: SoftplusConfig) -> Self {
Self::Softplus(config)
}
}
impl From<EluConfig> for ActivationConfig {
fn from(config: EluConfig) -> Self {
Self::Elu(config)
}
}
impl From<CeluConfig> for ActivationConfig {
fn from(config: CeluConfig) -> Self {
Self::Celu(config)
}
}
impl From<ThresholdedReluConfig> for ActivationConfig {
fn from(config: ThresholdedReluConfig) -> Self {
Self::ThresholdedRelu(config)
}
}
impl From<HardShrinkConfig> for ActivationConfig {
fn from(config: HardShrinkConfig) -> Self {
Self::HardShrink(config)
}
}
impl From<SoftShrinkConfig> for ActivationConfig {
fn from(config: SoftShrinkConfig) -> Self {
Self::SoftShrink(config)
}
}
impl From<ShrinkConfig> for ActivationConfig {
fn from(config: ShrinkConfig) -> Self {
Self::Shrink(config)
}
}
impl ActivationConfig {
/// Initialize a wrapped activation layer.
pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {
match self {
ActivationConfig::Relu => Relu.into(),
ActivationConfig::LeakyRelu(conf) => conf.init().into(),
ActivationConfig::Gelu => Gelu::new().into(),
ActivationConfig::GeluApproximate => Gelu::new_approximate().into(),
ActivationConfig::PRelu(conf) => conf.init(device).into(),
ActivationConfig::SwiGlu(conf) => conf.init(device).into(),
ActivationConfig::HardSigmoid(conf) => conf.init().into(),
ActivationConfig::HardSwish => HardSwish.into(),
ActivationConfig::Softplus(conf) => conf.init().into(),
ActivationConfig::Selu => Selu.into(),
ActivationConfig::Sigmoid => Sigmoid.into(),
ActivationConfig::Tanh => Tanh.into(),
ActivationConfig::Softsign => Softsign.into(),
ActivationConfig::Elu(conf) => conf.init().into(),
ActivationConfig::Celu(conf) => conf.init().into(),
ActivationConfig::HardShrink(conf) => conf.init().into(),
ActivationConfig::SoftShrink(conf) => conf.init().into(),
ActivationConfig::Shrink(conf) => conf.init().into(),
ActivationConfig::ThresholdedRelu(conf) => conf.init().into(),
}
}
}
/// Activation Layer Wrapper.
///
/// Provides support for many in-built `burn::nn` activations.
#[derive(Module, Debug)]
#[non_exhaustive]
#[allow(clippy::large_enum_variant)]
pub enum Activation<B: Backend> {
/// [`Gelu`] activation layer.
Gelu(Gelu),
/// [`PRelu`] activation layer.
PRelu(PRelu<B>),
/// [`Relu`] activation layer.
Relu(Relu),
/// [`LeakyRelu`] activation layer.
LeakyRelu(LeakyRelu),
/// [`SwiGlu`] activation layer.
SwiGlu(SwiGlu<B>),
/// [`Selu`] activation layer.
Selu(Selu),
/// [`Sigmoid`] activation layer.
Sigmoid(Sigmoid),
/// [`Tanh`] activation layer.
Tanh(Tanh),
/// [`HardSigmoid`] activation layer.
HardSigmoid(HardSigmoid),
/// [`HardSwish`] activation layer.
HardSwish(HardSwish),
/// [`Softplus`] activation layer.
Softplus(Softplus),
/// [`Softsign`] activation layer.
Softsign(Softsign),
/// [`Elu`] activation layer.
Elu(Elu),
/// [`Celu`] activation layer.
Celu(Celu),
/// [`ThresholdedRelu`] activation layer.
ThresholdedRelu(ThresholdedRelu),
/// [`HardShrink`] activation layer.
HardShrink(HardShrink),
/// [`SoftShrink`] activation layer.
SoftShrink(SoftShrink),
/// [`Shrink`] activation layer.
Shrink(Shrink),
}
impl<B: Backend> From<Gelu> for Activation<B> {
fn from(layer: Gelu) -> Self {
Self::Gelu(layer)
}
}
impl<B: Backend> From<PRelu<B>> for Activation<B> {
fn from(layer: PRelu<B>) -> Self {
Self::PRelu(layer)
}
}
impl<B: Backend> From<Relu> for Activation<B> {
fn from(layer: Relu) -> Self {
Self::Relu(layer)
}
}
impl<B: Backend> From<LeakyRelu> for Activation<B> {
fn from(layer: LeakyRelu) -> Self {
Self::LeakyRelu(layer)
}
}
impl<B: Backend> From<SwiGlu<B>> for Activation<B> {
fn from(layer: SwiGlu<B>) -> Self {
Self::SwiGlu(layer)
}
}
impl<B: Backend> From<Selu> for Activation<B> {
fn from(layer: Selu) -> Self {
Self::Selu(layer)
}
}
impl<B: Backend> From<Sigmoid> for Activation<B> {
fn from(layer: Sigmoid) -> Self {
Self::Sigmoid(layer)
}
}
impl<B: Backend> From<Tanh> for Activation<B> {
fn from(layer: Tanh) -> Self {
Self::Tanh(layer)
}
}
impl<B: Backend> From<HardSigmoid> for Activation<B> {
fn from(layer: HardSigmoid) -> Self {
Self::HardSigmoid(layer)
}
}
impl<B: Backend> From<HardSwish> for Activation<B> {
fn from(layer: HardSwish) -> Self {
Self::HardSwish(layer)
}
}
impl<B: Backend> From<Softplus> for Activation<B> {
fn from(layer: Softplus) -> Self {
Self::Softplus(layer)
}
}
impl<B: Backend> From<Softsign> for Activation<B> {
fn from(layer: Softsign) -> Self {
Self::Softsign(layer)
}
}
impl<B: Backend> From<Elu> for Activation<B> {
fn from(layer: Elu) -> Self {
Self::Elu(layer)
}
}
impl<B: Backend> From<Celu> for Activation<B> {
fn from(layer: Celu) -> Self {
Self::Celu(layer)
}
}
impl<B: Backend> From<ThresholdedRelu> for Activation<B> {
fn from(layer: ThresholdedRelu) -> Self {
Self::ThresholdedRelu(layer)
}
}
impl<B: Backend> From<HardShrink> for Activation<B> {
fn from(layer: HardShrink) -> Self {
Self::HardShrink(layer)
}
}
impl<B: Backend> From<SoftShrink> for Activation<B> {
fn from(layer: SoftShrink) -> Self {
Self::SoftShrink(layer)
}
}
impl<B: Backend> From<Shrink> for Activation<B> {
fn from(layer: Shrink) -> Self {
Self::Shrink(layer)
}
}
impl<B: Backend> Activation<B> {
/// Forward pass.
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
match self {
Activation::Relu(layer) => layer.forward(input),
Activation::LeakyRelu(layer) => layer.forward(input),
Activation::Gelu(layer) => layer.forward(input),
Activation::PRelu(layer) => layer.forward(input),
Activation::SwiGlu(layer) => layer.forward(input),
Activation::HardSigmoid(layer) => layer.forward(input),
Activation::HardSwish(layer) => layer.forward(input),
Activation::Softplus(layer) => layer.forward(input),
Activation::Selu(layer) => layer.forward(input),
Activation::Sigmoid(layer) => layer.forward(input),
Activation::Tanh(layer) => layer.forward(input),
Activation::Softsign(layer) => layer.forward(input),
Activation::Elu(layer) => layer.forward(input),
Activation::Celu(layer) => layer.forward(input),
Activation::ThresholdedRelu(layer) => layer.forward(input),
Activation::HardShrink(layer) => layer.forward(input),
Activation::SoftShrink(layer) => layer.forward(input),
Activation::Shrink(layer) => layer.forward(input),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::module::Module;
fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)
}
fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {
actual.to_data().assert_eq(&expected.to_data(), true);
}
fn check_stateless_config_output<B: Backend, const D: usize>(
config: ActivationConfig,
input: Tensor<B, D>,
expected: Tensor<B, D>,
device: &B::Device,
) {
let act = config.init(device);
let output = act.forward(input);
expect_tensor(output, expected);
}
#[test]
fn test_gelu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Gelu::new().forward(input.clone());
check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)
}
#[test]
fn test_gelu_approximate() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Gelu::new_approximate().forward(input.clone());
check_stateless_config_output(ActivationConfig::GeluApproximate, input, expected, &device)
}
#[test]
fn test_prelu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = PReluConfig::new();
let expected = inner_config.init(&device).forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_relu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Relu.forward(input.clone());
check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)
}
#[test]
fn test_leaky_relu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = LeakyReluConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_swi_glu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let d_input = input.shape()[1];
let d_output = 2 * d_input;
let inner_config = SwiGluConfig::new(d_input, d_output);
let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);
let config: ActivationConfig = inner_config.into();
let layer = config.init(&device);
match &layer {
Activation::SwiGlu(inner) => {
// Clone the initialized weights.
let state = inner.clone().into_record();
reference = reference.load_record(state);
}
_ => unreachable!(),
};
expect_tensor(
layer.forward(input.clone()),
reference.forward(input.clone()),
)
}
#[test]
fn test_selu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Selu.forward(input.clone());
check_stateless_config_output(ActivationConfig::Selu, input, expected, &device)
}
#[test]
fn test_sigmoid() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Sigmoid.forward(input.clone());
check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)
}
#[test]
fn test_tanh() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Tanh.forward(input.clone());
check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)
}
#[test]
fn test_hard_sigmoid() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = HardSigmoidConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_softsign() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let expected = Softsign.forward(input.clone());
check_stateless_config_output(ActivationConfig::Softsign, input, expected, &device)
}
#[test]
fn test_elu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = EluConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_softplus() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = SoftplusConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_celu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = CeluConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_thresholded_relu() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = ThresholdedReluConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_hard_shrink() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = HardShrinkConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_soft_shrink() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = SoftShrinkConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
#[test]
fn test_shrink() {
let device = Default::default();
let input = make_input::<TestBackend>(&device);
let inner_config = ShrinkConfig::new();
let expected = inner_config.init().forward(input.clone());
check_stateless_config_output(inner_config.into(), input, expected, &device)
}
}

View File

@@ -0,0 +1,104 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::activation::celu;
use burn::tensor::backend::Backend;
/// CELU (Continuously Differentiable Exponential Linear Unit) layer.
///
/// Applies the CELU function element-wise:
/// `celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`
///
/// Should be created with [CeluConfig](CeluConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Celu {
/// The alpha value for the CELU formulation.
pub alpha: f64,
}
/// Configuration to create a [Celu](Celu) layer using the [init function](CeluConfig::init).
#[derive(Config, Debug)]
pub struct CeluConfig {
/// The alpha value for the CELU formulation. Default is 1.0
#[config(default = "1.0")]
pub alpha: f64,
}
impl CeluConfig {
/// Initialize a new [Celu](Celu) Layer
pub fn init(&self) -> Celu {
Celu { alpha: self.alpha }
}
}
impl ModuleDisplay for Celu {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("alpha", &self.alpha).optional()
}
}
impl Celu {
/// Forward pass for the Celu layer.
///
/// See [celu](burn::tensor::activation::celu) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
celu(input, self.alpha)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_celu_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: Celu = CeluConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, -0.5, -1.0]]), &device);
let out = model.forward(input);
// celu(0.5, 1) = 0.5
// celu(-0.5, 1) = 1 * (exp(-0.5) - 1) = -0.393469
// celu(-1.0, 1) = 1 * (exp(-1) - 1) = -0.632121
let expected = TensorData::from([[0.5, -0.393469, -0.632121]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_celu_with_alpha() {
let device = <TestBackend as Backend>::Device::default();
let model: Celu = CeluConfig::new().with_alpha(2.0).init();
let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, -2.0]]), &device);
let out = model.forward(input);
// celu(0, 2) = 0
// celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241
let expected = TensorData::from([[0.0, -1.264241]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = CeluConfig::new().init();
assert_eq!(alloc::format!("{config}"), "Celu {alpha: 1}");
}
}

View File

@@ -0,0 +1,85 @@
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_core as burn;
use burn::tensor::activation::elu;
/// ELU (Exponential Linear Unit) layer.
///
/// Should be created with [EluConfig](EluConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Elu {
/// The alpha value.
pub alpha: f64,
}
/// Configuration to create an [Elu](Elu) layer using the [init function](EluConfig::init).
#[derive(Config, Debug)]
pub struct EluConfig {
/// The alpha value. Default is 1.0
#[config(default = "1.0")]
pub alpha: f64,
}
impl EluConfig {
/// Initialize a new [Elu](Elu) Layer
pub fn init(&self) -> Elu {
Elu { alpha: self.alpha }
}
}
impl ModuleDisplay for Elu {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("alpha", &self.alpha).optional()
}
}
impl Elu {
/// Forward pass for the ELU layer.
///
/// See [elu](burn::tensor::activation::elu) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
elu(input, self.alpha)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_elu_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: Elu = EluConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
let out = model.forward(input);
// elu(0.4410, 1.0) = 0.4410
// elu(-0.2507, 1.0) = 1.0 * (exp(-0.2507) - 1) = -0.22186
let expected = TensorData::from([[0.4410, -0.22186]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = EluConfig::new().init();
assert_eq!(alloc::format!("{config}"), "Elu {alpha: 1}");
}
}

View File

@@ -0,0 +1,82 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the Gaussian Error Linear Units function element-wise.
///
/// See also [gelu](burn::tensor::activation::gelu)
///
/// When `approximate` is true, uses the tanh approximation:
/// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
#[derive(Module, Clone, Debug, Default)]
pub struct Gelu {
/// Whether to use tanh approximation.
pub approximate: bool,
}
impl Gelu {
/// Create the module with exact GELU.
pub fn new() -> Self {
Self::default()
}
/// Create the module with tanh approximation.
pub fn new_approximate() -> Self {
Self { approximate: true }
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if self.approximate {
burn::tensor::activation::gelu_approximate(input)
} else {
burn::tensor::activation::gelu(input)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
type FT = FloatElem<TestBackend>;
#[test]
fn display() {
let layer = Gelu::new();
assert_eq!(alloc::format!("{layer}"), "Gelu {\n approximate: false\n}");
}
#[test]
fn forward_approximate() {
let device = Default::default();
let input =
Tensor::<TestBackend, 2>::from_data([[-1.0, 0.0, 1.0], [0.5, -0.5, 2.0]], &device);
let output = Gelu::new_approximate().forward(input);
// PyTorch: torch.nn.functional.gelu(x, approximate="tanh")
let expected = Tensor::<TestBackend, 2>::from_data(
[
[-0.1588079929, 0.0000000000, 0.8411920071],
[0.3457140028, -0.1542859972, 1.9545977116],
],
&device,
);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::rel_abs(1e-5, 1e-5));
}
}

View File

@@ -0,0 +1,51 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the gated linear unit function.
///
/// See also [glu](burn::tensor::activation::glu)
#[derive(Module, Clone, Debug, Default)]
pub struct GLU {
dim: usize,
}
impl GLU {
/// Create the module.
///
/// # Arguments
/// * `dim` - The dimension on which to split the input.
pub fn new(dim: usize) -> Self {
Self { dim }
}
/// Applies the gated linear unit function.
///
/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
///
/// **Note**:
/// * The size of the input tensor along `dim` must be divisible by 2.
///
/// ### Arguments
/// * `tensor` - The input tensor.
///
/// ### Returns
/// * A tensor with the same shape as the input, except the size along `dim` is halved.
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::glu(input, self.dim)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = GLU::new(1);
assert_eq!(alloc::format!("{layer}"), "GLU {\n dim: 1\n}");
}
}

View File

@@ -0,0 +1,98 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::activation::hard_shrink;
use burn::tensor::backend::Backend;
/// Hard Shrink layer.
///
/// Applies the Hard Shrink function element-wise:
/// `hard_shrink(x) = x if |x| > lambda else 0`
///
/// Should be created with [HardShrinkConfig](HardShrinkConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct HardShrink {
/// The lambda value for the Hard Shrink formulation.
pub lambda: f64,
}
/// Configuration to create a [HardShrink](HardShrink) layer using the [init function](HardShrinkConfig::init).
#[derive(Config, Debug)]
pub struct HardShrinkConfig {
/// The lambda value for the Hard Shrink formulation. Default is 0.5
#[config(default = "0.5")]
pub lambda: f64,
}
impl HardShrinkConfig {
/// Initialize a new [HardShrink](HardShrink) Layer
pub fn init(&self) -> HardShrink {
HardShrink {
lambda: self.lambda,
}
}
}
impl ModuleDisplay for HardShrink {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("lambda", &self.lambda).optional()
}
}
impl HardShrink {
/// Forward pass for the Hard Shrink layer.
///
/// See [hard_shrink](burn::tensor::activation::hard_shrink) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
hard_shrink(input, self.lambda)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn test_hard_shrink_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: HardShrink = HardShrinkConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0_f32, 0.0, -1.0], [8.0, 0.0, 0.0]]);
assert_eq!(out.into_data(), expected);
}
#[test]
fn test_hard_shrink_with_lambda() {
let device = <TestBackend as Backend>::Device::default();
let model: HardShrink = HardShrinkConfig::new().with_lambda(0.2).init();
let input =
Tensor::<TestBackend, 2>::from_data([[0.1, -0.1, -0.3], [0.5, 0.1, 0.0]], &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0_f32, 0.0, -0.3], [0.5, 0.0, 0.0]]);
assert_eq!(out.into_data(), expected);
}
#[test]
fn display() {
let config = HardShrinkConfig::new().init();
assert_eq!(alloc::format!("{config}"), "HardShrink {lambda: 0.5}");
}
}

View File

@@ -0,0 +1,97 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::activation::hard_sigmoid;
use burn::tensor::backend::Backend;
/// Hard Sigmoid layer.
///
/// Should be created with [HardSigmoidConfig](HardSigmoidConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct HardSigmoid {
/// The alpha value.
pub alpha: f64,
/// The beta value.
pub beta: f64,
}
/// Configuration to create a [Hard Sigmoid](HardSigmoid) layer using the [init function](HardSigmoidConfig::init).
#[derive(Config, Debug)]
pub struct HardSigmoidConfig {
/// The alpha value. Default is 0.2
#[config(default = "0.2")]
pub alpha: f64,
/// The beta value. Default is 0.5
#[config(default = "0.5")]
pub beta: f64,
}
impl HardSigmoidConfig {
/// Initialize a new [Hard Sigmoid](HardSigmoid) Layer
pub fn init(&self) -> HardSigmoid {
HardSigmoid {
alpha: self.alpha,
beta: self.beta,
}
}
}
impl ModuleDisplay for HardSigmoid {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("alpha", &self.alpha)
.add("beta", &self.beta)
.optional()
}
}
impl HardSigmoid {
/// Forward pass for the Hard Sigmoid layer.
///
/// See [hard_sigmoid](burn::tensor::activation::hard_sigmoid) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
hard_sigmoid(input, self.alpha, self.beta)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_hard_sigmoid_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: HardSigmoid = HardSigmoidConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
let out = model.forward(input);
let expected = TensorData::from([[0.5882, 0.44986]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = HardSigmoidConfig::new().init();
assert_eq!(
alloc::format!("{config}"),
"HardSigmoid {alpha: 0.2, beta: 0.5}"
);
}
}

View File

@@ -0,0 +1,58 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::activation::hard_swish;
use burn::tensor::backend::Backend;
/// Hard Swish layer.
#[derive(Module, Clone, Debug, Default)]
pub struct HardSwish;
impl HardSwish {
/// Create the module.
pub fn new() -> Self {
Self
}
/// Forward pass for the Hard Swish layer.
///
/// See [hard_swish](burn::tensor::activation::hard_swish) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
hard_swish(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_hard_swish_forward() {
let device = <TestBackend as Backend>::Device::default();
let model = HardSwish::new();
let input = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[3.0f32, -3.0], [0.0, 1.0]]),
&device,
);
let out = model.forward(input);
let expected = TensorData::from([[3.0f32, 0.0], [0.0, 0.6666667]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let layer = HardSwish::new();
assert_eq!(alloc::format!("{layer}"), "HardSwish");
}
}

View File

@@ -0,0 +1,124 @@
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_core as burn;
use burn::tensor::activation::leaky_relu;
/// Leaky ReLu layer.
///
/// Should be created with [LeakyReluConfig](LeakyReluConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct LeakyRelu {
/// The negative slope.
pub negative_slope: f64,
}
/// Configuration to create a [Leaky Relu](LeakyRelu) layer using the [init function](LeakyReluConfig::init).
#[derive(Config, Debug)]
pub struct LeakyReluConfig {
/// The negative slope. Default is 0.01
#[config(default = "0.01")]
pub negative_slope: f64,
}
impl LeakyReluConfig {
/// Initialize a new [Leaky Relu](LeakyRelu) Layer
pub fn init(&self) -> LeakyRelu {
LeakyRelu {
negative_slope: self.negative_slope,
}
}
}
impl ModuleDisplay for LeakyRelu {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("negative_slope", &self.negative_slope)
.optional()
}
}
impl LeakyRelu {
/// Forward pass for the Leaky ReLu layer.
///
/// See [leaky_relu](burn::tensor::activation::leaky_relu) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
leaky_relu(input, self.negative_slope)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_leaky_relu_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: LeakyRelu = LeakyReluConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
let out = model.forward(input);
let expected = TensorData::from([[0.4410, -0.002507]]);
out.to_data().assert_eq(&expected, false);
}
#[test]
fn test_leaky_relu_forward_multi_dim() {
let input = [
[
[-1.0222, 1.5810, 0.3457, -1.3530],
[0.0231, 0.8681, 0.2473, -0.0377],
[0.3520, -1.1199, 1.2219, 0.2804],
],
[
[1.0002, 0.7259, 0.8779, 0.2084],
[1.5615, -0.1057, -0.4886, -1.5184],
[-0.5523, -0.2741, -0.0210, -1.1352],
],
];
let expected = TensorData::from([
[
[-1.0222e-02, 1.5810e+00, 3.457e-01, -1.3530e-02],
[2.31e-02, 8.681e-01, 2.473e-01, -3.77e-04],
[3.52e-01, -1.1199e-02, 1.2219e+00, 2.804e-01],
],
[
[1.0002e+00, 7.259e-01, 8.779e-01, 2.084e-01],
[1.5615e+00, -1.057e-03, -4.886e-03, -1.5184e-02],
[-5.523e-03, -2.741e-03, -2.1e-04, -1.1352e-02],
],
]);
let device = <TestBackend as Backend>::Device::default();
let model: LeakyRelu = LeakyReluConfig::new().init();
let input_data = Tensor::<TestBackend, 3>::from_data(TensorData::from(input), &device);
let actual_output = model.forward(input_data);
actual_output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default())
}
#[test]
fn display() {
let config = LeakyReluConfig::new().init();
assert_eq!(
alloc::format!("{config}"),
"LeakyRelu {negative_slope: 0.01}"
);
}
}

View File

@@ -0,0 +1,68 @@
//! # Activation Layers
//!
//! Users who desire a selectable activation function should
//! consider [`Activation`], which provides an abstraction over:
//! * [`Relu`] - the default,
//! * ['PRelu']
//! * [`Gelu`]
//! * [`LeakyRelu`]
//! * [`SwiGlu`]
//! * [`Selu`]
//! * [`Sigmoid`]
//! * [`HardSigmoid`]
//! * [`HardSwish`]
//! * [`Softplus`]
//! * [`Softsign`]
//! * [`Tanh`]
//! * [`Elu`]
//! * [`Celu`]
//! * [`ThresholdedRelu`]
//!
//! The activation layer [`GLU`] has shape-changing behaviors
//! not compatible with the common API, and is not included
//! in the abstraction wrappers.
mod activation_wrapper;
// These are pub(crate) for dual-export in `nn` without re-exporting
// all of `nn.activation`, or manually listing each symbol.
pub(crate) mod celu;
pub(crate) mod elu;
pub(crate) mod gelu;
pub(crate) mod glu;
pub(crate) mod hard_shrink;
pub(crate) mod hard_sigmoid;
pub(crate) mod hard_swish;
pub(crate) mod leaky_relu;
pub(crate) mod prelu;
pub(crate) mod relu;
pub(crate) mod selu;
pub(crate) mod shrink;
pub(crate) mod sigmoid;
pub(crate) mod soft_shrink;
pub(crate) mod softplus;
pub(crate) mod softsign;
pub(crate) mod swiglu;
pub(crate) mod tanh;
pub(crate) mod thresholded_relu;
pub use activation_wrapper::*;
pub use celu::*;
pub use elu::*;
pub use gelu::*;
pub use glu::*;
pub use hard_shrink::*;
pub use hard_sigmoid::*;
pub use hard_swish::*;
pub use leaky_relu::*;
pub use prelu::*;
pub use relu::*;
pub use selu::*;
pub use shrink::*;
pub use sigmoid::*;
pub use soft_shrink::*;
pub use softplus::*;
pub use softsign::*;
pub use swiglu::*;
pub use tanh::*;
pub use thresholded_relu::*;

View File

@@ -0,0 +1,87 @@
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_core as burn;
/// Parametric Relu layer.
///
/// Should be created using [PReluConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PRelu<B: Backend> {
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
/// be the same as number of channels in the input tensor
pub alpha: Param<Tensor<B, 1>>,
/// Alpha value for the PRelu layer
pub alpha_value: f64,
}
impl<B: Backend> ModuleDisplay for PRelu<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [num_parameters] = self.alpha.shape().dims();
content
.add("num_parameters", &num_parameters)
.add("alpha_value", &self.alpha_value)
.optional()
}
}
/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).
#[derive(Config, Debug)]
pub struct PReluConfig {
/// The number of parameters.
#[config(default = "1")]
pub num_parameters: usize,
/// The learnable weight alpha. Default is 0.25
#[config(default = "0.25")]
pub alpha: f64,
}
impl PReluConfig {
/// Initialize a new [Parametric Relu](PRelu) Layer
pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
PRelu {
// alpha is a tensor of length num_parameters
alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
alpha_value: self.alpha,
}
}
}
impl<B: Backend> PRelu<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
///
/// See also [prelu](burn::tensor::activation::prelu) for more information.
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::prelu(input, self.alpha.val())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn display() {
let layer = PReluConfig::new().init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}"
);
}
}

View File

@@ -0,0 +1,39 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the rectified linear unit function element-wise
/// See also [relu](burn::tensor::activation::relu)
///
#[derive(Module, Clone, Debug, Default)]
pub struct Relu;
impl Relu {
/// Create the module.
pub fn new() -> Self {
Self {}
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::relu(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Relu::new();
assert_eq!(alloc::format!("{layer}"), "Relu");
}
}

View File

@@ -0,0 +1,38 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the Scaled Exponential Linear Unit function element-wise.
/// See also [selu](burn::tensor::activation::selu)
#[derive(Module, Clone, Debug, Default)]
pub struct Selu;
impl Selu {
/// Create the module.
pub fn new() -> Self {
Self {}
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::selu(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Selu::new();
assert_eq!(alloc::format!("{layer}"), "Selu");
}
}

View File

@@ -0,0 +1,114 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::activation::shrink;
use burn::tensor::backend::Backend;
/// Shrink layer.
///
/// Applies the Shrink function element-wise:
/// `shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`
///
/// Should be created with [ShrinkConfig](ShrinkConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Shrink {
/// The lambda value for the Shrink formulation.
pub lambda: f64,
/// The bias value for the Shrink formulation.
// Usually bias = lambda, but need this to handle onnx spec https://onnx.ai/onnx/operators/onnx__Shrink.html
pub bias: f64,
}
/// Configuration to create a [Shrink](Shrink) layer using the [init function](ShrinkConfig::init).
#[derive(Config, Debug)]
pub struct ShrinkConfig {
/// The lambda value for the Shrink formulation. Default is 0.5
#[config(default = "0.5")]
pub lambda: f64,
/// The bias value for the Shrink formulation. Default is 0.5.
#[config(default = "0.5")]
pub bias: f64,
}
impl ShrinkConfig {
/// Initialize a new [Shrink](Shrink) Layer
pub fn init(&self) -> Shrink {
Shrink {
lambda: self.lambda,
bias: self.bias,
}
}
}
impl ModuleDisplay for Shrink {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("lambda", &self.lambda)
.add("bias", &self.bias)
.optional()
}
}
impl Shrink {
/// Forward pass for the Shrink layer.
///
/// See [shrink](burn::tensor::activation::shrink) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
shrink(input, self.lambda, self.bias)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn test_shrink_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: Shrink = ShrinkConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);
assert_eq!(out.into_data(), expected);
}
#[test]
fn test_shrink_with_lambda_and_bias() {
let device = <TestBackend as Backend>::Device::default();
let model: Shrink = ShrinkConfig::new()
.with_lambda(0.25)
.with_bias(0.125)
.init();
let input =
Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0_f32, 0.0, -0.375], [0.625, 0.0, 0.0]]);
assert_eq!(out.into_data(), expected);
}
#[test]
fn display() {
let config = ShrinkConfig::new().init();
assert_eq!(
alloc::format!("{config}"),
"Shrink {lambda: 0.5, bias: 0.5}"
);
}
}

View File

@@ -0,0 +1,38 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the sigmoid function element-wise
/// See also [sigmoid](burn::tensor::activation::sigmoid)
#[derive(Module, Clone, Debug, Default)]
pub struct Sigmoid;
impl Sigmoid {
/// Create the module.
pub fn new() -> Self {
Self {}
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::sigmoid(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Sigmoid::new();
assert_eq!(alloc::format!("{layer}"), "Sigmoid");
}
}

View File

@@ -0,0 +1,98 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::activation::soft_shrink;
use burn::tensor::backend::Backend;
/// Soft Shrink layer.
///
/// Applies the Soft Shrink function element-wise:
/// `soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`
///
/// Should be created with [SoftShrinkConfig](SoftShrinkConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct SoftShrink {
/// The lambda value for the Soft Shrink formulation.
pub lambda: f64,
}
/// Configuration to create a [SoftShrink](SoftShrink) layer using the [init function](SoftShrinkConfig::init).
#[derive(Config, Debug)]
pub struct SoftShrinkConfig {
/// The lambda value for the Soft Shrink formulation. Default is 0.5
#[config(default = "0.5")]
pub lambda: f64,
}
impl SoftShrinkConfig {
/// Initialize a new [SoftShrink](SoftShrink) Layer
pub fn init(&self) -> SoftShrink {
SoftShrink {
lambda: self.lambda,
}
}
}
impl ModuleDisplay for SoftShrink {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("lambda", &self.lambda).optional()
}
}
impl SoftShrink {
/// Forward pass for the Soft Shrink layer.
///
/// See [soft_shrink](burn::tensor::activation::soft_shrink) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
soft_shrink(input, self.lambda)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn test_soft_shrink_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: SoftShrink = SoftShrinkConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);
assert_eq!(out.into_data(), expected);
}
#[test]
fn test_soft_shrink_with_lambda() {
let device = <TestBackend as Backend>::Device::default();
let model: SoftShrink = SoftShrinkConfig::new().with_lambda(0.25).init();
let input =
Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0_f32, 0.0, -0.25], [0.5, 0.0, 0.0]]);
assert_eq!(out.into_data(), expected);
}
#[test]
fn display() {
let config = SoftShrinkConfig::new().init();
assert_eq!(alloc::format!("{config}"), "SoftShrink {lambda: 0.5}");
}
}

View File

@@ -0,0 +1,105 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::activation::softplus;
use burn::tensor::backend::Backend;
/// Softplus layer.
///
/// Applies the softplus function element-wise:
/// `softplus(x) = (1/beta) * log(1 + exp(beta * x))`
///
/// Should be created with [SoftplusConfig](SoftplusConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Softplus {
/// The beta value.
pub beta: f64,
}
/// Configuration to create a [Softplus](Softplus) layer using the [init function](SoftplusConfig::init).
#[derive(Config, Debug)]
pub struct SoftplusConfig {
/// The beta value. Default is 1.0
#[config(default = "1.0")]
pub beta: f64,
}
impl SoftplusConfig {
/// Initialize a new [Softplus](Softplus) Layer
pub fn init(&self) -> Softplus {
Softplus { beta: self.beta }
}
}
impl ModuleDisplay for Softplus {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("beta", &self.beta).optional()
}
}
impl Softplus {
/// Forward pass for the Softplus layer.
///
/// See [softplus](burn::tensor::activation::softplus) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
softplus(input, self.beta)
}
}
#[cfg(test)]
#[allow(clippy::approx_constant)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_softplus_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: Softplus = SoftplusConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0, -1.0]]), &device);
let out = model.forward(input);
// softplus(0) = log(2) ≈ 0.6931
// softplus(1) = log(1 + e) ≈ 1.3133
// softplus(-1) = log(1 + e^-1) ≈ 0.3133
let expected = TensorData::from([[0.6931, 1.3133, 0.3133]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_softplus_with_beta() {
let device = <TestBackend as Backend>::Device::default();
let model: Softplus = SoftplusConfig::new().with_beta(2.0).init();
let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0]]), &device);
let out = model.forward(input);
// softplus(0, beta=2) = (1/2) * log(1 + exp(0)) = 0.5 * log(2) ≈ 0.3466
// softplus(1, beta=2) = (1/2) * log(1 + exp(2)) = 0.5 * log(8.389) ≈ 1.0635
let expected = TensorData::from([[0.3466, 1.0635]]);
out.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = SoftplusConfig::new().init();
assert_eq!(alloc::format!("{config}"), "Softplus {beta: 1}");
}
}

View File

@@ -0,0 +1,38 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the softsign function element-wise
/// See also [softsign](burn::tensor::activation::softsign)
#[derive(Module, Clone, Debug, Default)]
pub struct Softsign;
impl Softsign {
/// Create the module.
pub fn new() -> Self {
Self {}
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::softsign(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Softsign::new();
assert_eq!(alloc::format!("{layer}"), "Softsign");
}
}

View File

@@ -0,0 +1,153 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::activation::silu;
use burn::tensor::{Tensor, backend::Backend};
use crate::{Linear, LinearConfig, LinearLayout};
/// Configuration to create a [SwiGlu](SwiGlu) activation layer using the [init function](SwiGluConfig::init).
#[derive(Config, Debug)]
pub struct SwiGluConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the output features.
pub d_output: usize,
/// If a bias should be applied during the linear transformation. Default behaviour is False
/// for SwiGLU activation implementations.
#[config(default = false)]
pub bias: bool,
/// The type of function used to initialize the linear layer parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
/// The layout in which the linear parameters are stored.
#[config(default = "LinearLayout::Row")]
pub layout: LinearLayout,
}
/// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor.
/// The SwiGLU activation function is defined as:
/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`
///
/// Should be created with [SwiGluConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct SwiGlu<B: Backend> {
/// The inner linear layer for Swish activation function
/// with `d_input` input features and `d_output` output features.
pub linear_inner: Linear<B>,
/// The outer linear layer for element wise multiplication
/// with `d_input` input features and `d_output` output features.
pub linear_outer: Linear<B>,
}
impl<B: Backend> ModuleDisplay for SwiGlu<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, d_output] = self.linear_inner.weight.shape().dims();
content
.add("d_input", &d_input)
.add("d_output", &d_output)
.add("bias", &self.linear_inner.bias.is_some())
.optional()
}
}
impl SwiGluConfig {
/// Initialize a new [SwiGLU](SwiGlu) activation layer.
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
SwiGlu {
linear_inner: LinearConfig::new(self.d_input, self.d_output)
.with_bias(self.bias)
.with_initializer(self.initializer.clone())
.with_layout(self.layout)
.init(device),
linear_outer: LinearConfig::new(self.d_input, self.d_output)
.with_bias(self.bias)
.with_initializer(self.initializer.clone())
.with_layout(self.layout)
.init(device),
}
}
}
impl<B: Backend> SwiGlu<B> {
/// Applies the Swish Gated Linear Unit to the input tensor.
///
/// # Shapes
///
/// - input: `[batch_size, seq_length, d_input]`
/// - output: `[batch_size, seq_length, d_output]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input.clone());
let x = silu(x);
x.mul(self.linear_outer.forward(input))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_swiglu_forward_no_bias() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 });
let swiglu = config.init(&device);
let input =
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let output = swiglu.forward(input);
let expected_output = Tensor::<TestBackend, 2>::from_data(
[[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
}
#[test]
fn test_swiglu_forward_with_bias() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = SwiGluConfig::new(3, 3)
.with_bias(true)
.with_initializer(Initializer::Constant { value: 0.5 });
let swiglu = config.init(&device);
let input =
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let output = swiglu.forward(input);
let expected_output = Tensor::<TestBackend, 2>::from_data(
[[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
}
#[test]
fn display() {
let config = SwiGluConfig::new(3, 5);
let swiglu = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{swiglu}"),
"SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}"
);
}
}

View File

@@ -0,0 +1,38 @@
use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Applies the tanh activation function element-wise
/// See also [tanh](burn::tensor::activation::tanh)
#[derive(Module, Clone, Debug, Default)]
pub struct Tanh;
impl Tanh {
/// Create the module.
pub fn new() -> Self {
Self {}
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
burn::tensor::activation::tanh(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Tanh::new();
assert_eq!(alloc::format!("{layer}"), "Tanh");
}
}

View File

@@ -0,0 +1,82 @@
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_core as burn;
use burn::tensor::activation::thresholded_relu;
/// Thresholded ReLU layer.
///
/// Should be created with [ThresholdedReluConfig](ThresholdedReluConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct ThresholdedRelu {
/// The alpha threshold.
pub alpha: f64,
}
/// Configuration to create a [ThresholdedRelu](ThresholdedRelu) layer using the [init function](ThresholdedReluConfig::init).
#[derive(Config, Debug)]
pub struct ThresholdedReluConfig {
/// The alpha threshold. Default is 1.0
#[config(default = "1.0")]
pub alpha: f64,
}
impl ThresholdedReluConfig {
/// Initialize a new [ThresholdedRelu](ThresholdedRelu) layer.
pub fn init(&self) -> ThresholdedRelu {
ThresholdedRelu { alpha: self.alpha }
}
}
impl ModuleDisplay for ThresholdedRelu {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("alpha", &self.alpha).optional()
}
}
impl ThresholdedRelu {
/// Forward pass for the Thresholded ReLU layer.
///
/// See [thresholded_relu](burn::tensor::activation::thresholded_relu) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
thresholded_relu(input, self.alpha)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn test_thresholded_relu_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: ThresholdedRelu = ThresholdedReluConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, 1.5, -0.2]]), &device);
let out = model.forward(input);
let expected = TensorData::from([[0.0, 1.5, 0.0]]);
out.to_data().assert_eq(&expected, false);
}
#[test]
fn display() {
let config = ThresholdedReluConfig::new().init();
assert_eq!(alloc::format!("{config}"), "ThresholdedRelu {alpha: 1}");
}
}

View File

@@ -0,0 +1,63 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "256"]
//! Burn neural network module.
/// Loss module
pub mod loss;
/// Neural network modules implementations.
pub mod modules;
pub use modules::*;
pub mod activation;
pub use activation::{
celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*,
relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*,
tanh::*, thresholded_relu::*,
};
mod padding;
pub use padding::*;
// For backward compat, `burn::nn::Initializer`
pub use burn_core::module::Initializer;
extern crate alloc;
/// Backend for test cases
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda"),
not(feature = "test-rocm")
))]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-tch"))]
/// Backend for test cases
pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type TestBackend = burn_cuda::Cuda;
#[cfg(all(test, feature = "test-rocm"))]
/// Backend for test cases
pub type TestBackend = burn_rocm::Rocm;
/// Backend for autodiff test cases
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
#[cfg(all(test, feature = "test-memory-checks"))]
mod tests {
burn_fusion::memory_checks!();
}

View File

@@ -0,0 +1,432 @@
use burn_core as burn;
use alloc::vec::Vec;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::activation::log_sigmoid;
use burn::tensor::{Int, Tensor, backend::Backend};
use burn::{config::Config, module::Module};
/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
#[derive(Config, Debug)]
pub struct BinaryCrossEntropyLossConfig {
/// Create weighted binary cross-entropy with a weight for each class.
///
/// The loss of a specific sample will simply be multiplied by its label weight.
pub weights: Option<Vec<f32>>,
/// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).
///
/// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
/// Alpha = 0 would be the same as default.
pub smoothing: Option<f32>,
/// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
#[config(default = false)]
pub logits: bool,
}
impl BinaryCrossEntropyLossConfig {
/// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
self.assertions();
BinaryCrossEntropyLoss {
weights: self
.weights
.as_ref()
.map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
smoothing: self.smoothing,
logits: self.logits,
}
}
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(
weights.iter().all(|e| e > &0.),
"Weights of cross-entropy have to be positive."
);
}
}
}
/// Calculate the binary cross entropy loss from the input logits and the targets.
///
/// Should be created using [BinaryCrossEntropyLossConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BinaryCrossEntropyLoss<B: Backend> {
/// Weights for cross-entropy.
pub weights: Option<Tensor<B, 1>>,
/// Label smoothing alpha.
pub smoothing: Option<f32>,
/// Treat the inputs as logits
pub logits: bool,
}
impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("weights", &self.weights)
.add("smoothing", &self.smoothing)
.add("logits", &self.logits)
.optional()
}
}
impl<B: Backend> BinaryCrossEntropyLoss<B> {
/// Compute the criterion on the input tensor.
///
/// # Shapes
///
/// Binary:
/// - logits: `[batch_size]`
/// - targets: `[batch_size]`
///
/// Multi-label:
/// - logits: `[batch_size, num_classes]`
/// - targets: `[batch_size, num_classes]`
pub fn forward<const D: usize>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D, Int>,
) -> Tensor<B, 1> {
self.assertions(&logits, &targets);
let mut targets_float = targets.clone().float();
let shape = targets.dims();
if let Some(alpha) = self.smoothing {
let num_classes = if D > 1 { shape[D - 1] } else { 2 };
targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
}
let mut loss = if self.logits {
// Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
(targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
} else {
// - (target * log(input) + (1 - target) * log(1 - input))
// https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
(targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
- targets_float * logits.log().clamp_min(-100.0)
};
if let Some(weights) = &self.weights {
let weights = if D > 1 {
weights.clone().expand(shape)
} else {
// Flatten targets and expand resulting weights to make it compatible with
// Tensor<B, D> for binary 1-D case
weights
.clone()
.gather(0, targets.flatten(0, 0))
.expand(shape)
};
loss = loss * weights;
}
loss.mean()
}
fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
let logits_dims = logits.dims();
let targets_dims = targets.dims();
assert!(
logits_dims == targets_dims,
"Shape of targets ({targets_dims:?}) should correspond to outer shape of logits ({logits_dims:?})."
);
if let Some(weights) = &self.weights
&& D > 1
{
let targets_classes = targets_dims[D - 1];
let weights_classes = weights.dims()[0];
assert!(
weights_classes == targets_classes,
"The number of classes ({weights_classes}) does not match the weights provided ({targets_classes})."
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{TensorData, activation::sigmoid};
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_binary_cross_entropy_preds_all_correct() {
let device = Default::default();
let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.init(&device)
.forward(preds, targets)
.into_data();
let loss_expected = TensorData::from([0.000]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
}
#[test]
fn test_binary_cross_entropy_preds_all_incorrect() {
let device = Default::default();
let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.init(&device)
.forward(preds, targets)
.into_data();
let loss_expected = TensorData::from([100.000]); // clamped value
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
}
#[test]
fn test_binary_cross_entropy() {
// import torch
// from torch import nn
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
// target = torch.tensor([0., 1., 0., 1.])
// loss = nn.BCELoss()
// sigmoid = nn.Sigmoid()
// out = loss(sigmoid(input), target) # tensor(0.7491)
let device = Default::default();
let logits =
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.init(&device)
.forward(sigmoid(logits), targets)
.into_data();
let loss_expected = TensorData::from([0.7491]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
}
#[test]
fn test_binary_cross_entropy_with_logits() {
let device = Default::default();
let logits =
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.with_logits(true)
.init(&device)
.forward(logits, targets)
.into_data();
let loss_expected = TensorData::from([0.7491]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
}
#[test]
fn test_binary_cross_entropy_with_weights() {
// import torch
// from torch import nn
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
// target = torch.tensor([0, 1, 0, 1])
// weights = torch.tensor([3., 7.]).gather(0, target)
// loss = nn.BCELoss(weights)
// sigmoid = nn.Sigmoid()
// out = loss(sigmoid(input), target.float()) # tensor(3.1531)
let device = Default::default();
let logits =
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
let weights = [3., 7.];
let loss_actual = BinaryCrossEntropyLossConfig::new()
.with_weights(Some(weights.to_vec()))
.init(&device)
.forward(sigmoid(logits), targets)
.into_data();
let loss_expected = TensorData::from([3.1531]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
}
#[test]
fn test_binary_cross_entropy_with_smoothing() {
// import torch
// from torch import nn
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
// target = torch.tensor([0., 1., 0., 1.])
// target_smooth = target * (1 - 0.1) + (0.1 / 2)
// loss = nn.BCELoss()
// sigmoid = nn.Sigmoid()
// out = loss(sigmoid(input), target_smooth) # tensor(0.7490)
let device = Default::default();
let logits =
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.with_smoothing(Some(0.1))
.init(&device)
.forward(sigmoid(logits), targets)
.into_data();
let loss_expected = TensorData::from([0.7490]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
}
#[test]
fn test_binary_cross_entropy_multilabel() {
// import torch
// from torch import nn
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
// weights = torch.tensor([3., 7., 0.9])
// loss = nn.BCEWithLogitsLoss()
// out = loss(input, target) # tensor(0.7112)
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::from_floats(
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
&device,
);
let targets = Tensor::<TestBackend, 2, Int>::from_data(
TensorData::from([[1, 0, 1], [1, 0, 0]]),
&device,
);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.with_logits(true)
.init(&device)
.forward(logits, targets)
.into_data();
let loss_expected = TensorData::from([0.7112]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
}
#[test]
fn test_binary_cross_entropy_multilabel_with_weights() {
// import torch
// from torch import nn
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
// loss = nn.BCEWithLogitsLoss()
// out = loss(input, target) # tensor(3.1708)
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::from_floats(
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
&device,
);
let targets = Tensor::<TestBackend, 2, Int>::from_data(
TensorData::from([[1, 0, 1], [1, 0, 0]]),
&device,
);
let weights = [3., 7., 0.9];
let loss_actual = BinaryCrossEntropyLossConfig::new()
.with_logits(true)
.with_weights(Some(weights.to_vec()))
.init(&device)
.forward(logits, targets)
.into_data();
let loss_expected = TensorData::from([3.1708]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
}
#[test]
fn test_binary_cross_entropy_multilabel_with_smoothing() {
// import torch
// from torch import nn
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
// target_smooth = target * (1 - 0.1) + (0.1 / 3)
// loss = nn.BCELoss()
// sigmoid = nn.Sigmoid()
// out = loss(sigmoid(input), target_smooth) # tensor(0.7228)
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::from_floats(
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
&device,
);
let targets = Tensor::<TestBackend, 2, Int>::from_data(
TensorData::from([[1, 0, 1], [1, 0, 0]]),
&device,
);
let loss_actual = BinaryCrossEntropyLossConfig::new()
.with_smoothing(Some(0.1))
.init(&device)
.forward(sigmoid(logits), targets)
.into_data();
let loss_expected = TensorData::from([0.7228]);
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
}
#[test]
#[should_panic = "The number of classes"]
fn multilabel_weights_should_match_target() {
// import torch
// from torch import nn
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
// loss = nn.BCEWithLogitsLoss()
// out = loss(input, target) # tensor(3.1708)
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::from_floats(
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
&device,
);
let targets = Tensor::<TestBackend, 2, Int>::from_data(
TensorData::from([[1, 0, 1], [1, 0, 0]]),
&device,
);
let weights = [3., 7.];
let _loss = BinaryCrossEntropyLossConfig::new()
.with_logits(true)
.with_weights(Some(weights.to_vec()))
.init(&device)
.forward(logits, targets);
}
#[test]
fn display() {
let config =
BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
let loss = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{loss}"),
"BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
);
}
}

View File

@@ -0,0 +1,317 @@
use alloc::format;
use burn::tensor::linalg::cosine_similarity;
use burn_core as burn;
use crate::loss::reduction::Reduction;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Ignored, Module};
use burn::tensor::{Int, Tensor, activation::relu, backend::Backend};
/// Configuration for CosineEmbeddingLoss.
#[derive(Config, Debug)]
pub struct CosineEmbeddingLossConfig {
/// Margin for negative samples.
#[config(default = 0.0)]
pub margin: f32,
/// Specifies the reduction to apply to the output.
#[config(default = "Reduction::Mean")]
pub reduction: Reduction,
}
impl CosineEmbeddingLossConfig {
/// Initialize CosineEmbeddingLoss.
pub fn init(&self) -> CosineEmbeddingLoss {
CosineEmbeddingLoss {
margin: self.margin,
reduction: Ignored(self.reduction.clone()),
}
}
}
/// Cosine embedding loss between two tensors.
///
/// Measures cosine distance between tensors.
/// Used for learning embeddings or similarity.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct CosineEmbeddingLoss {
/// Margin value. Default: 0.0
pub margin: f32,
/// Reduction method
pub reduction: Ignored<Reduction>,
}
impl Default for CosineEmbeddingLoss {
fn default() -> Self {
CosineEmbeddingLossConfig::new().init()
}
}
impl ModuleDisplay for CosineEmbeddingLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("margin", &self.margin)
.add("reduction", format!("{:?}", &self.reduction.0).as_str())
.optional()
}
}
impl CosineEmbeddingLoss {
/// Creates a new instance
pub fn new() -> Self {
CosineEmbeddingLossConfig::new().init()
}
/// Compute loss with reduction.
///
/// # Shapes
///
/// - input1: ``[batch_size, embedding_dim]``
/// - input2: ``[batch_size, embedding_dim]``
/// - target: ``[batch_size]`` with values 1 or -1
///
/// # Returns
///
/// Loss tensor of shape ``[1]``
pub fn forward<B: Backend>(
&self,
input1: Tensor<B, 2>,
input2: Tensor<B, 2>,
target: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let tensor = self.forward_no_reduction(input1, input2, target);
match &self.reduction.0 {
Reduction::Mean | Reduction::Auto => tensor.mean(),
Reduction::Sum => tensor.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
/// Compute loss without applying reduction.
///
/// # Arguments
///
/// * `input1` - First input tensor of shape ``[batch_size, embedding_dim]``
/// * `input2` - Second input tensor of shape ``[batch_size, embedding_dim]``
/// * `target` - Target tensor of shape ``[batch_size]`` with values 1 or -1
///
/// # Returns
///
/// Tensor of per-element losses with shape ``[batch_size]``
pub fn forward_no_reduction<B: Backend>(
&self,
input1: Tensor<B, 2>,
input2: Tensor<B, 2>,
target: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
self.assertions(&input1, &input2, &target);
// cos_sim shape: [batch_size, 1]
let cos_sim = cosine_similarity(input1, input2, 1, None);
// cos_sim shape: [batch_size]
let cos_sim: Tensor<B, 1> = cos_sim.squeeze_dim(1);
let mut loss = cos_sim.zeros_like();
// Similar pairs (target == 1) - Formula: L = 1 - cos_sim
let similar_mask = target.clone().equal_elem(1);
let similar_loss = cos_sim.clone().neg().add_scalar(1);
loss = loss.mask_where(similar_mask, similar_loss);
// Dissimilar pairs (target == -1) - Formula: L = max(0, cos_sim - margin)
let dissimilar_mask = target.equal_elem(-1);
let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin));
loss = loss.mask_where(dissimilar_mask, dissimilar_loss);
// return loss shape: [batch_size]
loss
}
fn assertions<B: Backend>(
&self,
input1: &Tensor<B, 2>,
input2: &Tensor<B, 2>,
target: &Tensor<B, 1, Int>,
) {
let [batch_size1, dim1] = input1.dims();
let [batch_size2, dim2] = input2.dims();
let [batch_size_target] = target.dims();
assert_eq!(
batch_size1, batch_size2,
"Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})"
);
assert_eq!(
dim1, dim2,
"Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})"
);
assert_eq!(
batch_size1, batch_size_target,
"Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn cosine_embedding_loss_positive_target() {
let device = Default::default();
// Two identical vectors should have cosine similarity of 1
let input1 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let input2 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
// Target 1 means that inputs should be similar
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 1]), &device);
let loss = CosineEmbeddingLossConfig::new().init();
let loss_no_reduction =
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
let loss_sum = loss.forward(input1, input2, target);
// For identical vectors, 1 - cos_sim = 1 - 1 = 0
let expected_no_reduction = TensorData::from([0.0, 0.0]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
let expected_mean = TensorData::from([0.0]);
loss_mean
.into_data()
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
let expected_sum = TensorData::from([0.0]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
}
#[test]
fn cosine_embedding_loss_negative_target() {
let device = Default::default();
// Two identical vectors should have cosine similarity of 1
let input1 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let input2 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
// Target -1 means that inputs should be dissimilar
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([-1, -1]), &device);
// With margin 0.0, max(0, cos_sim - margin) = max(0, 1 - 0) = 1
let loss = CosineEmbeddingLossConfig::new().init();
let loss_no_reduction =
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
// Create a loss with Sum reduction for testing
let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum);
let loss_sum =
loss_sum_config
.init()
.forward(input1.clone(), input2.clone(), target.clone());
let expected_no_reduction = TensorData::from([1.0, 1.0]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
let expected_mean = TensorData::from([1.0]);
loss_mean
.into_data()
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
let expected_sum = TensorData::from([2.0]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
// With margin 0.5, max(0, cos_sim - margin) = max(0, 1 - 0.5) = 0.5
let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init();
let loss_with_margin = loss_with_margin.forward(input1, input2, target);
let expected = TensorData::from([0.5]);
loss_with_margin
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn cosine_embedding_loss_mixed_targets() {
let device = Default::default();
let input1 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let input2 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
// Mixed targets
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, -1]), &device);
let loss = CosineEmbeddingLossConfig::new().init();
let loss_no_reduction =
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
let loss_mean = loss.forward(input1, input2, target);
let expected_no_reduction = TensorData::from([0.0, 1.0]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
let expected_mean = TensorData::from([0.5]);
loss_mean
.into_data()
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
}
#[test]
fn display() {
let config = CosineEmbeddingLossConfig::new().with_margin(0.5);
let loss = config.init();
assert_eq!(
alloc::format!("{loss}"),
"CosineEmbeddingLoss {margin: 0.5, reduction: Mean}"
);
}
}

View File

@@ -0,0 +1,466 @@
use burn_core as burn;
use burn_core::tensor::IndexingUpdateOp;
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::activation::log_softmax;
use burn::tensor::{Bool, Int, Tensor, backend::Backend};
use burn::{config::Config, module::Module};
/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
#[derive(Config, Debug)]
pub struct CrossEntropyLossConfig {
/// Create padded cross entropy.
///
/// Prevents pad tokens from impacting loss calculation.
pub pad_tokens: Option<Vec<usize>>,
/// Create weighted cross-entropy.
///
/// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
///
/// # Pre-conditions
/// - The order of the weight vector should correspond to the label integer assignment.
/// - Targets assigned negative Int's will not be allowed.
pub weights: Option<Vec<f32>>,
/// Create cross-entropy with label smoothing.
///
/// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
/// Alpha = 0 would be the same as default.
pub smoothing: Option<f32>,
/// Create cross-entropy with probabilities as input instead of logits.
///
#[config(default = true)]
pub logits: bool,
}
impl CrossEntropyLossConfig {
/// Initialize [Cross-entropy loss](CrossEntropyLoss).
pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
self.assertions();
CrossEntropyLoss {
pad_tokens: self.pad_tokens.clone(),
weights: self
.weights
.as_ref()
.map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
smoothing: self.smoothing,
logits: self.logits,
}
}
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(
weights.iter().all(|e| e > &0.),
"Weights of cross-entropy have to be positive."
);
}
}
}
/// Calculate the cross entropy loss from the input logits and the targets.
///
/// Should be created using [CrossEntropyLossConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct CrossEntropyLoss<B: Backend> {
/// Pad tokens to ignore in the loss calculation.
pub pad_tokens: Option<Vec<usize>>,
/// Weights for cross-entropy.
pub weights: Option<Tensor<B, 1>>,
/// Label smoothing factor.
pub smoothing: Option<f32>,
/// Use logits as input.
pub logits: bool,
}
impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
alloc::format!("Vec<0..{}>", pad_tokens.len())
} else {
"None".to_string()
};
content
.add("pad_tokens", &pad_tokens)
.add("weights", &self.weights)
.add("smoothing", &self.smoothing)
.add("logits", &self.logits)
.optional()
}
}
impl<B: Backend> CrossEntropyLoss<B> {
/// For backward compatibility.
pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
CrossEntropyLossConfig::new()
.with_pad_tokens(pad_index.map(|e| vec![e]))
.init(device)
}
/// Compute the criterion on the input tensor.
///
/// # Shapes
///
/// - logits: `[batch_size, num_targets]`
/// - targets: `[batch_size]`
pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
Self::assertions(logits.clone(), targets.clone());
match self.smoothing {
Some(alpha) => self.forward_smoothed(logits, targets, alpha),
_ => self.forward_default(logits, targets),
}
}
fn forward_smoothed(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
alpha: f32,
) -> Tensor<B, 1> {
let mask = self.padding_mask(&targets);
let tensor = if self.logits {
log_softmax(logits, 1)
} else {
logits.log()
};
let [batch_size, nr_classes] = tensor.dims();
let tensor = tensor
* Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
match &self.weights {
Some(weights) => {
let tensor = tensor
* weights
.clone()
.reshape([1, nr_classes])
.repeat_dim(0, batch_size);
let weights = weights.clone().gather(0, targets);
let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum().neg() / weights.sum()
}
None => {
let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum_dim(1).mean().neg()
}
}
}
fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
let [batch_size] = targets.dims();
let mask = self.padding_mask(&targets);
let tensor = log_softmax(logits, 1);
let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
match &self.weights {
Some(weights) => {
let weights = weights.clone().gather(0, targets);
let tensor = tensor.reshape([batch_size]) * weights.clone();
let tensor = Self::apply_mask_1d(tensor, mask);
tensor.sum().neg() / weights.sum()
}
None => {
let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
tensor.mean().neg()
}
}
}
fn compute_smoothed_targets(
shape: [usize; 2],
targets: Tensor<B, 1, Int>,
alpha: f32,
) -> Tensor<B, 2> {
let [batch_size, nr_classes] = shape;
let device = &targets.device();
let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
1,
targets.reshape([batch_size, 1]),
Tensor::ones([batch_size, 1], device),
IndexingUpdateOp::Add,
);
targets_matrix * (1. - alpha) + alpha / nr_classes as f32
}
fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
let mut mask = None;
if let Some(pad_tokens) = &self.pad_tokens {
let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
for x in pad_tokens {
res = res + targets.clone().equal_elem(*x as i64).int();
}
mask = Some(res.greater_elem(0));
}
mask
}
fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
if let Some(mask) = mask {
tensor = tensor.mask_fill(mask, 0);
}
tensor
}
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
if let Some(mask) = mask {
let [batch_size, nr_classes] = tensor.dims();
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
}
tensor
}
fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
let [logits_height, _] = logits.dims();
let [targets_height] = targets.dims();
assert!(
logits_height == targets_height,
"Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
macro_rules! setup {
() => {{
let [batch_size, num_targets] = [4, 5];
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::random(
[batch_size, num_targets],
Distribution::Normal(0., 1.0),
&device,
);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
let targets_logits = Tensor::<TestBackend, 2>::from_data(
TensorData::from([
[0.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
]),
&device,
);
(logits, targets, targets_logits)
}};
}
macro_rules! setup_padded {
() => {{
let [batch_size, num_targets, pad_index] = [4, 5, 1];
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::random(
[batch_size, num_targets],
Distribution::Normal(0., 1.0),
&device,
);
let targets = Tensor::<TestBackend, 1, Int>::from_data(
TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
&device,
);
let targets_logits = Tensor::<TestBackend, 2>::from_data(
TensorData::from([
[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]),
&device,
);
(logits, targets, targets_logits)
}};
}
#[test]
fn test_cross_entropy_loss_with_weights() {
let (logits, targets, targets_logits) = setup!();
let weights = vec![1.0, 2., 3., 4., 5.];
let device = Default::default();
let loss_1 = CrossEntropyLossConfig::new()
.with_weights(Some(weights.clone()))
.init(&device)
.forward(logits.clone(), targets);
let tensor = log_softmax(logits, 1);
let loss_2 = tensor
* targets_logits
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
.unsqueeze()
.repeat_dim(0, 4);
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn test_label_smoothing_with_weights_and_alpha_zero() {
let (logits, targets, _) = setup!();
let device = Default::default();
let weights = vec![1.0, 2., 3., 4., 5.];
let loss_1 = CrossEntropyLossConfig::new()
.with_weights(Some(weights.clone()))
.init(&device)
.forward(logits.clone(), targets.clone());
let loss_2 = CrossEntropyLossConfig::new()
.with_weights(Some(weights.clone()))
.with_smoothing(Some(0.))
.init(&device)
.forward(logits.clone(), targets);
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn test_cross_entropy_loss() {
let (logits, targets, targets_logits) = setup!();
let device = Default::default();
let loss_1 = CrossEntropyLossConfig::new()
.init(&device)
.forward(logits.clone(), targets);
let loss_2 = cross_entropy_with_logits(logits, targets_logits);
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn test_label_smoothing_alpha_equal_zero() {
let (logits, targets, _) = setup!();
let device = Default::default();
let loss_1 = CrossEntropyLossConfig::new()
.init(&device)
.forward(logits.clone(), targets.clone());
let loss_2 = CrossEntropyLossConfig::new()
.with_smoothing(Some(0.))
.init(&device)
.forward(logits, targets);
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn test_cross_entropy_loss_with_pad_token() {
let (logits, targets, targets_logits) = setup_padded!();
let pad_index = 1;
let loss_1 = CrossEntropyLossConfig::new()
.with_pad_tokens(Some(vec![pad_index, 2]))
.init(&logits.device())
.forward(logits.clone(), targets);
let loss_2 = cross_entropy_with_logits(logits, targets_logits);
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn test_label_smoothing_with_zero_alpha_and_pad_token() {
let (logits, targets, _) = setup_padded!();
let pad_index = 1;
let loss_1 = CrossEntropyLossConfig::new()
.with_pad_tokens(Some(vec![pad_index, 2]))
.init(&logits.device())
.forward(logits.clone(), targets.clone());
let loss_2 = CrossEntropyLossConfig::new()
.with_pad_tokens(Some(vec![pad_index, 2]))
.with_smoothing(Some(0.))
.init(&logits.device())
.forward(logits.clone(), targets);
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn test_label_smoothing_target_conversion() {
let (logits, targets, _) = setup!();
let smoothed_targets =
CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
let targets_logits = Tensor::<TestBackend, 2>::from_data(
TensorData::from([
[0.01, 0.01, 0.96, 0.01, 0.01],
[0.96, 0.01, 0.01, 0.01, 0.01],
[0.01, 0.01, 0.01, 0.01, 0.96],
[0.01, 0.96, 0.01, 0.01, 0.01],
]),
&Default::default(),
);
smoothed_targets
.into_data()
.assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
}
#[test]
fn test_label_smoothing() {
let (logits, targets, _) = setup!();
let device = Default::default();
let loss_1 = CrossEntropyLossConfig::new()
.with_smoothing(Some(0.05))
.init(&device)
.forward(logits.clone(), targets);
let targets_logits = Tensor::<TestBackend, 2>::from_data(
TensorData::from([
[0.01, 0.01, 0.96, 0.01, 0.01],
[0.96, 0.01, 0.01, 0.01, 0.01],
[0.01, 0.01, 0.01, 0.01, 0.96],
[0.01, 0.96, 0.01, 0.01, 0.01],
]),
&device,
);
let x = log_softmax(logits, 1);
let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
loss_1
.into_data()
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
}
#[test]
fn display() {
let config = CrossEntropyLossConfig::new()
.with_weights(Some(alloc::vec![3., 7., 0.9]))
.with_smoothing(Some(0.5));
let loss = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{loss}"),
"CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,215 @@
use burn_core as burn;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::{config::Config, module::Module};
use super::Reduction;
/// Configuration to create a [Huber loss](HuberLoss).
#[derive(Config, Debug)]
pub struct HuberLossConfig {
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
pub delta: f32,
}
impl HuberLossConfig {
/// Initialize [Huber loss](HuberLoss).
pub fn init(&self) -> HuberLoss {
self.assertions();
HuberLoss {
delta: self.delta,
lin_bias: self.delta * self.delta * 0.5,
}
}
fn assertions(&self) {
assert!(
self.delta >= 0., // This also tests for normality
"Delta for Huber loss must be a non-negative number."
);
}
}
/// Calculate the Huber loss between the inputs and the target.
///
/// The loss for each element of the residuals `r = targets - predictions` is given by
///
/// ```text
/// L(r) = 0.5 * r^2 if |r| <= d
/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d
/// ```
///
/// where `d` is the configured `delta`. In particular, this is equal to the
/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
/// but behaves linearly instead of quadratically for large residuals.
///
/// This loss function is less sensitive to outliers than the mean squared error loss.
///
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
#[derive(Module, Debug, Clone)]
#[module(custom_display)]
pub struct HuberLoss {
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
pub delta: f32,
/// Precomputed value for the linear bias.
pub lin_bias: f32, // delta * delta * 0.5 precomputed
}
impl ModuleDisplay for HuberLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("delta", &self.delta)
.add("lin_bias", &self.lin_bias)
.optional()
}
}
impl HuberLoss {
/// Compute the loss element-wise for the predictions and targets, then reduce
/// to a single loss value.
///
/// `Reduction::Auto` behaves as `Reduction::Mean`.
///
/// # Shapes
///
/// - predictions: \[...dims\]
/// - targets: \[...dims\]
/// - output: \[1\]
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => loss.mean(),
Reduction::Sum => loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
/// Compute the loss element-wise for the predictions and targets.
///
/// # Shapes
///
/// - predictions: [...dims]
/// - targets: [...dims]
/// - output: [...dims]
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let residuals = targets - predictions;
self.forward_residuals(residuals)
}
/// Compute the loss element-wise for the given residuals.
///
/// # Shapes
///
/// - residuals: [...dims]
/// - output: [...dims]
pub fn forward_residuals<const D: usize, B: Backend>(
&self,
residuals: Tensor<B, D>,
) -> Tensor<B, D> {
let is_large = residuals.clone().abs().greater_elem(self.delta);
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
// `sign()` function, in general, suffers from a jump at 0.
// Instead the following tensor implements `delta * sign(r)` for values outside
// the bound:
let softsign = residuals.clone().clamp(-self.delta, self.delta);
// 0.5 * d^2 + d * (|r| - d) =
// d * |r| - 0.5 * d^2
// Moreover |r| = sign(r) * r
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
let inside = residuals.square().mul_scalar(0.5);
inside.mask_where(is_large, outside)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_huber_loss() {
let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
let targets = TensorData::from([0., 0., 0., 0., 0.]);
let device = Default::default();
let predict = TestTensor::<1>::from_data(predict, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let huber = HuberLossConfig::new(0.5).init();
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = huber.forward_no_reduction(predict, targets);
let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([0.284]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([1.42]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[cfg(feature = "std")]
#[test]
fn test_huber_ad_loss() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
let targets = TensorData::from([0., 0., 0., 0., 0.]);
let device = Default::default();
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
let targets = TestAutodiffTensor::from_data(targets, &device);
let loss = HuberLossConfig::new(0.5).init();
let loss = loss.forward_no_reduction(predict.clone(), targets);
let grads = loss.backward();
let grads_predict = predict.grad(&grads).unwrap();
let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
grads_predict
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = HuberLossConfig::new(0.5);
let loss = config.init();
assert_eq!(
alloc::format!("{loss}"),
"HuberLoss {delta: 0.5, lin_bias: 0.125}"
);
}
}

View File

@@ -0,0 +1,200 @@
use burn_core as burn;
use super::Reduction;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::{config::Config, module::Module};
/// Configuration to create a [KLDiv loss](KLDivLoss).
#[derive(Config, Debug)]
pub struct KLDivLossConfig {
/// Specifies whether target is the log space. Default: False.
#[config(default = false)]
pub log_target: bool,
}
impl KLDivLossConfig {
/// Initialize [KLDiv Loss](KLDivLoss).
pub fn init(&self) -> KLDivLoss {
KLDivLoss {
log_target: self.log_target,
}
}
}
/// Kullback-Leibler Divergence Loss
///
/// KL Divergence shows the difference between two probability distributions by measuring information loss
///
/// KLDivLoss =
/// ```tex
/// y_{true} \cdot (\log{y_{true}} - \log{y_{pred}})
/// ```
/// By default, the loss expects the input in the log-space.
/// The targets may also be provided in the log-space if `log_target` is true.
///
/// See
/// - [KullbackLeibler divergence](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence)
#[derive(Module, Debug, Clone)]
#[module(custom_display)]
pub struct KLDivLoss {
/// Specifies whether target is the log space. Default: False.
pub log_target: bool,
}
impl ModuleDisplay for KLDivLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("log_target", &self.log_target).optional()
}
}
impl KLDivLoss {
/// Compute the criterion on the input tensor.
///
/// `Reduction::Auto` behaves as `Reduction::BatchMean`,`Reduction::Mean` dose not align with the math definition.
///
/// # Shapes
///
/// - predictions: \[batch_size,num_targets\]
/// - targets: \[batch_size,num_targets\]
/// - output: \[1\]
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::BatchMean | Reduction::Auto => {
let batch_size = loss.dims()[0] as f32;
loss.sum().div_scalar(batch_size)
}
Reduction::Mean => loss.mean(),
Reduction::Sum => loss.sum(),
}
}
/// Compute the criterion on the input tensor without reducing.
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
match self.log_target {
true => targets.clone().exp().mul(targets.sub(predictions)),
false => {
let epsilon = 1e-8;
let log_target = targets.clone().clamp(epsilon, 1.0).log();
targets.mul(log_target.sub(predictions))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_kl_div_loss() {
let predict = TensorData::from([[-1.0, -0.5], [-2.0, -0.2]]);
let targets = TensorData::from([[0.4, 0.6], [0.1, 0.9]]);
let device = Default::default();
let predict = TestTensor::<2>::from_data(predict, &device);
let targets = TestTensor::<2>::from_data(targets, &device);
let kl_loss = KLDivLossConfig { log_target: false }.init();
let loss_sum = kl_loss.forward(predict.clone(), targets.clone(), Reduction::Sum);
let loss_batch_mean =
kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
let loss_no_reduction = kl_loss.forward_no_reduction(predict, targets);
let expected_no_reduction =
TensorData::from([[0.0334837139, -0.0064953566], [-0.0302585065, 0.0851755068]]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::absolute(1e-5));
let expected_sum = TensorData::from([0.08191]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
let expected_batch_mean = TensorData::from([0.04095]);
loss_batch_mean
.into_data()
.assert_approx_eq::<FT>(&expected_batch_mean, Tolerance::absolute(1e-5));
}
#[test]
fn test_kl_div_loss_log_target() {
let device = Default::default();
let predict = TestTensor::<1>::from_data([-1.0, -2.0], &device);
let targets = TestTensor::<1>::from_data([-0.5, -1.5], &device);
let kl_loss = KLDivLossConfig { log_target: true }.init();
let loss_no_reduction = kl_loss.forward_no_reduction(predict.clone(), targets.clone());
let expected_none = TensorData::from([0.3032653299, 0.1115650801]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_none, Tolerance::absolute(1e-5));
let loss_batch_mean =
kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
let expected_bm = TensorData::from([0.207415204965]);
loss_batch_mean
.into_data()
.assert_approx_eq::<FT>(&expected_bm, Tolerance::absolute(1e-5));
let loss_sum = kl_loss.forward(predict, targets, Reduction::Sum);
let expected_sum = TensorData::from([0.414830409931]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
}
#[cfg(feature = "std")]
#[test]
fn test_kl_div_ad_loss() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 2>;
let device = Default::default();
let predict = TestAutodiffTensor::from_data([[-1.0, -0.5]], &device).require_grad();
let targets = TestAutodiffTensor::from_data([[0.4, 0.6]], &device);
let kl_loss = KLDivLossConfig { log_target: false }.init();
let loss = kl_loss.forward(predict.clone(), targets, Reduction::Sum);
let grads = loss.backward();
let grads_predict = predict.grad(&grads).unwrap();
// d/d_pred [target * (log_target - pred)] = -target
let expected = TensorData::from([[-0.4, -0.6]]);
grads_predict
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = KLDivLossConfig { log_target: true };
let loss = config.init();
assert_eq!(alloc::format!("{loss}"), "KLDivLoss {log_target: true}");
}
}

View File

@@ -0,0 +1,672 @@
use super::Reduction;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::{Tensor, backend::Backend};
use burn_core as burn;
/// Configuration for the [Lp Loss](LpLoss) module.
///
/// # Example
///
/// ```ignore
/// use burn_nn::loss::{LpLossConfig, Reduction};
///
/// // Create L1 loss (MAE when using mean reduction)
/// let l1_loss = LpLossConfig::l1();
///
/// // Create L2 loss (MSE when using mean reduction)
/// let l2_loss = LpLossConfig::l2();
///
/// // Create custom Lp loss with p=3
/// let l3_loss = LpLossConfig::new(3.0).init();
/// ```
#[derive(Config, Debug)]
pub struct LpLossConfig {
/// The exponent `p` determining the type of error measurement.
///
/// Common values:
/// - `p = 1.0`: L1 loss (MAE with mean reduction) - robust to outliers
/// - `p = 2.0`: L2 loss (MSE with mean reduction) - standard choice, differentiable everywhere
/// - `p > 2.0`: Increasingly sensitive to large errors (outliers)
/// - `0 < p < 1`: More robust to outliers than L1 (quasi-norm)
pub p: f64,
}
impl LpLossConfig {
/// Initializes a [Lp Loss](LpLoss) module.
///
/// # Panics
///
/// Panics if `p <= 0`.
pub fn init(&self) -> LpLoss {
self.assertions();
LpLoss { p: self.p }
}
/// Creates L1 loss (p=1).
///
/// When used with `Reduction::Mean`, this computes Mean Absolute Error (MAE).
/// When used with `Reduction::Sum`, this computes Sum of Absolute Errors (SAE).
pub fn l1() -> LpLoss {
LpLoss { p: 1.0 }
}
/// Creates L2 loss (p=2).
///
/// When used with `Reduction::Mean`, this computes Mean Squared Error (MSE).
/// When used with `Reduction::Sum`, this computes Sum of Squared Errors (SSE).
pub fn l2() -> LpLoss {
LpLoss { p: 2.0 }
}
fn assertions(&self) {
assert!(self.p > 0.0, "The order of the norm p must be positive.")
}
}
/// Computes the Lp Loss between predictions and targets.
///
/// This loss function computes the element-wise p-th power of absolute errors,
/// then reduces them via mean or sum.
///
/// # Mathematical Definition
///
/// For predictions `ŷ` and targets `y`, the element-wise loss is:
///
/// ```text
/// Lᵢ = |ŷᵢ - yᵢ|ᵖ
/// ```
///
/// With mean reduction (default), the final loss is:
///
/// ```text
/// L = (1/n) × Σᵢ |ŷᵢ - yᵢ|ᵖ
/// ```
///
/// # Notes
///
/// - This implementation computes `|error|^p`, **not** the Lp norm `(Σ|error|^p)^(1/p)`.
/// - The `p = 1` case uses an optimized `abs()` operation.
/// - The `p = 2` case uses an optimized computation `error * error` instead of `powf`.
///
/// # Example
///
/// ```ignore
/// use burn_nn::loss::{LpLossConfig, Reduction};
/// use burn::tensor::Tensor;
///
/// // Create L2 loss
/// let l2_loss = LpLossConfig::l2();
///
/// let predictions: Tensor<Backend, 2> = /* model output */;
/// let targets: Tensor<Backend, 2> = /* ground truth */;
///
/// // Compute loss with mean reduction (MSE)
/// let mse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Mean);
///
/// // Compute loss with sum reduction (SSE)
/// let sse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Sum);
///
/// // Compute loss with no reduction
/// let unreduced_l2_loss = l2_loss.forward_no_reduction(predictions, targets);
/// ```
#[derive(Module, Clone, Debug)]
pub struct LpLoss {
/// The order of the norm (e.g., 1 for L1, 2 for L2).
/// Equivalently, the exponent `p` for computing `|error|^p`.
pub p: f64,
}
impl LpLoss {
/// Computes the element-wise loss `|error|^p` with reduction.
///
/// # Arguments
///
/// * `predictions` - The model's predicted values.
/// * `targets` - The ground truth target values.
/// * `reduction` - Specifies how to reduce the element-wise losses:
/// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.
/// - `Reduction::Sum`: Returns the sum of all element-wise losses.
///
/// # Returns
///
/// A scalar tensor containing the reduced loss value.
///
/// # Shapes
///
/// - predictions: `[...dims]` - Any shape
/// - targets: `[...dims]` - Must match predictions shape
/// - output: `[1]` - Scalar loss value
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let unreduced_loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
Reduction::Sum => unreduced_loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
/// Computes the element-wise loss `|error|^p` without reduction.
///
/// # Arguments
///
/// * `predictions` - The model's predicted values.
/// * `targets` - The ground truth target values.
///
/// # Returns
///
/// A tensor of the same shape as the inputs, containing `|prediction - target|^p`
/// for each element.
///
/// # Shapes
///
/// - predictions: `[...dims]` - Any shape
/// - targets: `[...dims]` - Must match predictions shape
/// - output: `[...dims]` - Same shape as inputs
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let error = predictions.sub(targets);
// Use simplified/optimized expressions for common cases (p = 1, p = 2)
if self.p == 1.0 {
// L1 loss
error.abs()
} else if self.p == 2.0 {
// L2 loss
error.clone().mul(error)
} else {
error.abs().powf_scalar(self.p)
}
}
/// Computes the element-wise loss `|error|^p` with reduction over specified dimensions.
///
/// Calculates element-wise `|predictions - targets|^p`, then takes the mean
/// over the specified dimensions. Useful for per-sample or per-channel losses (e.g., when
/// working with images).
///
/// Dimensions can be provided in any order. They are sorted internally and
/// reduced from highest to lowest to ensure indices remain valid.
///
/// # Arguments
///
/// * `predictions` - The model's predicted values.
/// * `targets` - The ground truth target values.
/// * `dims` - Dimensions to reduce over.
///
/// # Returns
///
/// A tensor with the specified dimensions reduced to size 1.
///
/// # Example
///
/// ```ignore
/// // Image tensor: [batch, C, H, W]
/// let l2_loss = LpLossConfig::l2();
///
/// // Per-image MSE for PSNR: reduce over C, H, W → [batch, 1, 1, 1]
/// let mse_per_image = l2_loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
/// ```
pub fn forward_reduce_dims<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
dims: &[usize],
) -> Tensor<B, D> {
let error = self.forward_no_reduction(predictions, targets);
// Sort the dimensions to ascending order
let mut sorted_dims = dims.to_vec();
sorted_dims.sort();
// Reduce over specified dimensions
error.mean_dims(sorted_dims.as_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_lp_loss_l1_constructor() {
let loss_func_l1 = LpLossConfig::l1();
let loss_func_p1 = LpLossConfig::new(1.0).init();
assert_eq!(loss_func_l1.p, 1.0);
assert_eq!(loss_func_l1.p, loss_func_p1.p);
}
#[test]
fn test_lp_loss_l2_constructor() {
let loss_func_l2 = LpLossConfig::l2();
let loss_func_p2 = LpLossConfig::new(2.0).init();
assert_eq!(loss_func_l2.p, 2.0);
assert_eq!(loss_func_l2.p, loss_func_p2.p);
}
#[test]
fn test_lp_loss_l1() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::l1();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.0]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([4.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_l2() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.5]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([6.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_p_half() {
// L0.5 quasi-norm: more robust to outliers than L1
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 0.0]]),
&device,
);
let loss_func = LpLossConfig::new(0.5).init();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
// |1-2|^0.5 = 1, |2-1|^0.5 = 1, |3-3|^0.5 = 0, |4-0|^0.5 = 2
let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.0]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([4.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_p3() {
// L3 norm: more sensitive to outliers than L2
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::new(3.0).init();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
// |1-2|^3 = 1, |2-1|^3 = 1, |3-3|^3 = 0, |4-2|^3 = 8
let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([2.5]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([10.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_zero_error() {
// Test when predictions exactly match targets
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = predictions.clone();
let loss_func_l1 = LpLossConfig::l1();
let loss_func_l2 = LpLossConfig::l2();
let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto);
let expected = TensorData::from([0.0]);
l1_loss.into_data().assert_eq(&expected, false);
l2_loss.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_negative_errors() {
// Test that negative errors are handled correctly (absolute value)
let device = Default::default();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);
let targets =
Tensor::<TestBackend, 1>::from_data(TensorData::from([3.0, 4.0, 5.0]), &device);
let loss_func_l1 = LpLossConfig::l1();
let loss_func_p1 = LpLossConfig::new(1.0).init();
let loss_no_reduction_l1 =
loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone());
let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets);
// All errors are negative: 1-3=-2, 2-4=-2, 3-5=-2, but |error| = 2
let expected = TensorData::from([2.0, 2.0, 2.0]);
loss_no_reduction_l1.into_data().assert_eq(&expected, false);
loss_no_reduction_p1.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_3d_tensor() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]),
&device,
);
let loss_func_l2 = LpLossConfig::l2();
let loss_func_p2 = LpLossConfig::new(2.0).init();
let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto);
// Errors: 1, 0, 0, -1, 1, 0, 0, -2
// Squared: 1, 0, 0, 1, 1, 0, 0, 4
// Mean: 7/8 = 0.875
let expected = TensorData::from([0.875]);
loss_l2.into_data().assert_eq(&expected, false);
loss_p2.into_data().assert_eq(&expected, false);
}
#[test]
#[should_panic(expected = "The order of the norm p must be positive.")]
fn test_lp_loss_negative_p_panics() {
let _ = LpLossConfig::new(-1.0).init();
}
#[test]
#[should_panic(expected = "The order of the norm p must be positive.")]
fn test_lp_loss_zero_p_panics() {
let _ = LpLossConfig::new(0.0).init();
}
#[test]
fn test_lp_loss_fractional_p() {
// Test p = 1.5
let device = Default::default();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.0, 4.0]), &device);
let targets = Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 0.0]), &device);
let loss_func = LpLossConfig::new(1.5).init();
let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
// |0-1|^1.5 = 1, |4-0|^1.5 = 8
let expected = TensorData::from([1.0, 8.0]);
loss_no_reduction.into_data().assert_eq(&expected, false);
}
#[test]
fn test_forward_reduce_dims_single_dim() {
let device = Default::default();
// Shape: [2, 3]
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
&device,
);
let loss_func_l2 = LpLossConfig::l2();
let loss_func_p2 = LpLossConfig::new(2.0).init();
// Reduce over dim 1 -> should give [2, 1] shape
let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]);
let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]);
// Errors row 0: [1, 0, -3] -> squared: [1, 0, 9] -> mean: 10/3
// Errors row 1: [3, 0, 0] -> squared: [9, 0, 0] -> mean: 3
let expected = TensorData::from([[10.0 / 3.0], [3.0]]);
loss_l2
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
loss_p2
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_first_dim() {
let device = Default::default();
// Shape: [2, 3]
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
// Reduce over dim 0 -> should give [1, 3] shape
let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]);
// Squared errors: [[1, 0, 9], [9, 0, 0]]
// Mean over dim 0: [5, 0, 4.5]
let expected = TensorData::from([[5.0, 0.0, 4.5]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_multiple_dims() {
let device = Default::default();
// Shape: [2, 2, 2]
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]),
&device,
);
let loss_func = LpLossConfig::l2();
// Reduce over dims 1 and 2 -> should give [2, 1, 1] shape
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
// Batch 0 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25
// Batch 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25
let expected = TensorData::from([[[1.25]], [[1.25]]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_all_dims() {
let device = Default::default();
// Shape: [2, 2]
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
// Reduce over all dims -> should give [1, 1] shape
let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]);
// Errors: [[-1, 1], [0, 2]] -> squared: [[1, 1], [0, 4]] -> mean: 1.5
let expected = TensorData::from([[1.5]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_image_batch() {
// Simulate per-image loss for [batch, C, H, W] tensor (common use case for PSNR)
let device = Default::default();
// Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)
let predictions = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[1.0, 2.0], [3.0, 4.0]]], // Image 1
[[[5.0, 6.0], [7.0, 8.0]]], // Image 2
]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[0.0, 2.0], [3.0, 6.0]]], // Target 1
[[[5.0, 5.0], [7.0, 7.0]]], // Target 2
]),
&device,
);
let loss_func = LpLossConfig::l2();
// Reduce over C, H, W (dims 1, 2, 3) to get per-image MSE
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
// Image 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 1.25
// Image 2 errors: [[0, 1], [0, 1]] -> squared: [[0, 1], [0, 1]] -> mean: 0.5
let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_with_p1() {
let device = Default::default();
// Shape: [2, 3]
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]),
&device,
);
let loss_func = LpLossConfig::l1();
// Reduce over dim 1 -> should give [2, 1] shape
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]);
// Abs errors row 0: [1, 3, 0] -> mean: 4/3
// Abs errors row 1: [3, 0, 3] -> mean: 2
let expected = TensorData::from([[4.0 / 3.0], [2.0]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_empty_dims() {
// Reducing over no dimensions should return the unreduced loss
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 2.0], [3.0, 6.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss_reduce_dims =
loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
// Should be equivalent
loss_reduce_dims
.into_data()
.assert_eq(&loss_no_reduction.into_data(), true);
}
#[test]
fn test_forward_reduce_dims_zero_error() {
let device = Default::default();
// Shape: [2, 2, 2]
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = predictions.clone();
let loss_func = LpLossConfig::l2();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
// All zeros, reduced to shape: [2, 1, 1]
let expected = TensorData::from([[[0.0]], [[0.0]]]);
loss.into_data().assert_eq(&expected, false);
}
}

View File

@@ -0,0 +1,23 @@
mod binary_cross_entropy;
mod cosine_embedding;
mod cross_entropy;
mod ctc;
mod huber;
mod kldiv;
mod lp_loss;
mod mse;
mod poisson;
mod reduction;
mod smooth_l1;
pub use binary_cross_entropy::*;
pub use cosine_embedding::*;
pub use cross_entropy::*;
pub use ctc::*;
pub use huber::*;
pub use kldiv::*;
pub use lp_loss::*;
pub use mse::*;
pub use poisson::*;
pub use reduction::*;
pub use smooth_l1::*;

View File

@@ -0,0 +1,93 @@
use burn_core as burn;
use crate::loss::reduction::Reduction;
use burn::module::Module;
use burn::tensor::{Tensor, backend::Backend};
/// Calculate the mean squared error loss from the input logits and the targets.
#[derive(Module, Clone, Debug)]
pub struct MseLoss;
impl Default for MseLoss {
fn default() -> Self {
Self::new()
}
}
impl MseLoss {
/// Create the criterion.
pub fn new() -> Self {
Self
}
/// Compute the criterion on the input tensor.
///
/// # Shapes
///
/// - logits: [batch_size, num_targets]
/// - targets: [batch_size, num_targets]
pub fn forward<const D: usize, B: Backend>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let tensor = self.forward_no_reduction(logits, targets);
match reduction {
Reduction::Mean | Reduction::Auto => tensor.mean(),
Reduction::Sum => tensor.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
/// Compute the criterion on the input tensor without reducing.
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
logits.sub(targets).square()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn test_mse_loss() {
let device = Default::default();
let logits = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let mse = MseLoss::new();
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
let loss_sum = mse.forward(logits, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.5]);
loss.into_data().assert_eq(&expected, false);
let expected = TensorData::from([6.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn display() {
let loss = MseLoss::new();
assert_eq!(alloc::format!("{loss}"), "MseLoss");
}
}

View File

@@ -0,0 +1,417 @@
use burn_core as burn;
use core::f32::consts::PI;
use burn::tensor::cast::ToElement;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::{config::Config, module::Module};
use super::Reduction;
/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance.
///
/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss
/// behavior, such as whether the input is in log-space, whether to include the Stirling
/// approximation term, and a small epsilon value to avoid numerical instability.
#[derive(Config, Debug)]
pub struct PoissonNllLossConfig {
/// If `true`, the predictions are expected to be in log-space.
///
/// When `log_input` is `true`, the loss is computed as:
/// ```text
/// L(predictions, target) = exp(predictions) - target * predictions
/// ```
/// When `log_input` is `false`, the loss is computed as:
/// ```text
/// L(predictions, target) = predictions - target * log(predictions + eps)
/// ```
#[config(default = true)]
pub log_input: bool,
/// Whether to compute the full loss, including the Stirling approximation term.
///
/// When `full` is `true`, the Stirling approximation term is added to the loss:
/// ```text
/// target * log(target) - target + 0.5 * log(2 * PI * target)
/// ```
#[config(default = false)]
pub full: bool,
/// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
///
/// This epsilon value is added to the predictions to ensure numerical stability
/// when computing the logarithm.
#[config(default = 1e-8)]
pub eps: f64,
}
impl PoissonNllLossConfig {
/// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration.
///
/// # Panics
/// - Panics if `eps` is not a positive number.
pub fn init(&self) -> PoissonNllLoss {
self.assertions();
PoissonNllLoss {
log_input: self.log_input,
full: self.full,
eps: self.eps,
}
}
/// Validates the configuration parameters.
///
/// # Panics
/// - Panics if `eps` is not a positive number.
fn assertions(&self) {
assert!(
self.eps > 0.,
"eps for PoissonNllLoss must be a positive number."
);
}
}
/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target.
///
/// This loss function is used when the target values are assumed to follow a Poisson distribution.
/// The loss is defined as:
/// ```text
/// target ~ Poisson(input)
/// L(predictions, target) = predictions - target * log(predictions) + log(target!)
/// ```
/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula.
/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss.
///
/// For more details, see:
/// <https://en.wikipedia.org/wiki/Poisson_regression#Maximum_likelihood-based_parameter_estimation>
#[derive(Module, Debug, Clone)]
#[module(custom_display)]
pub struct PoissonNllLoss {
/// If `true`, the predictions are expected to be in log-space.
pub log_input: bool,
/// Whether to compute the full loss, including the Stirling approximation term.
pub full: bool,
/// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
pub eps: f64,
}
impl ModuleDisplay for PoissonNllLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("log_input", &self.log_input)
.add("full", &self.full)
.add("eps", &self.eps)
.optional()
}
}
impl PoissonNllLoss {
/// Computes the loss element-wise for the given predictions and targets, then reduces
/// the result to a single loss value.
///
/// # Arguments
/// - `predictions`: The predicted values.
/// - `targets`: The target values.
/// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`.
///
/// # Shapes
/// - `predictions`: `[...dims]`
/// - `targets`: `[...dims]`
/// - `output`: `[1]`
///
/// # Panics
/// - Panics if the shapes of `predictions` and `targets` do not match.
/// - Panics if any target value is negative.
/// - Panics if `log_input` is `false` and any prediction value is negative.
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => loss.mean(),
Reduction::Sum => loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
/// Computes the loss element-wise for the given predictions and targets without reduction.
///
/// # Arguments
/// - `predictions`: The predicted values.
/// - `targets`: The target values.
///
/// # Shapes
/// - `predictions`: `[...dims]`
/// - `targets`: `[...dims]`
/// - `output`: `[...dims]`
///
/// # Panics
/// - Panics if the shapes of `predictions` and `targets` do not match.
/// - Panics if any target value is negative.
/// - Panics if `log_input` is `false` and any prediction value is negative.
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
self.assertions(&predictions, &targets);
let mut loss;
if self.log_input {
loss = predictions.clone().exp() - targets.clone() * predictions;
} else {
loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
}
if self.full {
let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
+ (targets.clone() * 2. * PI).log() * 0.5;
loss = loss
+ log_stirling_term
.mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
}
loss
}
/// Validates the input tensors for the loss computation.
///
/// # Panics
/// - Panics if the shapes of `predictions` and `targets` do not match.
/// - Panics if any target value is negative.
/// - Panics if `log_input` is `false` and any prediction value is negative.
fn assertions<const D: usize, B: Backend>(
&self,
predictions: &Tensor<B, D>,
targets: &Tensor<B, D>,
) {
let predictions_dims = predictions.dims();
let targets_dims = targets.dims();
assert!(
predictions_dims == targets_dims,
"Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?})."
);
assert!(
targets
.clone()
.greater_equal_elem(0.)
.all()
.into_scalar()
.to_bool(),
"All the values of `targets` must be non-negative."
);
if !self.log_input {
assert!(
predictions
.clone()
.greater_equal_elem(0.)
.all()
.into_scalar()
.to_bool(),
"When `log_input` is `false`, all the values of `predictions` must be non-negative."
);
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::approx_constant)]
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_poisson_nll_loss() {
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().init();
let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([21.0321]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([126.1929]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_poisson_nll_loss_no_log_input() {
let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_poisson_nll_loss_full() {
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_full(true).init();
let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([21.9920]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([131.9518]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[cfg(feature = "std")]
#[test]
fn test_poisson_nll_loss_gradients() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
let device = Default::default();
let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
let predictions2 = predictions1.clone();
let targets = TestAutodiffTensor::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_full(false).init();
let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
let loss_full_sum =
poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
let grads = loss_sum.backward();
let grads_full = loss_full_sum.backward();
let grads_predictions1 = predictions1.grad(&grads).unwrap();
let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
grads_predictions1
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
grads_predictions2
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
#[should_panic = "eps for PoissonNllLoss must be a positive number."]
fn test_negative_eps() {
let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
}
#[test]
#[should_panic = "All the values of `targets` must be non-negative."]
fn test_targets_with_negative_values() {
let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().init();
let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
}
#[test]
#[should_panic = "Shape of targets"]
fn test_shape_tensors() {
let predictions = TensorData::from([0., 1., 2.]);
let targets = TensorData::from([0., 1.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().init();
let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
}
#[test]
#[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
fn test_exp_predictions_non_negative() {
let predictions = TensorData::from([0.3, -0.1, 0.4]);
let targets = TensorData::from([0., 1., 0.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
}
#[test]
fn display() {
let config = PoissonNllLossConfig::new();
let loss = config.init();
assert_eq!(
alloc::format!("{loss}"),
"PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
);
}
}

View File

@@ -0,0 +1,19 @@
use burn_core as burn;
use burn::config::Config;
/// The reduction type for the loss.
#[derive(Config, Debug)]
pub enum Reduction {
/// The mean of the losses will be returned.
Mean,
/// The sum of the losses will be returned.
Sum,
/// The sum of the losses divided by the batch_size will be returned.
BatchMean,
/// The mean of the losses will be returned.
Auto,
}

View File

@@ -0,0 +1,520 @@
use super::Reduction;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::{Tensor, backend::Backend};
use burn_core as burn;
/// Configuration for the [SmoothL1Loss](SmoothL1Loss) module.
///
/// Smooth L1 loss combines L1 and L2 loss, using L2 loss for small errors (below beta)
/// and L1 loss for large errors (above beta). This makes it less sensitive to outliers
/// than MSE while maintaining smooth gradients near zero.
///
/// # Example
///
/// ```ignore
/// use burn_nn::loss::{SmoothL1LossConfig, Reduction};
///
/// // Create Smooth L1 loss with default beta=1.0
/// let smooth_l1 = SmoothL1LossConfig::new().init();
///
/// // Create with custom beta
/// let smooth_l1_custom = SmoothL1LossConfig::new().with_beta(0.5).init();
/// ```
#[derive(Config, Debug)]
pub struct SmoothL1LossConfig {
/// Specifies the threshold at which to change between L1 and L2 loss.
/// The value must be positive. Default: 1.0
#[config(default = 1.0)]
pub beta: f32,
}
impl SmoothL1LossConfig {
/// Initializes a [Smooth L1 Loss](SmoothL1Loss) module.
///
/// # Panics
///
/// Panics if `beta <= 0`.
pub fn init(&self) -> SmoothL1Loss {
self.assertions();
SmoothL1Loss { beta: self.beta }
}
fn assertions(&self) {
assert!(self.beta > 0.0, "The parameter beta must be positive.")
}
}
/// Computes the Smooth L1 Loss between predictions and targets.
///
/// This loss function uses L2 loss for small errors (below beta) and L1 loss for
/// large errors (above beta), providing robustness to outliers while maintaining
/// smooth gradients near |x - y| = 0.
///
/// # Mathematical Definition
///
/// For predictions `x` and targets `y`, the element-wise loss is:
///
/// - L_i = 0.5 * (x_i - y_i)² / beta , if |x_i - y_i| < beta
/// - L_i = |x_i - y_i| - 0.5 * beta , otherwise
///
/// # Notes
///
/// Smooth L1 loss is closely related to HuberLoss since it is equivalent to HuberLoss
/// scaled by `1/beta`:
/// `SmoothL1(x, y, beta) = Huber(x, y, beta) / beta`
///
/// This leads to the following differences:
///
/// - As beta approaches 0, Smooth L1 loss converges to L1Loss, while HuberLoss converges to 0.
/// When beta = 0, Smooth L1 loss is equivalent to L1 loss. Thus, the `beta`
/// parameter in Burn must be positive. L1Loss should be used for beta = 0.
/// - As beta approaches positive infinity, Smooth L1 loss converges to a constant 0 loss, while
/// HuberLoss converges to L2Loss.
///
/// # Example
///
/// ```rust,ignore
/// use burn_nn::loss::{SmoothL1LossConfig, Reduction};
/// use burn::tensor::Tensor;
///
/// // Create Smooth L1 loss with the default beta=1.0
/// let smooth_l1 = SmoothL1LossConfig::new().init();
///
/// let predictions: Tensor<Backend, 2> = /* model output */;
/// let targets: Tensor<Backend, 2> = /* ground truth */;
///
/// // Compute element-wise loss without reduction
/// let element_wise = smooth_l1.forward(predictions.clone(), targets.clone());
///
/// // Compute loss with mean reduction
/// let loss = smooth_l1.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);
///
/// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1]
/// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
/// ```
#[derive(Module, Clone, Debug)]
pub struct SmoothL1Loss {
/// Specifies the threshold at which to change between L1 and L2 loss.
/// The value must be positive. Default: 1.0
pub beta: f32,
}
impl SmoothL1Loss {
/// Computes the element-wise smooth L1 loss without reduction.
///
/// # Arguments
///
/// - `predictions` - The model's predicted values.
/// - `targets` - The ground truth target values.
///
/// # Returns
///
/// A tensor of the same shape as the inputs, containing the smooth L1 loss
/// for each element.
///
/// # Shapes
///
/// - predictions: `[...dims]` - Any shape
/// - targets: `[...dims]` - Must match predictions shape
/// - output: `[...dims]` - Same shape as inputs
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let error = predictions.sub(targets);
let abs_error = error.clone().abs();
// The L1 case: |error| - 0.5 * beta (when |error| >= beta)
let l1_loss = abs_error.clone().sub_scalar(0.5 * self.beta);
// The L2 case: 0.5 * (error)^2 / beta (when |error| < beta)
let l2_loss = error.square().mul_scalar(0.5).div_scalar(self.beta);
let l2_mask = abs_error.lower_elem(self.beta);
l1_loss.mask_where(l2_mask, l2_loss)
}
/// Computes the smooth L1 loss with reduction.
///
/// # Arguments
///
/// - `predictions` - The model's predicted values.
/// - `targets` - The ground truth target values.
/// - `reduction` - Specifies how to reduce the element-wise losses:
/// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.
/// - `Reduction::Sum`: Returns the sum of all element-wise losses.
///
/// # Returns
///
/// A scalar tensor containing the reduced loss value.
///
/// # Shapes
///
/// - predictions: `[...dims]` - Any shape
/// - targets: `[...dims]` - Must match predictions shape
/// - output: `[1]` - Scalar loss value
pub fn forward_with_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let unreduced_loss = self.forward(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
Reduction::Sum => unreduced_loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
/// Computes the smooth L1 loss with reduction over specified dimensions.
///
/// Calculates element-wise smooth L1 loss, then takes the mean
/// over the specified dimensions. Useful for per-sample or per-channel losses.
///
/// Dimensions can be provided in any order. They are sorted internally and
/// reduced from highest to lowest to ensure indices remain valid.
///
/// # Arguments
///
/// - `predictions` - The model's predicted values.
/// - `targets` - The ground truth target values.
/// - `dims` - Dimensions to reduce over.
///
/// # Returns
///
/// A tensor with the specified dimensions reduced to size 1.
///
/// # Example
///
/// ```ignore
/// // Consider image tensor with shape [batch, C, H, W]
/// let smooth_l1 = SmoothL1LossConfig::new().init();
///
/// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1]
/// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
/// ```
pub fn forward_reduce_dims<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
dims: &[usize],
) -> Tensor<B, D> {
let error = self.forward(predictions, targets);
// Sort the dimensions to ascending order
let mut sorted_dims = dims.to_vec();
sorted_dims.sort();
// Reduce over specified dimensions
error.mean_dims(sorted_dims.as_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
// =========================================================================
// Configuration Tests
// =========================================================================
#[test]
fn test_smooth_l1_config_default_beta() {
let loss = SmoothL1LossConfig::new().init();
assert_eq!(loss.beta, 1.0);
}
#[test]
fn test_smooth_l1_config_custom_beta() {
let loss = SmoothL1LossConfig::new().with_beta(2.5).init();
assert_eq!(loss.beta, 2.5);
}
#[test]
#[should_panic(expected = "The parameter beta must be positive")]
fn test_smooth_l1_config_beta_zero_panics() {
SmoothL1LossConfig::new().with_beta(0.0).init();
}
#[test]
#[should_panic(expected = "The parameter beta must be positive")]
fn test_smooth_l1_config_beta_negative_panics() {
SmoothL1LossConfig::new().with_beta(-1.0).init();
}
// =========================================================================
// Forward Pass (Element-wise) Tests
// =========================================================================
#[test]
fn test_smooth_l1_forward_l2_region() {
// Beta = 1.0, errors = 0.0 and 0.5 (both < beta, use L2 formula)
// L2 formula: 0.5 * error^2 / beta
// error = 0.0 -> loss = 0.5 * 0.0 / 1.0 = 0.0
// error = 0.5 -> loss = 0.5 * 0.25 / 1.0 = 0.125
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.5]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([[0.0_f32, 0.125]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_l1_region() {
// Beta = 1.0, errors = 0.0 and 2.0 (2.0 >= beta, use L1 formula)
// L1 formula: |error| - 0.5 * beta
// L2 formula: 0.5 * (error)^2 / beta
// error = 0.0 -> loss = 0.0
// error = 2.0 -> loss = 2.0 - 0.5 = 1.5
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 2.0]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([[0.0_f32, 1.5]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_zero_error() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[1.0_f32, 2.0, 3.0]]), &device);
let targets = predictions.clone();
let output = loss.forward(predictions, targets);
let expected = TensorData::from([[0.0_f32, 0.0, 0.0]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_negative_errors() {
// Ensure absolute value is used correctly
// L1 formula: |error| - 0.5 * beta
// L2 formula: 0.5 * (error)^2 / beta
// Beta = 1.0, error = -3.0 (L1: 3.0 - 0.5 = 2.5)
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([-3.0_f32]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([1], &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([2.5_f32]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_mixed_regions() {
// Test with errors in both L1 and L2 regions
// Beta = 1.0
// L1 formula: |error| - 0.5 * beta
// L2 formula: 0.5 * (error)^2 / beta
// error = 0.5 -> L2: 0.5 * 0.25 / 1 = 0.125
// error = 1.5 -> L1: 1.5 - 0.5 = 1.0
// error = 3.0 -> L1: 3.0 - 0.5 = 2.5
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.5_f32, 1.5, 3.0]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([3], &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([0.125_f32, 1.0, 2.5]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_custom_beta_values() {
// Test with beta = 0.5
// error = 0.25 (< beta): L2 = 0.5 * 0.0625 / 0.5 = 0.0625
// error = 1.0 (>= beta): L1 = 1.0 - 0.25 = 0.75
let device = Default::default();
let loss = SmoothL1LossConfig::new().with_beta(0.5).init();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.25_f32, 1.0]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([2], &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([0.0625_f32, 0.75]);
output.into_data().assert_eq(&expected, false);
}
// =========================================================================
// forward_with_reduction Tests
// =========================================================================
#[test]
fn test_smooth_l1_reduction_mean() {
// Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5)
// Mean: (0.125 + 1.5) / 2 = 0.8125
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward_with_reduction(predictions, targets, Reduction::Mean);
let expected = TensorData::from([0.8125_f32]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_reduction_sum() {
// Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5)
// Sum: 1.625
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward_with_reduction(predictions, targets, Reduction::Sum);
let expected = TensorData::from([1.625_f32]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_reduction_auto_equals_mean() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 1>::from_data(TensorData::from([2.0_f32]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([1], &device);
let mean_out =
loss.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);
let auto_out = loss.forward_with_reduction(predictions, targets, Reduction::Auto);
mean_out.into_data().assert_eq(&auto_out.into_data(), false);
}
// =========================================================================
// Dimension Reduction Tests
// =========================================================================
#[test]
fn test_smooth_l1_forward_reduce_dims_single_dim() {
// Beta = 2.0
// L1 formula: |error| - 0.5 * beta
// L2 formula: 0.5 * (error)^2 / beta
// Row 0: errors [0.0, 1.0, 4.0]
// error = 0.0 -> L2: 0.0
// error = 1.0 -> L2: 0.5 * 1.0 / 2.0 = 0.25
// error = 4.0 -> L1: 4.0 - 1.0 = 3.0
// Mean = 3.25 / 3 = 1.083333...
// Row 1: errors [0.0, 0.0, 0.0] -> Mean = 0.0
let device = Default::default();
let loss = SmoothL1LossConfig::new().with_beta(2.0).init();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0_f32, 1.0, 4.0], [5.0_f32, 5.0, 5.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0_f32, 0.0, 0.0], [5.0_f32, 5.0, 5.0]]),
&device,
);
let output = loss.forward_reduce_dims(predictions, targets, &[1]);
let expected = TensorData::from([[3.25_f32 / 3.0], [0.0]]); // 3.25/3 = 1.0833...
output
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_smooth_l1_forward_reduce_dims_image_batch() {
// Simulate per-image Smooth L1 loss for [batch, C, H, W] tensor
// (common in object detection like Fast R-CNN)
let device = Default::default();
let loss = SmoothL1LossConfig::new().init(); // beta = 1.0
// Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)
let predictions = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[0.5_f32, 2.0], [0.0, 3.0]]], // Image 1
[[[1.0_f32, 0.0], [0.5, 1.5]]], // Image 2
]),
&device,
);
let targets = Tensor::<TestBackend, 4>::zeros([2, 1, 2, 2], &device);
// Reduce over C, H, W (dims 1, 2, 3) to get per-image loss
let output = loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
// Image 1: losses [[0.125, 1.5], [0.0, 2.5]] -> mean: 4.125 / 4 = 1.03125
// Image 2: losses [[0.5, 0.0], [0.125, 1.0]] -> mean: 1.625 / 4 = 0.40625
let expected = TensorData::from([[[[1.03125_f32]]], [[[0.40625_f32]]]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_reduce_dims_unsorted() {
// Test that unsorted dimensions are handled correctly (sorted internally)
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0_f32, 2.0], [3.0, 4.0]], [[5.0_f32, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = Tensor::<TestBackend, 3>::zeros([2, 2, 2], &device);
// Pass dims in reverse order
let output = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[2, 1]);
let expected_output = loss.forward_reduce_dims(predictions, targets, &[1, 2]);
output
.into_data()
.assert_eq(&expected_output.into_data(), false);
}
#[test]
fn test_smooth_l1_forward_reduce_dims_empty_dims() {
// Reducing over no dimensions should return the unreduced loss
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.5_f32, 2.0], [0.0, 3.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::zeros([2, 2], &device);
let loss_reduce_dims = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
let loss_no_reduction = loss.forward(predictions, targets);
loss_reduce_dims
.into_data()
.assert_eq(&loss_no_reduction.into_data(), false);
}
}

View File

@@ -0,0 +1,621 @@
//! Cross-Attention Module for Burn
//!
//! Features:
//! - Asymmetric Input Shapes (Query vs Context)
//! - Grouped Query Attention (GQA) & Multi-Query Attention (MQA) support
//! - Quantization-Safe Masking (min_float)
//! - Sparse-Ready (quiet_softmax)
//! - KV Caching for Streaming Inference
use crate::cache::TensorCache;
use crate::modules::{Linear, LinearConfig};
use crate::{Dropout, DropoutConfig};
use burn_core as burn;
use burn::{
config::Config,
module::{Initializer, Module},
tensor::{
Bool, Tensor,
activation::{quiet_softmax, softmax},
backend::Backend,
},
};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
#[derive(Config, Debug)]
/// Configuration to create a [CrossAttention](CrossAttention) layer using the [init function](CrossAttentionConfig::init).
pub struct CrossAttentionConfig {
/// Dimension of the Query (e.g., Decoder state).
pub d_model: usize,
/// Dimension of the Context (e.g., Encoder audio embeddings).
pub d_context: usize,
/// Number of heads for the Query.
pub n_heads: usize,
/// Number of heads for Key/Value (Set to 1 for MQA, set to n_heads for MHA).
pub n_heads_kv: usize,
/// Dimension of a single head.
pub d_head: usize,
/// Dropout rate.
#[config(default = 0.1)]
pub dropout: f64,
/// Masking value. Use -1.0e4 for f16/bf16 safety.
#[config(default = -1.0e4)]
pub min_float: f64,
/// Use quiet_softmax to allow zero-attention (good for sparse/quantized models).
#[config(default = false)]
pub quiet_softmax: bool,
}
#[derive(Module, Debug)]
/// The Cross attention module
///
/// # Params
///
/// - `query`: [`Linear`] layer with `d_model` input and output features.
/// - `key`: [`Linear`] layer with `d_model` input and output features.
/// - `value`: [`Linear`] layer with `d_model` input and output features.
/// - `output`: [`Linear`] layer with `d_model` input and output features.
///
/// Should be created with [CrossAttentionConfig].
pub struct CrossAttention<B: Backend> {
query: Linear<B>,
key: Linear<B>,
value: Linear<B>,
output: Linear<B>,
dropout: Dropout,
n_heads: usize,
n_heads_kv: usize,
d_head: usize,
scale: f64,
min_float: f64,
quiet_softmax: bool,
}
/// Cache for the [Cross Attention](CrossAttention) layer.
///
/// To be used during inference when context is constant.
pub struct CrossAttentionCache<B: Backend> {
/// Cached key tensor.
pub k: TensorCache<B, 4>,
/// Cached value tensor.
pub v: TensorCache<B, 4>,
}
impl<B: Backend> CrossAttentionCache<B> {
/// Create a new empty cache.
pub fn new() -> Self {
Self {
k: TensorCache::empty(),
v: TensorCache::empty(),
}
}
}
impl<B: Backend> Default for CrossAttentionCache<B> {
fn default() -> Self {
Self::new()
}
}
impl CrossAttentionConfig {
/// Initializes a new cross-attention module.
///
/// # Arguments
///
/// * `device` - The device on which to initialize the module.
///
/// # Returns
///
/// A new [CrossAttention] module.
pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {
// Safety Rail for GQA
assert_eq!(
self.n_heads % self.n_heads_kv,
0,
"Query heads must be divisible by KV heads"
);
let init_linear = |in_dim, out_dim| {
LinearConfig::new(in_dim, out_dim)
.with_initializer(Initializer::KaimingUniform {
gain: 1.0 / (self.d_head as f64).sqrt(),
fan_out_only: false,
})
.init(device)
};
CrossAttention {
// ADVICE: Asymmetric Projections
query: init_linear(self.d_model, self.n_heads * self.d_head),
key: init_linear(self.d_context, self.n_heads_kv * self.d_head),
value: init_linear(self.d_context, self.n_heads_kv * self.d_head),
output: init_linear(self.n_heads * self.d_head, self.d_model),
dropout: DropoutConfig::new(self.dropout).init(),
n_heads: self.n_heads,
n_heads_kv: self.n_heads_kv,
d_head: self.d_head,
scale: (self.d_head as f64).sqrt().recip(),
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}
}
impl<B: Backend> CrossAttention<B> {
/// Applies cross-attention to query using context as key and value.
///
/// # Arguments
///
/// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
/// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
/// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
///
/// # Returns
///
/// Output tensor of shape `[batch, seq_len_query, d_model]`.
pub fn forward(
&self,
query: Tensor<B, 3>,
context: Tensor<B, 3>,
mask: Option<Tensor<B, 2, Bool>>,
) -> Tensor<B, 3> {
let [batch, l_q, _] = query.dims();
let [_, l_k, _] = context.dims();
// 1. Projections
let q = self.query.forward(query);
let k = self.key.forward(context.clone());
let v = self.value.forward(context);
// 2. Reshape Heads
// Q: [Batch, Heads, L_q, D_head]
let q = q
.reshape([batch, l_q, self.n_heads, self.d_head])
.swap_dims(1, 2);
// K, V: [Batch, Heads_KV, L_k, D_head]
let k = k
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2);
let v = v
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2);
// 3. GQA Expansion
// ADVICE: Handle GQA by repeating KV heads to match Query heads
let (k, v) = if self.n_heads != self.n_heads_kv {
let n_rep = self.n_heads / self.n_heads_kv;
(self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
} else {
(k, v)
};
// 4. Score Calculation
let scores = q.matmul(k.transpose()) * self.scale;
// 5. Masking
// ADVICE: Use min_float for F16/FP8 safety
let scores = if let Some(mask) = mask {
let mask = mask.reshape([batch, 1, 1, l_k]);
scores.mask_fill(mask, self.min_float)
} else {
scores
};
// 6. Softmax
// ADVICE: Optional Quiet Softmax for sparse networks
let weights = if self.quiet_softmax {
quiet_softmax(scores, 3)
} else {
softmax(scores, 3)
};
let weights = self.dropout.forward(weights);
// 7. Aggregate & Output
let output = weights.matmul(v);
let output = output
.swap_dims(1, 2)
.reshape([batch, l_q, self.n_heads * self.d_head]);
self.output.forward(output)
}
/// Applies cross-attention to query using context as key and value.
///
/// This method uses a cache to avoid recomputing key and value tensors when the context is the same.
///
/// # Arguments
///
/// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
/// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
/// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
/// * `cache` - The cache to use.
///
/// # Returns
///
/// Output tensor of shape `[batch, seq_len_query, d_model]`.
pub fn forward_cache(
&self,
query: Tensor<B, 3>,
context: Tensor<B, 3>,
mask: Option<Tensor<B, 2, Bool>>,
cache: &mut CrossAttentionCache<B>,
) -> Tensor<B, 3> {
let [batch, l_q, _] = query.dims();
// 1. Projections
let q = self.query.forward(query);
let k_compute = |context: Tensor<B, 3>| {
let [batch, l_k, _] = context.dims();
self.key
.forward(context)
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2)
};
let v_compute = |context: Tensor<B, 3>| {
let [batch, l_k, _] = context.dims();
self.value
.forward(context)
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2)
};
let k = cache.k.forward_full(context.clone(), k_compute);
let v = cache.v.forward_full(context, v_compute);
let [_, _, l_k, _] = k.dims();
// 2. Reshape Heads
// Q: [Batch, Heads, L_q, D_head]
let q = q
.reshape([batch, l_q, self.n_heads, self.d_head])
.swap_dims(1, 2);
// K, V are already in their correct shape from k_compute and v_compute
// 3. GQA Expansion
// ADVICE: Handle GQA by repeating KV heads to match Query heads
let (k, v) = if self.n_heads != self.n_heads_kv {
let n_rep = self.n_heads / self.n_heads_kv;
(self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
} else {
(k, v)
};
// 4. Score Calculation
let scores = q.matmul(k.transpose()) * self.scale;
// 5. Masking
// ADVICE: Use min_float for F16/FP8 safety
let scores = if let Some(mask) = mask {
let mask = mask.reshape([batch, 1, 1, l_k]);
scores.mask_fill(mask, self.min_float)
} else {
scores
};
// 6. Softmax
// ADVICE: Optional Quiet Softmax for sparse networks
let weights = if self.quiet_softmax {
quiet_softmax(scores, 3)
} else {
softmax(scores, 3)
};
let weights = self.dropout.forward(weights);
// 7. Aggregate & Output
let output = weights.matmul(v);
let output = output
.swap_dims(1, 2)
.reshape([batch, l_q, self.n_heads * self.d_head]);
self.output.forward(output)
}
/// Helper for Grouped Query Attention
fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
let [b, h, l, d] = x.dims();
x.reshape([b, h, 1, l, d])
.expand([b, h, n_rep, l, d])
.reshape([b, h * n_rep, l, d])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};
#[test]
fn test_cross_attention_mha_shapes() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [7, 13, 15, 32, 40, 4, 8];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: n_heads, // MHA case
d_head,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query = Tensor::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output = cross_attn.forward(query, context, None);
assert_eq!(
output.shape(),
Shape::new([batch_size, seq_len_query, d_model]),
"Output should have the correct shape",
);
}
#[test]
fn test_cross_attention_gqa_shapes() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
n_heads_kv,
d_head,
] = [7, 13, 15, 32, 40, 4, 2, 8];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv, // GQA case
d_head,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query = Tensor::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output = cross_attn.forward(query, context, None);
assert_eq!(
output.shape(),
Shape::new([batch_size, seq_len_query, d_model]),
"Output should have the correct shape",
);
}
#[test]
fn test_cross_attention_mqa_shapes() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [7, 13, 15, 32, 40, 4, 8];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: 1, // MQA case
d_head,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query = Tensor::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output = cross_attn.forward(query, context, None);
assert_eq!(
output.shape(),
Shape::new([batch_size, seq_len_query, d_model]),
"Output should have the correct shape",
);
}
#[test]
fn test_cross_attention_mask() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [3, 6, 8, 12, 16, 4, 3];
let num_padded = 2;
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: n_heads,
d_head,
dropout: 0.0, // No dropout for deterministic test
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
// Create a padding mask for the context
let mut mask: Tensor<TestBackend, 2, Int> =
Tensor::zeros([batch_size, seq_len_context], &device);
mask = mask.slice_assign(
[0..batch_size, seq_len_context - num_padded..seq_len_context],
Tensor::ones([batch_size, num_padded], &device),
);
let mask_bool = mask.equal_elem(1);
let query = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context_1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
// Change the padded part of the context tensor
let context_2 = context_1.clone().slice_assign(
[
0..batch_size,
seq_len_context - num_padded..seq_len_context,
0..d_context,
],
Tensor::random(
[batch_size, num_padded, d_context],
Distribution::Default,
&device,
),
);
// The outputs should be the same since the changed part is masked.
let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));
let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());
}
#[test]
#[should_panic]
fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {
let device = Default::default();
let config = CrossAttentionConfig {
d_model: 32,
d_context: 32,
n_heads: 5,
n_heads_kv: 2,
d_head: 8,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
config.init::<TestBackend>(&device);
}
#[test]
fn test_cross_attention_cache() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [3, 6, 8, 12, 16, 4, 3];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: n_heads,
d_head,
dropout: 0.0, // No dropout for deterministic test
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
// First forward pass, no cache
let output1 = cross_attn.forward(query1.clone(), context.clone(), None);
// Second forward pass with cache
let mut cache = CrossAttentionCache::new();
let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);
// The two outputs should be identical
output1
.into_data()
.assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());
// Third forward pass with different query, but same context and cache
let query2 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);
// For control, do a forward pass without cache with query2
let output4 = cross_attn.forward(query2.clone(), context.clone(), None);
// output3 and output4 should be identical
output3
.into_data()
.assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());
}
}

View File

@@ -0,0 +1,161 @@
use burn_core as burn;
use burn_core::config::Config;
use alloc::vec::Vec;
use burn::tensor::ops::IntElem;
use burn::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
/// Generate an autoregressive attention mask.
///
/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
pub fn generate_autoregressive_mask<B: Backend>(
batch_size: usize,
seq_length: usize,
device: &B::Device,
) -> Tensor<B, 3, Bool> {
let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);
mask.expand([batch_size, seq_length, seq_length])
}
/// Generate a padding attention mask.
pub struct GeneratePaddingMask<B: Backend> {
/// The generated tensor.
pub tensor: Tensor<B, 2, Int>,
/// The generated mask.
pub mask: Tensor<B, 2, Bool>,
}
/// Defines an enumeration to specify sequence length options for padding
#[derive(Config, Debug, Copy)]
pub enum SeqLengthOption {
/// No maximum length; use the longest sequence
NoMax,
/// Maximum length specified, truncate if necessary
Max(usize),
/// Fixed length, pad or truncate to this exact length
Fixed(usize),
}
impl From<Option<usize>> for SeqLengthOption {
fn from(val: Option<usize>) -> Self {
match val {
Some(max) => SeqLengthOption::Max(max),
None => SeqLengthOption::NoMax,
}
}
}
/// Generates a padding attention mask for a batch of token sequences.
///
/// # Arguments
///
/// * `pad_token` - The token ID used for padding
/// * `tokens_list` - Vector of token sequences (each sequence is a vector of token IDs)
/// * `seq_length` - Sequence length option (NoMax, Max, or Fixed)
/// * `device` - The device for tensor operations
///
/// # Returns
///
/// A `GeneratePaddingMask` containing the padded tensor and corresponding mask
pub fn generate_padding_mask<B: Backend>(
pad_token: usize,
tokens_list: Vec<Vec<usize>>,
seq_length: impl Into<SeqLengthOption>,
device: &B::Device,
) -> GeneratePaddingMask<B> {
let tokens_max = || {
tokens_list
.iter()
.map(|tokens| tokens.len())
.max()
.unwrap_or(1)
};
let size = match seq_length.into() {
SeqLengthOption::NoMax => tokens_max(),
SeqLengthOption::Max(max) => usize::min(tokens_max(), max),
SeqLengthOption::Fixed(limit) => limit,
};
let batch_size = tokens_list.len();
let mut tensor = Tensor::zeros([batch_size, size], device);
tensor = tensor.add_scalar(pad_token as i64);
for (index, tokens) in tokens_list.into_iter().enumerate() {
let seq_length = tokens.len().min(size);
tensor = tensor.slice_assign(
[index..index + 1, 0..seq_length],
Tensor::from_data(
TensorData::new(
tokens
.into_iter()
.take(size)
.map(|e| (e as i64).elem::<IntElem<B>>())
.collect(),
Shape::new([1, seq_length]),
),
device,
),
);
}
let mask = tensor.clone().equal_elem(pad_token as i64);
GeneratePaddingMask { tensor, mask }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use alloc::vec;
use burn::tensor::TensorData;
#[test]
fn test_generate_autoregressive_mask() {
let device = <TestBackend as Backend>::Device::default();
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
mask.into_data().assert_eq(
&TensorData::from([
[
[false, true, true],
[false, false, true],
[false, false, false],
],
[
[false, true, true],
[false, false, true],
[false, false, false],
],
]),
false,
);
}
#[test]
fn test_generate_padding_mask() {
let device = <TestBackend as Backend>::Device::default();
let tokens = vec![
vec![3, 3, 3],
vec![3, 3, 3],
vec![3, 3, 3, 4],
vec![3, 3, 3, 4, 10, 15],
];
let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
mask.mask.into_data().assert_eq(
&TensorData::from([
[false, false, false, true, true, true],
[false, false, false, true, true, true],
[false, false, false, false, true, true],
[false, false, false, false, false, false],
]),
false,
);
}
}

View File

@@ -0,0 +1,531 @@
use burn_core as burn;
use crate::activation::Gelu;
use crate::cache::TensorCache;
use crate::{Dropout, DropoutConfig, Linear, LinearConfig};
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::{Bool, Tensor, backend::Backend};
use burn::tensor::activation::{quiet_softmax, softmax};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init).
#[derive(Config, Debug)]
pub struct MultiHeadAttentionConfig {
/// The size of each linear layer.
pub d_model: usize,
/// The number of heads.
pub n_heads: usize,
/// The dropout rate. Default: 0.1
#[config(default = 0.1)]
pub dropout: f64,
/// The minimum value a float can take. Default: -1.0e4
/// This is used to mask attention scores before calculating attention weights.
/// A value too low might result in NaN.
#[config(default = -1.0e4)]
pub min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
///
/// # Params
///
/// - `query`: [`Linear`] layer with `d_model` input and output features.
/// - `key`: [`Linear`] layer with `d_model` input and output features.
/// - `value`: [`Linear`] layer with `d_model` input and output features.
/// - `output`: [`Linear`] layer with `d_model` input and output features.
///
/// Should be created with [MultiHeadAttentionConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct MultiHeadAttention<B: Backend> {
/// Linear layer to transform the input features into the query space.
pub query: Linear<B>,
/// Linear layer to transform the input features into the key space.
pub key: Linear<B>,
/// Linear layer to transform the input features into the value space.
pub value: Linear<B>,
/// Linear layer to transform the output features back to the original space.
pub output: Linear<B>,
/// Dropout layer.
pub dropout: Dropout,
/// Activation function.
pub activation: Gelu,
/// The size of each linear layer.
pub d_model: usize,
/// The number of heads.
pub n_heads: usize,
/// Size of the key and query vectors.
pub d_k: usize,
/// Minimum value a float can take.
pub min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("n_heads", &self.n_heads)
.add("d_k", &self.d_k)
.add("dropout", &self.dropout.prob)
.add("min_float", &self.min_float)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
#[derive(Debug, Clone)]
pub struct MhaInput<B: Backend> {
/// Shape `[batch_size, seq_length_1, d_model]`
query: Tensor<B, 3>,
/// Shape `[batch_size, seq_length_2, d_model]`
key: Tensor<B, 3>,
/// Shape `[batch_size, seq_length_2, d_model]`
value: Tensor<B, 3>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
}
impl MultiHeadAttentionConfig {
/// Initialize a new [multihead attention](MultiHeadAttention) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
let linear = |config: &Self| {
LinearConfig::new(config.d_model, config.d_model)
.with_initializer(self.initializer.clone())
.init(device)
};
MultiHeadAttention {
query: linear(self),
key: linear(self),
value: linear(self),
output: linear(self),
dropout: DropoutConfig::new(self.dropout).init(),
activation: Gelu::new(),
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
d_model: self.d_model,
}
}
}
impl<B: Backend> MhaInput<B> {
/// Create a [multihead attention](MultiHeadAttention) input argument
/// by setting the query, key and value to the given tensor.
///
/// # Shape
/// - tensor: `[batch_size, seq_length, d_model]`
pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
Self {
query: tensor.clone(),
key: tensor.clone(),
value: tensor,
mask_pad: None,
mask_attn: None,
}
}
/// Create a [multihead attention](MultiHeadAttention) input argument.
pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
Self {
query,
key,
value,
mask_pad: None,
mask_attn: None,
}
}
/// Register the padding mask.
pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.mask_pad = Some(mask_pad);
self
}
/// Register the attention mask.
pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.mask_attn = Some(mask_attn);
self
}
}
/// [Multihead attention](MultiHeadAttention) outputs.
#[derive(Debug, Clone)]
pub struct MhaOutput<B: Backend> {
/// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`.
pub weights: Tensor<B, 4>,
/// The context tensor `[batch_size, seq_length_1, d_model]`.
pub context: Tensor<B, 3>,
}
impl<B: Backend> MultiHeadAttention<B> {
/// Applies the forward pass on the input tensors.
///
/// See [MultiHeadAttention](MultiHeadAttention) for more information.
///
/// # Shapes
///
/// - query: `[batch_size, seq_length_1, d_model]`
/// - key: `[batch_size, seq_length_2, d_model]`
/// - value: `[batch_size, seq_length_2, d_model]`
/// - output: `[batch_size, seq_length_1, d_model]`
pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = self.attention_linear(input.query, &self.query);
let key = self.attention_linear(input.key, &self.key);
let value = self.attention_linear(input.value, &self.value);
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.clone().matmul(value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = self.output.forward(context);
MhaOutput { weights, context }
}
/// Applies the forward pass using a cache.
///
/// # Shapes
///
/// - query: `[batch_size, seq_length_1, d_model]`
/// - key: `[batch_size, seq_length_2, d_model]`
/// - value: `[batch_size, seq_length_2, d_model]`
/// - output: `[batch_size, seq_length_1, d_model]`
pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = cache
.query
.forward(input.query, |t| self.attention_linear(t, &self.query));
let key = cache
.key
.forward(input.key, |t| self.attention_linear(t, &self.key));
let value = cache
.value
.forward(input.value, |t| self.attention_linear(t, &self.value));
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.clone().matmul(value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = cache.output.forward(context, |t| self.output.forward(t));
MhaOutput { weights, context }
}
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
let attn_scores = query
.matmul(key.transpose())
.div_scalar((self.d_k as f32).sqrt());
self.dropout.forward(attn_scores)
}
fn attn_weights(
&self,
mut attn_scores: Tensor<B, 4>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
) -> Tensor<B, 4> {
if let Some(mask_pad) = mask_pad {
let [batch_size, seq_length] = mask_pad.dims();
attn_scores = attn_scores.mask_fill(
mask_pad.reshape([batch_size, 1, 1, seq_length]),
self.min_float,
);
}
if let Some(mask_attn) = mask_attn {
let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
attn_scores = attn_scores.mask_fill(
mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
self.min_float,
);
}
if self.quiet_softmax {
quiet_softmax(attn_scores, 3)
} else {
softmax(attn_scores, 3)
}
}
fn attention_linear(&self, x: Tensor<B, 3>, linear: &Linear<B>) -> Tensor<B, 4> {
let [batch_size, seq_length, _d_model] = x.dims();
linear
.forward(x)
.reshape([batch_size, seq_length, self.n_heads, self.d_k])
.swap_dims(1, 2)
}
}
/// Cache for the [Multi Head Attention](MultiHeadAttention) layer.
///
/// To be used during inference when decoding tokens.
pub struct MhaCache<B: Backend> {
query: MhaLinearCache<B, 4>,
key: MhaLinearCache<B, 4>,
value: MhaLinearCache<B, 4>,
output: MhaLinearCache<B, 3>,
}
enum MhaLinearCache<B: Backend, const D: usize> {
Autoregressive(TensorCache<B, D>, usize),
Full(TensorCache<B, D>),
}
impl<B: Backend> MhaCache<B> {
/// Initialize a cache for autoregressive inference.
pub fn autoregressive() -> Self {
Self {
query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
}
}
/// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and
/// values (cross-attention).
pub fn autoregressive_cross_attention() -> Self {
Self {
query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
key: MhaLinearCache::Full(TensorCache::empty()),
value: MhaLinearCache::Full(TensorCache::empty()),
output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
}
}
}
impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
&mut self,
tensor: Tensor<B, 3>,
func: F,
) -> Tensor<B, D> {
match self {
MhaLinearCache::Autoregressive(cache, dim) => {
cache.forward_autoregressive(tensor, *dim, func)
}
MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TestBackend, attention::generate_autoregressive_mask};
use alloc::vec::Vec;
use burn::tensor::Int;
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
use burn::tensor::{Distribution, Shape};
#[test]
fn test_self_attention_shapes() {
let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
let device = Default::default();
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
let input = MhaInput::self_attn(Tensor::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
));
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
Shape::new([batch_size, seq_length, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
Shape::new([batch_size, n_heads, seq_length, seq_length]),
"Weights should have the correct shape",
);
}
#[test]
fn test_generic_mha_shapes() {
let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
let mha = MultiHeadAttentionConfig::new(d_model, n_heads)
.init::<TestBackend>(&Default::default());
let device = Default::default();
let input = MhaInput::new(
Tensor::random(
[batch_size, seq_length_1, d_model],
Distribution::Default,
&device,
),
Tensor::random(
[batch_size, seq_length_2, d_model],
Distribution::Default,
&device,
),
Tensor::random(
[batch_size, seq_length_2, d_model],
Distribution::Default,
&device,
),
);
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
Shape::new([batch_size, seq_length_1, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
"Weights should have the correct shape",
);
}
#[test]
fn test_self_attention_mask_pad() {
let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
let device = Default::default();
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
// Create a padding mask
let mask_pad: Tensor<TestBackend, 2, Int> =
Tensor::zeros([batch_size, seq_length], &device);
let mask_pad = mask_pad.slice_assign(
[0..batch_size, seq_length - num_padded..seq_length],
Tensor::ones([batch_size, num_padded], &device),
);
let mask_pad = mask_pad.equal_elem(1).to_device(&device);
let tensor_1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
// Change the end of the tensor
let tensor_2 = tensor_1.clone().slice_assign(
[
0..batch_size,
seq_length - num_padded..seq_length,
0..d_model,
],
Tensor::random(
[batch_size, num_padded, d_model],
Distribution::Default,
&device,
),
);
let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
let output_1 = mha.forward(input_1);
let output_2 = mha.forward(input_2);
// Check that the beginning of each tensor is the same
output_1
.context
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data()
.assert_approx_eq(
&output_2
.context
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data(),
Tolerance::<f32>::default(),
);
}
#[test]
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
let device = Default::default();
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
let tensor = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
let output_1 = mha.forward(input);
let mut output_2 = Vec::new();
let mut cache = MhaCache::autoregressive();
for i in 1..seq_length + 1 {
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
let input = MhaInput::self_attn(tensor);
let next_tok = mha.forward_cache(input, &mut cache).context.slice([
0..batch_size,
i - 1..i,
0..d_model,
]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
output_1
.context
.into_data()
.assert_approx_eq::<FloatElem<TestBackend>>(
&output_2.into_data(),
Tolerance::default(),
);
}
#[test]
fn display() {
let config = MultiHeadAttentionConfig::new(2, 4);
let mha = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{mha}"),
"MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
);
}
}

View File

@@ -0,0 +1,7 @@
mod cross_attention;
mod mask;
mod mha;
pub use cross_attention::*;
pub use mask::*;
pub use mha::*;

View File

@@ -0,0 +1,52 @@
use alloc::vec;
use burn_core as burn;
use super::{CacheState, TensorCache};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
impl<B: Backend, const D: usize> TensorCache<B, D> {
pub(crate) fn forward_autoregressive<F>(
&mut self,
tensor: Tensor<B, 3>,
dim_cat: usize,
func: F,
) -> Tensor<B, D>
where
F: Fn(Tensor<B, 3>) -> Tensor<B, D>,
{
let mut tensor_old = CacheState::Empty;
core::mem::swap(&mut self.state, &mut tensor_old);
let tensor_new = match tensor_old {
CacheState::Value(tensor_old) => {
let [batch_size, seq_length, d_model] = tensor.dims();
let next_seq_token =
tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
let next_seq_token = func(next_seq_token);
Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)
}
_ => func(tensor),
};
self.state = CacheState::Value(tensor_new.clone());
tensor_new
}
pub(crate) fn forward_full<F>(&mut self, tensor: Tensor<B, 3>, func: F) -> Tensor<B, D>
where
F: Fn(Tensor<B, 3>) -> Tensor<B, D>,
{
let mut tensor_old = CacheState::Empty;
core::mem::swap(&mut self.state, &mut tensor_old);
let tensor_new = match tensor_old {
CacheState::Value(tensor_old) => tensor_old,
_ => func(tensor),
};
self.state = CacheState::Value(tensor_new.clone());
tensor_new
}
}

View File

@@ -0,0 +1,27 @@
use burn_core as burn;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
pub(crate) enum CacheState<T> {
Value(T),
Empty,
}
/// A cache for a tensor.
pub struct TensorCache<B: Backend, const D: usize> {
pub(crate) state: CacheState<Tensor<B, D>>,
}
impl<B: Backend, const D: usize> TensorCache<B, D> {
/// Creates a new empty cache.
///
/// # Returns
///
/// The empty cache.
pub fn empty() -> Self {
Self {
state: CacheState::Empty,
}
}
}

View File

@@ -0,0 +1,4 @@
mod autoregressive;
mod base;
pub use base::*;

View File

@@ -0,0 +1,22 @@
pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize, groups: usize) {
let channels_in_div_by_group = channels_in.is_multiple_of(groups);
let channels_out_div_by_group = channels_out.is_multiple_of(groups);
if !channels_in_div_by_group || !channels_out_div_by_group {
panic!(
"Both channels must be divisible by the number of groups. Got \
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
);
}
}
// https://github.com/tracel-ai/burn/issues/2676
/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
/// size is not supported as it will not produce the same output size.
pub(crate) fn check_same_padding_support(kernel_size: &[usize]) {
for k in kernel_size.iter() {
if k % 2 == 0 {
unimplemented!("Same padding with an even kernel size is not supported");
}
}
}

View File

@@ -0,0 +1,283 @@
use alloc::format;
use burn_core as burn;
use crate::{PaddingConfig1d, conv::checks};
use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::PaddedConvOptions};
use burn::{
config::Config,
module::{Content, DisplaySettings, Ignored, Initializer, Module, ModuleDisplay, Param},
};
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
#[derive(Config, Debug)]
pub struct Conv1dConfig {
/// The number of input channels.
pub channels_in: usize,
/// The number of output channels.
pub channels_out: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The stride of the convolution.
#[config(default = "1")]
pub stride: usize,
/// Spacing between kernel elements.
#[config(default = "1")]
pub dilation: usize,
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
///
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
/// will automatically use asymmetric padding to preserve input dimensions.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a 1D convolution over input tensors.
///
/// Should be created with [Conv1dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv1d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
pub weight: Param<Tensor<B, 3>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: usize,
/// Size of the kernel.
pub kernel_size: usize,
/// Spacing between kernel elements.
pub dilation: usize,
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: Ignored<PaddingConfig1d>,
}
impl<B: Backend> ModuleDisplay for Conv1d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Format padding
let padding_formatted = format!("{}", &self.padding);
// Format stride/dilation as strings
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
// Extract channels in/out from weight dims
let [channels_out, group_channels_in, _] = self.weight.dims();
let channels_in = group_channels_in * self.groups;
let ch_out = format!("{:?}", channels_out);
let ch_in = format!("{:?}", channels_in);
content
.add("ch_in", &ch_in)
.add("ch_out", &ch_out)
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl Conv1dConfig {
/// Initialize a new [conv1d](Conv1d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
let shape = [
self.channels_out,
self.channels_in / self.groups,
self.kernel_size,
];
let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
let weight = self
.initializer
.init_with(shape, Some(fan_in), None, device);
let mut bias = None;
if self.bias {
bias =
Some(
self.initializer
.init_with([self.channels_out], Some(fan_in), None, device),
);
}
Conv1d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
groups: self.groups,
}
}
}
impl<B: Backend> Conv1d<B> {
/// Applies the forward pass on the input tensor.
///
/// See [conv1d](burn::tensor::module::conv1d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, length_in]`
/// - output: `[batch_size, channels_out, length_out]`
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let length = input.dims()[2];
// Calculate padding as pair - handles Same, Valid, and Explicit uniformly
let (left, right) =
self.padding
.calculate_padding_1d_pair(length, self.kernel_size, self.stride);
let options = PaddedConvOptions::asymmetric(
[self.stride],
[left],
[right],
[self.dilation],
self.groups,
);
conv1d(
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
options,
)
}
}
#[cfg(test)]
mod tests {
use burn::tensor::{ElementConversion, ops::FloatElem};
type FT = FloatElem<TestBackend>;
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = Conv1dConfig::new(5, 5, 5);
let k = (config.channels_in * config.kernel_size) as f64;
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&device);
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight
.to_data()
.assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
}
#[test]
fn same_with_even_kernel_uses_asymmetric_padding() {
let device = Default::default();
let config = Conv1dConfig::new(4, 4, 2)
.with_padding(PaddingConfig1d::Same)
.with_initializer(Initializer::Constant { value: 1.0 })
.with_bias(false);
let conv = config.init::<TestBackend>(&device);
// Input: [batch=1, channels=4, length=5]
let input = Tensor::<TestBackend, 3>::ones([1, 4, 5], &device);
let output = conv.forward(input);
// Same padding should preserve spatial dimensions
assert_eq!(output.dims(), [1, 4, 5]);
}
#[test]
fn display() {
let config = Conv1dConfig::new(5, 5, 5);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{conv}"),
"Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = Conv1dConfig::new(5, 3, 3);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
let _ = conv.forward(input);
}
#[test]
fn asymmetric_padding_forward() {
let device = Default::default();
// Create conv with asymmetric padding: left=1, right=2
let config = Conv1dConfig::new(2, 3, 3)
.with_padding(PaddingConfig1d::Explicit(1, 2))
.with_initializer(Initializer::Constant { value: 1.0 })
.with_bias(false);
let conv = config.init::<TestBackend>(&device);
// Input: [batch=1, channels=2, length=4]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
let output = conv.forward(input);
// With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
// Output length = (7 - 3) / 1 + 1 = 5
assert_eq!(output.dims(), [1, 3, 5]);
}
#[test]
fn symmetric_explicit_padding_forward() {
let device = Default::default();
// Create conv with symmetric explicit padding: left=2, right=2
let config = Conv1dConfig::new(2, 3, 3)
.with_padding(PaddingConfig1d::Explicit(2, 2))
.with_initializer(Initializer::Constant { value: 1.0 })
.with_bias(false);
let conv = config.init::<TestBackend>(&device);
// Input: [batch=1, channels=2, length=4]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
let output = conv.forward(input);
// With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
// Output length = (8 - 3) / 1 + 1 = 6
assert_eq!(output.dims(), [1, 3, 6]);
}
}

View File

@@ -0,0 +1,349 @@
use alloc::format;
use burn_core as burn;
use crate::PaddingConfig2d;
use burn::config::Config;
use burn::module::Initializer;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::conv2d;
use burn::tensor::ops::PaddedConvOptions;
use crate::conv::checks;
/// Configuration to create a [2D convolution](Conv2d) layer, using the [init function](Conv2dConfig::init).
#[derive(Config, Debug)]
pub struct Conv2dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
#[config(default = "[1, 1]")]
pub stride: [usize; 2],
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
///
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
/// will automatically use asymmetric padding to preserve input dimensions.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a 2D convolution over input tensors.
///
/// Should be created with [Conv2dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv2d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
}
impl Conv2dConfig {
/// Initialize a new [conv2d](Conv2d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
let shape = [
self.channels[1],
self.channels[0] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
];
let k = self.kernel_size.iter().product::<usize>();
let fan_in = self.channels[0] / self.groups * k;
let fan_out = self.channels[1] / self.groups * k;
let weight = self
.initializer
.init_with(shape, Some(fan_in), Some(fan_out), device);
let mut bias = None;
if self.bias {
bias = Some(self.initializer.init_with(
[self.channels[1]],
Some(fan_in),
Some(fan_out),
device,
));
}
Conv2d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: Ignored(self.padding.clone()),
groups: self.groups,
}
}
}
impl<B: Backend> ModuleDisplay for Conv2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
let [channels_out, group_channels_in, _, _] = self.weight.dims();
let channels_in = group_channels_in * self.groups;
let ch_out = format!("{:?}", channels_out);
let ch_in = format!("{:?}", channels_in);
content
.add("ch_in", &ch_in)
.add("ch_out", &ch_out)
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl<B: Backend> Conv2d<B> {
/// Applies the forward pass on the input tensor.
///
/// See [conv2d](burn::tensor::module::conv2d) for more information.
///
/// # Shapes
/// - `input`: `[batch_size, channels_in, height_in, width_in]`
/// - `output`: `[batch_size, channels_out, height_out, width_out]`
///
/// # Example
/// ```rust,ignore
/// use burn::nn::conv::Conv2dConfig;
/// use burn::tensor::Tensor;
///
/// // Assuming backend type alias `B`
/// let device = Default::default();
/// let conv = Conv2dConfig::new([3, 8], [3, 3]).init::<B>(&device);
///
/// let x = Tensor::<B, 4>::zeros([1, 3, 28, 28], &device);
/// let y = conv.forward(x);
///
/// println!("{:?}", y.dims()); // [1, 8, 26, 26]
/// ```
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
// Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
height_in,
width_in,
&self.kernel_size,
&self.stride,
);
let options = PaddedConvOptions::asymmetric(
self.stride,
[top, left],
[bottom, right],
self.dilation,
self.groups,
);
conv2d(
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
options,
)
}
}
#[cfg(test)]
mod tests {
use burn::tensor::ops::FloatElem;
use burn::tensor::{ElementConversion, Tolerance};
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
type FT = FloatElem<TestBackend>; // Float test
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = Conv2dConfig::new([5, 1], [5, 5]);
let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&device);
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
let conv = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight.to_data().assert_approx_eq::<FT>(
&TensorData::zeros::<FT, _>(conv.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn initializer_fan_out() {
let device = Default::default();
TestBackend::seed(&device, 0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true, // test that fan_out is passed to `init_with()`
};
let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
fn initializer_fan_with_groups_is_valid() {
let device = Default::default();
TestBackend::seed(&device, 0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true,
};
let config = Conv2dConfig::new([4, 4], [1, 1])
.with_initializer(init.clone())
.with_groups(4);
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
#[should_panic = "Both channels must be divisible by the number of groups."]
fn channels_with_groups_is_invalid() {
let device = Default::default();
let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
let _ = config.init::<TestBackend>(&device);
}
#[test]
fn same_with_even_kernel_uses_asymmetric_padding() {
let device = Default::default();
let config = Conv2dConfig::new([4, 4], [2, 2])
.with_padding(PaddingConfig2d::Same)
.with_initializer(Initializer::Constant { value: 1.0 })
.with_bias(false);
let conv = config.init::<TestBackend>(&device);
// Input: [batch=1, channels=4, height=5, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 4, 5, 5], &device);
let output = conv.forward(input);
// Same padding should preserve spatial dimensions
assert_eq!(output.dims(), [1, 4, 5, 5]);
}
#[test]
fn display() {
let config = Conv2dConfig::new([5, 1], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{conv}"),
"Conv2d {ch_in: 5, ch_out: 1, stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = Conv2dConfig::new([5, 3], [3, 3]);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
let _ = conv.forward(input);
}
#[test]
fn asymmetric_padding_forward() {
let device = Default::default();
// Create conv with asymmetric padding: top=1, left=2, bottom=3, right=4
let config = Conv2dConfig::new([2, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4))
.with_initializer(Initializer::Constant { value: 1.0 })
.with_bias(false);
let conv = config.init::<TestBackend>(&device);
// Input: [batch=1, channels=2, height=4, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
let output = conv.forward(input);
// Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
// Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
assert_eq!(output.dims(), [1, 3, 6, 9]);
}
#[test]
fn symmetric_explicit_padding_forward() {
let device = Default::default();
// Create conv with symmetric explicit padding: top=2, left=2, bottom=2, right=2
let config = Conv2dConfig::new([2, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))
.with_initializer(Initializer::Constant { value: 1.0 })
.with_bias(false);
let conv = config.init::<TestBackend>(&device);
// Input: [batch=1, channels=2, height=4, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
let output = conv.forward(input);
// Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
// Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
assert_eq!(output.dims(), [1, 3, 6, 7]);
}
}

View File

@@ -0,0 +1,276 @@
use alloc::format;
use burn_core as burn;
use crate::PaddingConfig3d;
use burn::config::Config;
use burn::module::Initializer;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::conv3d;
use burn::tensor::ops::ConvOptions;
use crate::conv::checks;
/// Configuration to create a [3D convolution](Conv3d) layer, using the [init function](Conv3dConfig::init).
#[derive(Config, Debug)]
pub struct Conv3dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 3],
/// The stride of the convolution.
#[config(default = "[1, 1, 1]")]
pub stride: [usize; 3],
/// Spacing between kernel elements.
#[config(default = "[1, 1, 1]")]
pub dilation: [usize; 3],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
#[config(default = "PaddingConfig3d::Valid")]
pub padding: PaddingConfig3d,
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a 3D convolution over input tensors.
///
/// Should be created with [Conv3dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv3d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2, kernel_size_3]`
pub weight: Param<Tensor<B, 5>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: [usize; 3],
/// Size of the kernel.
pub kernel_size: [usize; 3],
/// Spacing between kernel elements.
pub dilation: [usize; 3],
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig3d>,
}
impl Conv3dConfig {
/// Initialize a new [conv3d](Conv3d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv3d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
if self.padding == PaddingConfig3d::Same {
checks::check_same_padding_support(&self.kernel_size);
}
let shape = [
self.channels[1],
self.channels[0] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
];
let k = self.kernel_size.iter().product::<usize>();
let fan_in = self.channels[0] / self.groups * k;
let fan_out = self.channels[1] / self.groups * k;
let weight = self
.initializer
.init_with(shape, Some(fan_in), Some(fan_out), device);
let mut bias = None;
if self.bias {
bias = Some(self.initializer.init_with(
[self.channels[1]],
Some(fan_in),
Some(fan_out),
device,
));
}
Conv3d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: Ignored(self.padding.clone()),
groups: self.groups,
}
}
}
impl<B: Backend> ModuleDisplay for Conv3d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Padding doesn't implement ModuleDisplay, so format manually.
let padding_formatted = format!("{}", &self.padding);
// Format arrays as strings (consistent with Conv2d/Conv1d).
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
// Weight dims: [channels_out, channels_in/groups, k1, k2, k3]
let [channels_out, group_channels_in, _, _, _] = self.weight.dims();
let channels_in = group_channels_in * self.groups;
let ch_out = format!("{:?}", channels_out);
let ch_in = format!("{:?}", channels_in);
content
.add("ch_in", &ch_in)
.add("ch_out", &ch_out)
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl<B: Backend> Conv3d<B> {
/// Applies the forward pass on the input tensor.
///
/// See [conv3d](burn::tensor::module::conv3d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, depth_in, height_in, width_in]`
/// - output: `[batch_size, channels_out, depth_out, height_out, width_out]`
pub fn forward(&self, input: Tensor<B, 5>) -> Tensor<B, 5> {
let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims();
let padding = self.padding.calculate_padding_3d(
depth_in,
height_in,
width_in,
&self.kernel_size,
&self.stride,
);
conv3d(
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
ConvOptions::new(self.stride, padding, self.dilation, self.groups),
)
}
}
#[cfg(test)]
mod tests {
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = Conv3dConfig::new([5, 1], [5, 5, 5]);
let k = (config.channels[0]
* config.kernel_size[0]
* config.kernel_size[1]
* config.kernel_size[2]) as f64;
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&device);
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = Conv3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros);
let device = Default::default();
let conv = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight.to_data().assert_approx_eq::<FT>(
&TensorData::zeros::<f32, _>(conv.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn initializer_fan_out() {
let device = Default::default();
TestBackend::seed(&device, 0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true, // test that fan_out is passed to `init_with()`
};
let config = Conv3dConfig::new([5, 1], [5, 5, 5]).with_initializer(init.clone());
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
fn initializer_fan_with_groups_is_valid() {
let device = Default::default();
TestBackend::seed(&device, 0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true,
};
let config = Conv3dConfig::new([4, 4], [1, 1, 1])
.with_initializer(init.clone())
.with_groups(4);
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
#[should_panic = "Same padding with an even kernel size is not supported"]
fn same_with_even_kernel_is_invalid() {
let device = Default::default();
let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same);
let _ = config.init::<TestBackend>(&device);
}
#[test]
fn display() {
let config = Conv3dConfig::new([5, 1], [5, 5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{conv}"),
"Conv3d {ch_in: 5, ch_out: 1, stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: Valid, params: 626}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = Conv3dConfig::new([5, 3], [3, 3, 3]);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 5>::zeros([1, 4, 10, 10, 10], &Default::default());
let _ = conv.forward(input);
}
}

View File

@@ -0,0 +1,216 @@
use alloc::format;
use burn_core as burn;
use crate::conv::checks;
use burn::config::Config;
use burn::module::Content;
use burn::module::DisplaySettings;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::ModuleDisplay;
use burn::module::Param;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::conv_transpose1d;
use burn::tensor::ops::ConvTransposeOptions;
/// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer
/// using the [init function](ConvTranspose1dConfig::init).
#[derive(Config, Debug)]
pub struct ConvTranspose1dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: usize,
/// The stride of the convolution.
#[config(default = "1")]
pub stride: usize,
/// Spacing between kernel elements.
#[config(default = "1")]
pub dilation: usize,
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
#[config(default = "0")]
pub padding: usize,
/// The padding output configuration.
#[config(default = "0")]
pub padding_out: usize,
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a 1D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose1d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
pub weight: Param<Tensor<B, 3>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: usize,
/// Size of the kernel.
pub kernel_size: usize,
/// Spacing between kernel elements.
pub dilation: usize,
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: usize,
/// The padding output configuration.
pub padding_out: usize,
/// The number of channels.
pub channels: [usize; 2],
}
impl<B: Backend> ModuleDisplay for ConvTranspose1d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &self.stride)
.add("kernel_size", &self.kernel_size)
.add("dilation", &self.dilation)
.add("groups", &self.groups)
.add("padding", &self.padding)
.add("padding_out", &self.padding_out)
.optional()
}
}
impl ConvTranspose1dConfig {
/// Initialize a new [conv transpose 1d](ConvTranspose1d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose1d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
let shape = [
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size,
];
let fan_in = self.channels[1] / self.groups * self.kernel_size;
let weight = self
.initializer
.init_with(shape, Some(fan_in), None, device);
let mut bias = None;
if self.bias {
bias = Some(
self.initializer
.init_with([self.channels[1]], Some(fan_in), None, device),
);
}
ConvTranspose1d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
impl<B: Backend> ConvTranspose1d<B> {
/// Applies the forward pass on the input tensor.
///
/// See also [conv_transpose1d](burn::tensor::module::conv_transpose1d).
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, length_in]`
/// - output: `[batch_size, channels_out, length_out]`
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
conv_transpose1d(
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
ConvTransposeOptions::new(
[self.stride],
[self.padding],
[self.padding_out],
[self.dilation],
self.groups,
),
)
}
}
#[cfg(test)]
mod tests {
use burn::tensor::ops::FloatElem;
use burn::tensor::{ElementConversion, Tolerance};
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
type FT = FloatElem<TestBackend>;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = ConvTranspose1dConfig::new([5, 1], 5);
let k = (config.channels[1] * config.kernel_size) as f64;
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&Default::default());
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight.to_data().assert_approx_eq::<f32>(
&TensorData::zeros::<f32, _>(conv.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn display() {
let config = ConvTranspose1dConfig::new([5, 2], 5);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{conv}"),
"ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = ConvTranspose1dConfig::new([5, 3], 3);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
let _ = conv.forward(input);
}
}

View File

@@ -0,0 +1,216 @@
use alloc::format;
use burn_core as burn;
use crate::conv::checks;
use burn::config::Config;
use burn::module::Content;
use burn::module::DisplaySettings;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::ModuleDisplay;
use burn::module::Param;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::conv_transpose2d;
use burn::tensor::ops::ConvTransposeOptions;
/// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer
/// using the [init function](ConvTranspose2dConfig::init).
#[derive(Config, Debug)]
pub struct ConvTranspose2dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
#[config(default = "[1, 1]")]
pub stride: [usize; 2],
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
#[config(default = "[0, 0]")]
pub padding: [usize; 2],
/// The padding output configuration.
#[config(default = "[0, 0]")]
pub padding_out: [usize; 2],
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a 2D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose2d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: [usize; 2],
/// Padding output configuration.
pub padding_out: [usize; 2],
/// Number of channels.
pub channels: [usize; 2],
}
impl<B: Backend> ModuleDisplay for ConvTranspose2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &format!("{:?}", &self.stride))
.add("kernel_size", &format!("{:?}", &self.kernel_size))
.add("dilation", &format!("{:?}", &self.dilation))
.add("groups", &self.groups)
.add("padding", &format!("{:?}", &self.padding))
.add("padding_out", &format!("{:?}", &self.padding_out))
.optional()
}
}
impl ConvTranspose2dConfig {
/// Initialize a new [conv transpose 2d](ConvTranspose2d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose2d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
let shape = [
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
];
let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::<usize>();
let weight = self
.initializer
.init_with(shape, Some(fan_in), None, device);
let mut bias = None;
if self.bias {
bias = Some(
self.initializer
.init_with([self.channels[1]], Some(fan_in), None, device),
);
}
ConvTranspose2d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
impl<B: Backend> ConvTranspose2d<B> {
/// Applies the forward pass on the input tensor.
///
/// See also [conv_transpose2d](burn::tensor::module::conv_transpose2d).
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, height_in, width_in]`
/// - output: `[batch_size, channels_out, height_out, width_out]`
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
conv_transpose2d(
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = ConvTranspose2dConfig::new([5, 1], [5, 5]);
let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64;
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&Default::default());
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config =
ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight.to_data().assert_approx_eq::<FT>(
&TensorData::zeros::<f32, _>(conv.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn display() {
let config = ConvTranspose2dConfig::new([5, 2], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{conv}"),
"ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = ConvTranspose2dConfig::new([5, 3], [3, 3]);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
let _ = conv.forward(input);
}
}

View File

@@ -0,0 +1,221 @@
use alloc::format;
use burn_core as burn;
use crate::conv::checks;
use burn::config::Config;
use burn::module::Content;
use burn::module::DisplaySettings;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::ModuleDisplay;
use burn::module::Param;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::conv_transpose3d;
use burn::tensor::ops::ConvTransposeOptions;
/// Configuration to create an [3D transposed convolution](ConvTranspose3d) layer
/// using the [init function](ConvTranspose3dConfig::init).
#[derive(Config, Debug)]
pub struct ConvTranspose3dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 3],
/// The stride of the convolution.
#[config(default = "[1, 1, 1]")]
pub stride: [usize; 3],
/// Spacing between kernel elements.
#[config(default = "[1, 1, 1]")]
pub dilation: [usize; 3],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
#[config(default = "[0, 0, 0]")]
pub padding: [usize; 3],
/// The padding output configuration.
#[config(default = "[0, 0, 0]")]
pub padding_out: [usize; 3],
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a 3D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose3d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2, kernel_size_3]`
pub weight: Param<Tensor<B, 5>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: [usize; 3],
/// Size of the kernel.
pub kernel_size: [usize; 3],
/// Spacing between kernel elements.
pub dilation: [usize; 3],
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: [usize; 3],
/// Padding output configuration.
pub padding_out: [usize; 3],
/// Number of channels.
pub channels: [usize; 2],
}
impl<B: Backend> ModuleDisplay for ConvTranspose3d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &format!("{:?}", &self.stride))
.add("kernel_size", &format!("{:?}", &self.kernel_size))
.add("dilation", &format!("{:?}", &self.dilation))
.add("groups", &self.groups)
.add("padding", &format!("{:?}", &self.padding))
.add("padding_out", &format!("{:?}", &self.padding_out))
.optional()
}
}
impl ConvTranspose3dConfig {
/// Initialize a new [conv transpose 2d](ConvTranspose3d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose3d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
let shape = [
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
];
let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::<usize>();
let weight = self
.initializer
.init_with(shape, Some(fan_in), None, device);
let mut bias = None;
if self.bias {
bias = Some(
self.initializer
.init_with([self.channels[1]], Some(fan_in), None, device),
);
}
ConvTranspose3d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
impl<B: Backend> ConvTranspose3d<B> {
/// Applies the forward pass on the input tensor.
///
/// See also [conv_transpose3d](burn::tensor::module::conv_transpose3d).
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, depth_in, height_in, width_in]`
/// - output: `[batch_size, channels_out, depth_out, height_out, width_out]`
pub fn forward(&self, input: Tensor<B, 5>) -> Tensor<B, 5> {
conv_transpose3d(
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
)
}
}
#[cfg(test)]
mod tests {
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = ConvTranspose3dConfig::new([5, 1], [5, 5, 5]);
let k = (config.channels[1]
* config.kernel_size[0]
* config.kernel_size[1]
* config.kernel_size[2]) as f64;
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&Default::default());
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config =
ConvTranspose3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight.to_data().assert_approx_eq::<f32>(
&TensorData::zeros::<f32, _>(conv.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn display() {
let config = ConvTranspose3dConfig::new([5, 2], [5, 5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{conv}"),
"ConvTranspose3d {channels: [5, 2], stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: [0, 0, 0], padding_out: [0, 0, 0], params: 1252}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = ConvTranspose3dConfig::new([5, 3], [3, 3, 3]);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 5>::zeros([1, 4, 10, 10, 10], &Default::default());
let _ = conv.forward(input);
}
}

View File

@@ -0,0 +1,295 @@
use alloc::format;
use burn::tensor::ops::DeformConvOptions;
use burn_core as burn;
use crate::PaddingConfig2d;
use burn::config::Config;
use burn::module::Initializer;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::deform_conv2d;
use crate::conv::checks;
/// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init).
#[derive(Config, Debug)]
pub struct DeformConv2dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
#[config(default = "[1, 1]")]
pub stride: [usize; 2],
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub weight_groups: usize,
/// Offset groups.
#[config(default = "1")]
pub offset_groups: usize,
/// The padding configuration.
///
/// ### Warning
/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
/// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a deformable 2D convolution over input tensors.
///
/// Should be created with [DeformConv2dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct DeformConv2d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub weight_groups: usize,
/// Offset groups.
pub offset_groups: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
}
impl DeformConv2dConfig {
/// Initialize a new [DeformConv2d](DeformConv2d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> DeformConv2d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);
if self.padding == PaddingConfig2d::Same {
checks::check_same_padding_support(&self.kernel_size);
}
let shape = [
self.channels[1],
self.channels[0] / self.weight_groups,
self.kernel_size[0],
self.kernel_size[1],
];
let k = self.kernel_size.iter().product::<usize>();
let fan_in = self.channels[0] / self.weight_groups * k;
let fan_out = self.channels[1] / self.weight_groups * k;
let weight = self
.initializer
.init_with(shape, Some(fan_in), Some(fan_out), device);
let mut bias = None;
if self.bias {
bias = Some(self.initializer.init_with(
[self.channels[1]],
Some(fan_in),
Some(fan_out),
device,
));
}
DeformConv2d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: Ignored(self.padding.clone()),
weight_groups: self.weight_groups,
offset_groups: self.weight_groups,
}
}
}
impl<B: Backend> ModuleDisplay for DeformConv2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
content
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("weight_groups", &self.weight_groups)
.add("offset_groups", &self.offset_groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl<B: Backend> DeformConv2d<B> {
/// Applies the forward pass on the input tensor.
///
/// See [deform_conv2d](burn::tensor::module::deform_conv2d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, height_in, width_in]`
/// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]`
/// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]`
/// - output: `[batch_size, channels_out, height_out, width_out]`
pub fn forward(
&self,
input: Tensor<B, 4>,
offset: Tensor<B, 4>,
mask: Option<Tensor<B, 4>>,
) -> Tensor<B, 4> {
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
let padding =
self.padding
.calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
deform_conv2d(
input,
offset,
self.weight.val(),
mask,
self.bias.as_ref().map(|bias| bias.val()),
DeformConvOptions::new(
self.stride,
padding,
self.dilation,
self.weight_groups,
self.offset_groups,
),
)
}
}
#[cfg(test)]
mod tests {
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
let k = (config.offset_groups as f64 / k).sqrt().elem::<FT>();
let conv = config.init::<TestBackend>(&device);
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
let conv = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight.to_data().assert_approx_eq::<FT>(
&TensorData::zeros::<f32, _>(conv.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn initializer_fan_out() {
let device = Default::default();
TestBackend::seed(&device, 0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true, // test that fan_out is passed to `init_with()`
};
let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
fn initializer_fan_with_groups_is_valid() {
let device = Default::default();
TestBackend::seed(&device, 0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true,
};
let config = DeformConv2dConfig::new([4, 4], [1, 1])
.with_initializer(init.clone())
.with_weight_groups(4);
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
#[should_panic = "Both channels must be divisible by the number of groups."]
fn channels_with_groups_is_invalid() {
let device = Default::default();
let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4);
let _ = config.init::<TestBackend>(&device);
}
#[test]
#[should_panic = "Same padding with an even kernel size is not supported"]
fn same_with_even_kernel_is_invalid() {
let device = Default::default();
let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same);
let _ = config.init::<TestBackend>(&device);
}
#[test]
fn display() {
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{conv}"),
"DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}"
);
}
#[test]
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
fn input_channels_mismatch() {
let config = DeformConv2dConfig::new([5, 3], [3, 3]);
let conv = config.init::<TestBackend>(&Default::default());
let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
let offset = Tensor::<TestBackend, 4>::zeros([1, 2 * 3 * 3, 10, 10], &Default::default());
let _ = conv.forward(input, offset, None);
}
}

View File

@@ -0,0 +1,17 @@
mod conv1d;
mod conv2d;
mod conv3d;
mod conv_transpose1d;
mod conv_transpose2d;
mod conv_transpose3d;
mod deform_conv2d;
pub(crate) mod checks;
pub use conv_transpose1d::*;
pub use conv_transpose2d::*;
pub use conv_transpose3d::*;
pub use conv1d::*;
pub use conv2d::*;
pub use conv3d::*;
pub use deform_conv2d::*;

View File

@@ -0,0 +1,124 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution, Tensor};
/// Configuration to create a [Dropout](Dropout) layer using the [init function](DropoutConfig::init).
#[derive(Config, Debug)]
pub struct DropoutConfig {
/// The probability of randomly zeroes some elements of the input tensor during training.
pub prob: f64,
}
/// Set at random some elements of the input tensor to zero during training.
///
/// This is an effective regularization technique as describe in the paper
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
///
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
///
/// Should be created with [DropoutConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Dropout {
/// The probability of randomly zeroes some elements of the input tensor during training.
pub prob: f64,
}
impl DropoutConfig {
/// Initialize a new [dropout](Dropout) module.
pub fn init(&self) -> Dropout {
if self.prob < 0.0 || self.prob > 1.0 {
panic!(
"Dropout probability should be between 0 and 1, but got {}",
self.prob
);
}
Dropout { prob: self.prob }
}
}
impl Dropout {
/// Applies the forward pass on the input tensor.
///
/// See [Dropout](Dropout) for more information.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if !B::ad_enabled(&input.device()) || self.prob == 0.0 {
return input;
}
let prob_keep = 1.0 - self.prob;
let random = input.random_like(Distribution::Bernoulli(prob_keep));
let x = input * random;
x * (1.0 / prob_keep)
}
}
impl ModuleDisplay for Dropout {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("prob", &self.prob).optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Shape;
#[cfg(feature = "std")]
use crate::{TestAutodiffBackend, TestBackend};
#[cfg(not(feature = "std"))]
use crate::TestBackend;
#[cfg(feature = "std")]
#[test]
fn with_ad_backend_should_mark_input() {
let tensor =
Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
let dropout = DropoutConfig::new(0.5).init();
let output = dropout.forward(tensor.clone());
assert_ne!(tensor.to_data(), output.to_data());
}
#[test]
fn without_ad_backend_should_not_change_input() {
let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
let dropout = DropoutConfig::new(0.5).init();
let output = dropout.forward(tensor.clone());
assert_eq!(tensor.to_data(), output.to_data());
}
#[test]
fn display() {
let config = DropoutConfig::new(0.5);
let layer = config.init();
assert_eq!(alloc::format!("{layer}"), "Dropout {prob: 0.5}");
}
#[test]
#[should_panic = "Dropout probability should be between 0 and 1,"]
fn dropout_prob_invalid() {
let config = DropoutConfig::new(-10.);
let _layer = config.init();
}
}

View File

@@ -0,0 +1,111 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::Param;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Int;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::embedding;
/// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init).
#[derive(Config, Debug)]
pub struct EmbeddingConfig {
/// The number of embedding vectors.
pub n_embedding: usize,
/// The size of each vector.
pub d_model: usize,
/// The type of function used to initialize neural network parameters
#[config(default = "Initializer::Normal{mean:0.0, std:1.0}")]
pub initializer: Initializer,
}
/// Lookup table to store a fix number of vectors.
///
/// Should be created with [EmbeddingConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Embedding<B: Backend> {
/// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
/// from a normal distribution `N(0, 1)`.
pub weight: Param<Tensor<B, 2>>,
}
impl<B: Backend> ModuleDisplay for Embedding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [n_embedding, d_model] = self.weight.shape().dims();
content
.add("n_embedding", &n_embedding)
.add("d_model", &d_model)
.optional()
}
}
impl EmbeddingConfig {
/// Initialize a new [embedding](Embedding) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
let weight = self
.initializer
.init([self.n_embedding, self.d_model], device);
Embedding { weight }
}
}
impl<B: Backend> Embedding<B> {
/// Applies the forward pass on the input tensor.
///
/// See also [embedding](burn::tensor::module::embedding).
///
/// # Shapes
///
/// - input: `[batch_size, seq_length]`
/// - output: `[batch_size, seq_length, d_model]`
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
embedding(self.weight.val(), input)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros);
let embed = config.init::<TestBackend>(&Default::default());
assert_eq!(config.initializer, Initializer::Zeros);
embed.weight.to_data().assert_approx_eq::<FT>(
&TensorData::zeros::<f32, _>(embed.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn display() {
let config = EmbeddingConfig::new(100, 10);
let embed = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{embed}"),
"Embedding {n_embedding: 100, d_model: 10, params: 1000}"
);
}
}

View File

@@ -0,0 +1,258 @@
use alloc::format;
use burn::tensor::module::interpolate;
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::InterpolateOptions;
use super::InterpolateMode;
/// Configuration for the 1D interpolation module.
///
/// This struct defines the configuration options for the 1D interpolation operation.
/// It allows specifying the output size, scale factor, and interpolation mode.
#[derive(Config, Debug)]
pub struct Interpolate1dConfig {
/// Output size of the interpolated tensor.
/// If specified, this takes precedence over `scale_factor`.
#[config(default = "None")]
pub output_size: Option<usize>,
/// Scale factor for resizing the input tensor.
/// This is used when `output_size` is not specified.
#[config(default = "None")]
pub scale_factor: Option<f32>,
/// Interpolation mode to use for resizing.
/// Determines how the output values are calculated.
#[config(default = "InterpolateMode::Nearest")]
pub mode: InterpolateMode,
/// If `true`, the input and output tensors are aligned by their corner pixels.
/// If `false`, half-pixel coordinate mapping is used instead.
#[config(default = true)]
pub align_corners: bool,
}
/// Interpolate module for resizing 1D tensors with shape [N, C, L].
///
/// This struct represents a 1D interpolation module that can resize tensors
/// using various interpolation methods. It provides flexibility in specifying
/// either an output size or a scale factor for resizing, along with options
/// for the interpolation mode.
///
/// The module can be used to upsample or downsample 1D tensors, preserving the
/// number of channels and batch size while adjusting the length dimension.
///
/// The module can be created using the [Interpolate1dConfig] struct and the
/// `init` method, which returns an instance of the [Interpolate1d] struct.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Interpolate1d {
/// Output size of the interpolated tensor
pub output_size: Option<usize>,
/// Scale factor for resizing the input tensor
pub scale_factor: Option<f32>,
/// Interpolation mode used for resizing
pub mode: Ignored<InterpolateMode>,
/// Whether to align corner pixels
pub align_corners: bool,
}
impl Interpolate1dConfig {
/// Initialize the interpolation module
pub fn init(self) -> Interpolate1d {
Interpolate1d {
output_size: self.output_size,
scale_factor: self.scale_factor,
mode: Ignored(self.mode),
align_corners: self.align_corners,
}
}
}
impl Interpolate1d {
/// Performs the forward pass of the 1D interpolation module
///
/// # Arguments
///
/// * `input` - Input tensor with shape [N, C, L]
///
/// # Returns
///
/// Resized tensor with shape [N, C, L'], where L' is determined by
/// the output_size or scale_factor specified in the module configuration
///
/// # Example
///
/// ```ignore
/// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);
/// let interpolate = Interpolate1dConfig::new()
/// .with_output_size(Some(128))
/// .init();
/// let output = interpolate.forward(input);
/// assert_eq!(output.dims(), [1, 3, 128]);
/// ```
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
// Use the interpolate operation to resize the temporal input tensor
// by adding a new dimension for the interpolation axis
let input = input.unsqueeze_dim(2);
let result = interpolate(
input,
[1, output_size],
InterpolateOptions::new(self.mode.0.clone().into())
.with_align_corners(self.align_corners),
);
result.squeeze_dims(&[2])
}
}
/// Calculate output size based on input dimensions, output size, and scale factor
///
/// # Arguments
///
/// * `input_dims` - Input dimensions of the tensor
/// * `output_size` - Output size for the interpolated tensor
/// * `scale_factor` - Scale factor for resizing the tensor
///
/// # Returns
///
/// Output size for the interpolated tensor
///
/// # Panics
///
/// Panics if neither output_size nor scale_factor is provided
/// or if the scale factor is too large
fn calculate_output_size(
input_dims: [usize; 3],
output_size: Option<usize>,
scale_factor: Option<f32>,
) -> usize {
match (output_size, scale_factor) {
(Some(output_size), None) => {
// Use provided
output_size
}
(None, Some(scale_factor)) => {
// Calculate output size based on scale factor
let [_, _, l] = input_dims;
let new_dim = (l as f64) * (scale_factor as f64);
if new_dim > usize::MAX as f64 {
panic!("Scale factor is too large");
}
new_dim as usize
}
_ => panic!("Either output_size or scale_factor must be provided"),
}
}
impl ModuleDisplay for Interpolate1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("mode", &self.mode)
.add("output_size", &format!("{:?}", self.output_size))
.add("scale_factor", &self.scale_factor)
.optional()
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Distribution;
use super::*;
use crate::TestBackend;
#[test]
fn test_calculate_output_size() {
let input_dims = [1, 1, 4];
let output_size = calculate_output_size(input_dims, Some(2), None);
assert_eq!(output_size, 2);
let output_size = calculate_output_size(input_dims, None, Some(2.0));
assert_eq!(output_size, 8);
let output_size = calculate_output_size(input_dims, None, Some(0.5));
assert_eq!(output_size, 2);
let output_size = calculate_output_size(input_dims, None, Some(1.5));
assert_eq!(output_size, 6);
}
#[test]
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
fn test_panic() {
let input_dims = [1, 1, 4];
calculate_output_size(input_dims, None, None);
}
#[test]
#[should_panic(expected = "Scale factor is too large")]
fn test_large_scale_factor() {
let input_dims = [1, 1, usize::MAX - 1];
calculate_output_size(input_dims, None, Some(2.0));
}
#[test]
fn test_module() {
let input = Tensor::<TestBackend, 3>::random(
[2, 3, 4],
Distribution::Uniform(0.0, 1.0),
&Default::default(),
);
// Test with output_size
let config = Interpolate1dConfig::new().with_output_size(Some(8));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 8]);
// Test with scale_factor
let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 2]);
// Test with different interpolation mode
let config = Interpolate1dConfig::new()
.with_output_size(Some(6))
.with_mode(InterpolateMode::Linear);
let interpolate = config.init();
let output = interpolate.forward(input);
assert_eq!(output.dims(), [2, 3, 6]);
}
#[test]
fn display() {
let config = Interpolate1dConfig::new().with_output_size(Some(20));
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"Interpolate1d {mode: Nearest, output_size: Some(20), \
scale_factor: None}"
);
}
}

View File

@@ -0,0 +1,261 @@
use alloc::format;
use burn::tensor::module::interpolate;
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::InterpolateOptions;
use super::InterpolateMode;
/// Configuration for the 2D interpolation module.
///
/// This struct defines the configuration options for the 2D interpolation operation.
/// It allows specifying the output size, scale factor, and interpolation mode.
#[derive(Config, Debug)]
pub struct Interpolate2dConfig {
/// Output size of the interpolated tensor.
/// If specified, this takes precedence over `scale_factor`.
#[config(default = "None")]
pub output_size: Option<[usize; 2]>,
/// Scale factor for resizing the input tensor.
/// This is used when `output_size` is not specified.
#[config(default = "None")]
pub scale_factor: Option<[f32; 2]>,
/// Interpolation mode to use for resizing.
/// Determines how the output values are calculated.
#[config(default = "InterpolateMode::Nearest")]
pub mode: InterpolateMode,
/// If `true`, the input and output tensors are aligned by their corner pixels.
/// If `false`, half-pixel coordinate mapping is used instead.
#[config(default = true)]
pub align_corners: bool,
}
/// Interpolate module for resizing tensors with shape [N, C, H, W].
///
/// This struct represents an interpolation module that can resize tensors
/// using various interpolation methods. It provides flexibility in specifying
/// either an output size or a scale factor for resizing, along with options
/// for the interpolation mode.
///
/// The module can be used to upsample or downsample tensors, preserving the
/// number of channels and batch size while adjusting the height and width
/// dimensions.
///
/// The module can be created using the [Interpolate2dConfig] struct and the
/// `init` method, which returns an instance of the [Interpolate2d] struct.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Interpolate2d {
/// Output size of the interpolated tensor
pub output_size: Option<[usize; 2]>,
/// Scale factor for resizing the input tensor
pub scale_factor: Option<[f32; 2]>,
/// Interpolation mode used for resizing
pub mode: Ignored<InterpolateMode>,
/// Whether to align corner pixels
pub align_corners: bool,
}
impl Interpolate2dConfig {
/// Initialize the interpolation module
pub fn init(self) -> Interpolate2d {
Interpolate2d {
output_size: self.output_size,
scale_factor: self.scale_factor,
mode: Ignored(self.mode),
align_corners: self.align_corners,
}
}
}
impl Interpolate2d {
/// Performs the forward pass of the interpolation module
///
/// # Arguments
///
/// * `input` - Input tensor with shape [N, C, H, W]
///
/// # Returns
///
/// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by
/// the output_size or scale_factor specified in the module configuration
///
/// # Example
///
/// ```ignore
/// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);
/// let interpolate = Interpolate2dConfig::new()
/// .with_output_size(Some([128, 128]))
/// .init();
/// let output = interpolate.forward(input);
/// assert_eq!(output.dims(), [1, 3, 128, 128]);
/// ```
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
interpolate(
input,
output_size,
InterpolateOptions::new(self.mode.0.clone().into())
.with_align_corners(self.align_corners),
)
}
}
/// Calculates the output size for tensor interpolation.
///
/// # Arguments
///
/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].
/// * `output_size` - Optional desired output size [H', W'].
/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].
///
/// # Returns
///
/// A tuple [H', W'] representing the calculated output size.
///
/// # Panics
///
/// Panics if neither `output_size` nor `scale_factor` is provided,
/// or if the scale factor results in dimensions exceeding usize::MAX.
fn calculate_output_size(
input_dims: [usize; 4],
output_size: Option<[usize; 2]>,
scale_factor: Option<[f32; 2]>,
) -> [usize; 2] {
match (output_size, scale_factor) {
(Some(output_size), None) => {
// Use provided
output_size
}
(None, Some(scale_factor)) => {
// Calculate output size based on scale factor
let [_, _, h, w] = input_dims;
let new_dim_h = (h as f64) * (scale_factor[0] as f64);
if new_dim_h > usize::MAX as f64 {
panic!("Scale factor for height is too large");
}
let new_dim_w = (w as f64) * (scale_factor[1] as f64);
if new_dim_w > usize::MAX as f64 {
panic!("Scale factor for width is too large");
}
[new_dim_h as usize, new_dim_w as usize]
}
_ => panic!("Either output_size or scale_factor must be provided"),
}
}
impl ModuleDisplay for Interpolate2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("mode", &self.mode)
.add("output_size", &format!("{:?}", self.output_size))
.add("scale_factor", &self.scale_factor)
.optional()
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Distribution;
use crate::TestBackend;
use super::*;
#[test]
fn test_calculate_output_size() {
let input_dims = [1, 1, 4, 4];
let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
assert_eq!(output_size, [2, 2]);
let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
assert_eq!(output_size, [8, 8]);
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
assert_eq!(output_size, [2, 2]);
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
assert_eq!(output_size, [8, 6]);
}
#[test]
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
fn test_missing_params() {
calculate_output_size([1, 1, 4, 4], None, None);
}
#[test]
#[should_panic(expected = "Scale factor for height is too large")]
fn test_infinite_height() {
calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
}
#[test]
#[should_panic(expected = "Scale factor for width is too large")]
fn test_infinite_width() {
calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
}
#[test]
fn test_module() {
let input = Tensor::<TestBackend, 4>::random(
[2, 3, 4, 4],
Distribution::Uniform(0.0, 1.0),
&Default::default(),
);
// Test with output_size
let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 8, 8]);
// Test with scale_factor
let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 2, 2]);
// Test with different interpolation mode
let config = Interpolate2dConfig::new()
.with_output_size(Some([6, 6]))
.with_mode(InterpolateMode::Linear);
let interpolate = config.init();
let output = interpolate.forward(input);
assert_eq!(output.dims(), [2, 3, 6, 6]);
}
#[test]
fn display() {
let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
scale_factor: None}"
);
}
}

View File

@@ -0,0 +1,49 @@
mod interpolate1d;
mod interpolate2d;
pub use interpolate1d::*;
pub use interpolate2d::*;
use burn_core as burn;
use burn::config::Config;
use burn::tensor::ops::InterpolateMode as OpsInterpolateMode;
/// Algorithm used for downsampling and upsampling
///
/// This enum defines different interpolation modes for resampling data.
#[derive(Config, Debug)]
pub enum InterpolateMode {
/// Nearest-neighbor interpolation
///
/// This mode selects the value of the nearest sample point for each output pixel.
/// It is applicable for both temporal and spatial data.
Nearest,
/// Linear interpolation
///
/// This mode calculates the output value using linear
/// interpolation between nearby sample points.
///
/// It is applicable for both temporal and spatial data.
Linear,
/// Cubic interpolation
///
/// This mode uses cubic interpolation to calculate the output value
/// based on surrounding sample points.
///
/// It is applicable for both temporal and spatial data and generally
/// provides smoother results than linear interpolation.
Cubic,
}
impl From<InterpolateMode> for OpsInterpolateMode {
fn from(mode: InterpolateMode) -> Self {
match mode {
InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
}
}
}

View File

@@ -0,0 +1,340 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Param;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::module::linear;
use burn::tensor::{Tensor, backend::Backend};
/// Configuration to create a [`Linear`] layer using the [init function](LinearConfig::init).
#[derive(Config, Debug)]
pub struct LinearConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the output features.
pub d_output: usize,
/// If a bias should be applied during the linear transformation.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
/// The layout in which the linear parameters are stored.
#[config(default = "LinearLayout::Row")]
pub layout: LinearLayout,
}
#[derive(Config, Debug, Copy)]
/// The layout in which the linear parameters are stored.
///
/// This can have performance impacts.
pub enum LinearLayout {
/// Parameters are stored in Row major.
Row,
/// Parameters are stored in Col major.
Col,
}
/// Applies a linear transformation to the input tensor.
///
/// Should be created with [LinearConfig]
///
/// `O = IW + b`
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Linear<B: Backend> {
/// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
pub weight: Param<Tensor<B, 2>>,
/// Vector of size `d_output` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
pub bias: Option<Param<Tensor<B, 1>>>,
}
impl LinearConfig {
/// Initialize a new [`Linear`] module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
let weight = match self.layout {
LinearLayout::Row => {
let shape = [self.d_input, self.d_output];
self.initializer
.init_with(shape, Some(self.d_input), Some(self.d_output), device)
}
LinearLayout::Col => {
let shape = [self.d_output, self.d_input];
self.initializer
.init_with(shape, Some(self.d_output), Some(self.d_input), device)
// The param is already transposed when init. We re-transpose to have
// [d_output, d_input] while saving.
.save_mapper(move |tensor| {
B::sync(&tensor.device()).unwrap();
let tensor = tensor.transpose();
B::sync(&tensor.device()).unwrap();
tensor
})
// When loading from record we have to transpose.
.load_mapper(move |tensor| {
B::sync(&tensor.device()).unwrap();
let tensor = tensor.transpose();
B::sync(&tensor.device()).unwrap();
tensor
})
// When loading from initialization, we have to transpose.
.init_mapper(|tensor| {
B::sync(&tensor.device()).unwrap();
let tensor = tensor.transpose();
B::sync(&tensor.device()).unwrap();
tensor
})
}
};
let bias = if self.bias {
Some(self.initializer.init_with(
[self.d_output],
Some(self.d_input),
Some(self.d_output),
device,
))
} else {
None
};
Linear { weight, bias }
}
}
impl<B: Backend> Linear<B> {
/// Applies the forward pass on the input tensor.
///
/// # Arguments
///
/// - `input` - The input tensor of shape `[..., d_input]`.
///
/// # Shapes
///
/// - input: `[..., d_input]`
/// - output: `[..., d_output]`
///
/// # Returns
///
/// The transformed tensor of shape `[..., d_output]`.
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
linear(
input,
self.weight.val(),
self.bias.as_ref().map(|b| b.val()),
)
}
}
impl<B: Backend> ModuleDisplay for Linear<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, d_output] = self.weight.shape().dims();
content
.add("d_input", &d_input)
.add("d_output", &d_output)
.add("bias", &self.bias.is_some())
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::module::ParamId;
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::ElementConversion;
use burn::tensor::{Shape, TensorData};
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn initializer_default() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = LinearConfig::new(5, 5);
let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();
let linear = config.init::<TestBackend>(&device);
assert_eq!(
config.initializer,
Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: false
}
);
linear.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
let linear = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, Initializer::Zeros);
linear.weight.to_data().assert_approx_eq::<FT>(
&TensorData::zeros::<f32, _>(linear.weight.shape()),
Tolerance::default(),
);
}
#[test]
fn test_linear_forward_no_bias() {
let device = Default::default();
TestBackend::seed(&device, 0);
let value = 2.;
let config = LinearConfig::new(2, 3)
.with_initializer(Initializer::Constant { value })
.with_bias(false);
let linear = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
let result = linear.forward(input);
let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
assert_eq!(result.into_data(), expected_result.into_data());
}
#[test]
fn test_linear_forward_with_bias() {
let device = Default::default();
TestBackend::seed(&device, 0);
let device = Default::default();
let value = 2.;
let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
let linear = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
let result = linear.forward(input);
let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
assert_eq!(result.into_data(), expected_result.into_data());
}
#[test]
fn test_linear_1d() {
let device = Default::default();
TestBackend::seed(&device, 0);
let device = Default::default();
let value = 2.;
let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
let linear = config.init::<TestBackend>(&device);
let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
let result_1d = linear.forward(input_1d).unsqueeze::<2>();
let result_2d = linear.forward(input_2d);
assert_eq!(result_1d.into_data(), result_2d.into_data());
}
#[test]
fn display() {
let config = LinearConfig::new(3, 5);
let linear = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{linear}"),
"Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
);
}
#[test]
fn layout() {
let device = Default::default();
let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
let linear = config.init::<TestBackend>(&device);
assert_eq!(linear.weight.dims(), [6, 12], "Shape is as configured");
let recorder = BinBytesRecorder::<FullPrecisionSettings>::new();
// We go through serialization to trigger the mappers..
let record = linear.into_record();
let data = recorder.record(record, ()).unwrap();
let record = recorder.load(data.clone(), &device).unwrap();
let config = LinearConfig::new(12, 6).with_layout(LinearLayout::Row);
let linear_row = config.init::<TestBackend>(&device).load_record(record);
assert_eq!(
linear_row.weight.dims(),
[12, 6],
"Shape should be transposed"
);
let record = recorder.load(data.clone(), &device).unwrap();
let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
let linear_col = config.init::<TestBackend>(&device).load_record(record);
assert_eq!(
linear_col.weight.dims(),
[6, 12],
"Shape should be as configured"
);
// We go through serialization to trigger the mappers.
//
// The test will fail if the mapper is not correctly given to the module after loading a
// record.
let record = linear_col.into_record();
let data = recorder.record(record, ()).unwrap();
let record = recorder.load(data, &device).unwrap();
let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
let linear_col = config.init::<TestBackend>(&device).load_record(record);
assert_eq!(
linear_col.weight.dims(),
[6, 12],
"Shape should be as configured"
);
}
#[test]
fn col_row_same_result() {
let device = Default::default();
let config_col = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
let linear_col = config_col.init::<TestBackend>(&device);
let signal = Tensor::<_, 2>::random([8, 6], burn::tensor::Distribution::Default, &device);
let value = linear_col.forward(signal.clone());
let data_1 = value.into_data();
let weights = linear_col.weight.val().into_data();
let weights = Tensor::from_data(weights, &device);
let linear = Linear {
weight: Param::initialized(ParamId::new(), weights),
bias: linear_col
.bias
.map(|b| Param::initialized(ParamId::new(), b.val())),
};
let value = linear.forward(signal);
let data_2 = value.into_data();
data_1.assert_approx_eq::<f32>(&data_2, Default::default());
}
}

View File

@@ -0,0 +1,38 @@
/// Attention module
pub mod attention;
/// Cache module
pub mod cache;
/// Convolution module
pub mod conv;
/// Pooling module
pub mod pool;
/// Transformer module
pub mod transformer;
/// Interpolate module
pub mod interpolate;
mod dropout;
mod embedding;
mod linear;
mod noise;
mod pos_encoding;
mod rnn;
mod rope_encoding;
mod unfold;
pub mod norm;
pub use norm::{batch::*, group::*, instance::*, layer::*, rms::*};
pub use dropout::*;
pub use embedding::*;
pub use linear::*;
pub use noise::*;
pub use pos_encoding::*;
pub use rnn::*;
pub use rope_encoding::*;
pub use unfold::*;

View File

@@ -0,0 +1,123 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution, Tensor};
/// Configuration to create a [GaussianNoise](GaussianNoise) layer using the [init function](GaussianNoiseConfig::init).
#[derive(Config, Debug)]
pub struct GaussianNoiseConfig {
/// Standard deviation of the normal noise distribution.
pub std: f64,
}
/// Add pseudorandom Gaussian noise to an arbitrarily shaped tensor.
///
/// This is an effective regularization technique that also contributes to data augmentation.
/// Please keep in mind that the value of [std](GaussianNoise::std) should be chosen with care in order to avoid
/// distortion.
///
/// Should be created with [GaussianNoiseConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct GaussianNoise {
/// Standard deviation of the normal noise distribution.
pub std: f64,
}
impl GaussianNoiseConfig {
/// Initialize a new [Gaussian noise](GaussianNoise) module.
pub fn init(&self) -> GaussianNoise {
if self.std.is_sign_negative() {
panic!(
"Standard deviation is required to be non-negative, but got {}",
self.std
);
}
GaussianNoise { std: self.std }
}
}
impl GaussianNoise {
/// Applies the forward pass on the input tensor.
///
/// See [GaussianNoise](GaussianNoise) for more information.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if B::ad_enabled(&input.device()) && self.std != 0.0 {
let noise = Tensor::random(
input.shape(),
Distribution::Normal(0.0, self.std),
&input.device(),
);
input + noise
} else {
input
}
}
}
impl ModuleDisplay for GaussianNoise {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("std", &self.std).optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Shape;
#[cfg(feature = "std")]
use crate::{TestAutodiffBackend, TestBackend};
#[cfg(not(feature = "std"))]
use crate::TestBackend;
#[cfg(feature = "std")]
#[test]
fn with_ad_backend_should_mark_input() {
let tensor =
Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
let noise = GaussianNoiseConfig::new(0.5).init();
let output = noise.forward(tensor.clone());
assert_ne!(tensor.to_data(), output.to_data());
}
#[test]
fn without_ad_backend_should_not_change_input() {
let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
let noise = GaussianNoiseConfig::new(0.5).init();
let output = noise.forward(tensor.clone());
assert_eq!(tensor.to_data(), output.to_data());
}
#[test]
#[should_panic(expected = "Standard deviation is required to be non-negative")]
fn negative_std_should_panic() {
GaussianNoiseConfig { std: -0.5 }.init();
}
#[test]
fn display() {
let config = GaussianNoiseConfig::new(0.5);
let layer = config.init();
assert_eq!(alloc::format!("{layer}"), "GaussianNoise {std: 0.5}");
}
}

View File

@@ -0,0 +1,484 @@
use burn_core as burn;
use burn::module::Initializer;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::{Tensor, backend::Backend};
use burn::{
config::Config,
module::{Module, Param, RunningState},
};
/// [`BatchNorm`] Configuration.
///
/// Used to create a [`BatchNorm`] layer using the [`BatchNormConfig::init`].
#[derive(Config, Debug)]
pub struct BatchNormConfig {
/// The number of features.
pub num_features: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
pub epsilon: f64,
/// Momentum used to update the metrics. Default: 0.1
#[config(default = 0.1)]
pub momentum: f64,
}
/// Applies Batch Normalization over a tensor.
///
/// Based upon the paper [Batch Normalization](https://arxiv.org/abs/1502.03167).
///
/// Assumes input tensor is of shape ``[batch_size, channels, ...]``.
///
/// `Y = norm(X) * γ + β`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `norm` is the normalization function
/// - `γ` is the learnable weight
/// - `β` is the learnable bias
///
/// Should be created using [`BatchNormConfig`].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BatchNorm<B: Backend> {
/// The learnable weight gamma.
pub gamma: Param<Tensor<B, 1>>,
/// The learnable weight beta.
pub beta: Param<Tensor<B, 1>>,
/// The running mean.
pub running_mean: RunningState<Tensor<B, 1>>,
/// The running variance.
pub running_var: RunningState<Tensor<B, 1>>,
/// Momentum used to update the metrics.
pub momentum: f64,
/// A value required for numerical stability.
pub epsilon: f64,
}
impl BatchNormConfig {
/// Initializes a new [batch norm](BatchNorm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {
let gamma = Initializer::Ones.init([self.num_features], device);
let beta = Initializer::Zeros.init([self.num_features], device);
let running_mean = Tensor::zeros([self.num_features], device);
let running_var = Tensor::ones([self.num_features], device);
BatchNorm {
gamma,
beta,
running_mean: RunningState::new(running_mean),
running_var: RunningState::new(running_var),
momentum: self.momentum,
epsilon: self.epsilon,
}
}
}
impl<B: Backend> BatchNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// See [`BatchNorm`] for more information.
///
/// # Shapes
///
/// - `input`: ``[batch_size, channels, ...]``
/// - `output`: ``[batch_size, channels, ...]``
///
/// # Panics
///
/// This function will panic if the input tensor has rank < 2.
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
// Should be move to a compilation error when const generic support that kind of
// validation. https://github.com/rust-lang/rust/issues/76560
if D < 2 {
panic!(
"BatchNorm can only be applied on tensors of rank >= 2 with the following shape \
[batch_size, channels, ...], received {}D tensor",
D
);
}
match B::ad_enabled(&input.device()) {
true => self.forward_train(input),
false => self.forward_inference(input),
}
}
fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let device = input.device();
let channels = input.dims()[1];
let mean = self.running_mean.value().to_device(&device);
let var = self.running_var.value().to_device(&device);
let mut shape = [1; D];
shape[1] = channels;
self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
}
fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let device = input.device();
let dims = input.dims();
let batch_size = dims[0];
let channels = dims[1];
let mut shape_unsqueeze = [1; D];
let mut flatten_size = batch_size;
shape_unsqueeze[1] = channels;
for dim in dims.iter().take(D).skip(2) {
flatten_size *= dim;
}
let mean = input
.clone()
.swap_dims(0, 1)
.reshape([channels, flatten_size])
.mean_dim(1)
.reshape(shape_unsqueeze);
let var = input
.clone()
.sub(mean.clone())
.square()
.swap_dims(0, 1)
.reshape([channels, flatten_size])
.mean_dim(1)
.reshape(shape_unsqueeze);
let running_mean = self.running_mean.value_sync().to_device(&device);
let running_var = self.running_var.value_sync().to_device(&device);
let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
mean.clone()
.detach()
.mul_scalar(self.momentum)
.reshape([channels]),
);
let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
var.clone()
.detach()
.mul_scalar(self.momentum)
.reshape([channels]),
);
self.running_mean.update(running_mean.detach());
self.running_var.update(running_var.detach());
self.forward_shared(input, mean, var)
}
fn forward_shared<const D: usize>(
&self,
x: Tensor<B, D>,
mean: Tensor<B, D>,
var: Tensor<B, D>,
) -> Tensor<B, D> {
let channels = x.dims()[1];
let mut shape = [1; D];
shape[1] = channels;
let std = var.add_scalar(self.epsilon).sqrt();
let x = x.sub(mean);
let x = x.div(std);
let x = x.mul(self.gamma.val().reshape(shape));
x.add(self.beta.val().reshape(shape))
}
}
impl<B: Backend> ModuleDisplay for BatchNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [num_features] = self.beta.shape().dims();
content
.add("num_features", &num_features)
.add("momentum", &self.momentum)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests_1d {
use super::*;
use crate::TestAutodiffBackend;
use burn::module::AutodiffModule;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestAutodiffBackend>;
#[test]
fn batch_norm_forward_train() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
let output = module.forward(input_tensor(&device));
output
.to_data()
.assert_approx_eq::<FT>(&expected_train(), Tolerance::rel_abs(0.1, 0.001));
}
#[test]
fn batch_norm_forward_inference() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
module.forward(input_tensor(&device));
let module = module.valid();
let output = module.forward(input_tensor(&device));
output
.to_data()
.assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
}
fn expected_valid() -> TensorData {
TensorData::from([
[[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
[[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
])
}
fn expected_train() -> TensorData {
TensorData::from([
[
[1.1483e+00, 3.7521e-01],
[1.6272e-03, 7.5067e-01],
[1.6204e+00, -4.5168e-02],
],
[
[6.8856e-02, -1.5923e+00],
[-1.6318e+00, 8.7949e-01],
[-5.3368e-01, -1.0416e+00],
],
])
}
fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
Tensor::<B, 3>::from_floats(
[
[[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
[[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
],
device,
)
}
#[test]
fn batch_norm_forward_train_inference() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
module.forward(input_tensor(&device));
let module = module.valid();
let output = module.forward(input_tensor(&device));
output
.to_data()
.assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
let module = module.train::<TestAutodiffBackend>();
let output = module.forward(input_tensor(&device));
output
.to_data()
.assert_approx_eq::<FT>(&expected_train(), Tolerance::default());
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests_2d {
use super::*;
use crate::TestAutodiffBackend;
use burn::module::AutodiffModule;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestAutodiffBackend>;
#[test]
fn batch_norm_forward_train() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
let output = module.forward(input_tensor(&device));
let expected = TensorData::from([
[
[[1.5136, 0.7506], [-1.2216, 0.1477]],
[[0.3135, 1.2252], [-0.4150, 0.6130]],
[[1.4186, 0.3372], [-1.5183, 1.5262]],
],
[
[[0.4483, -1.1914], [-1.2010, 0.7537]],
[[-1.6752, 1.3822], [-0.5058, -0.9381]],
[[0.0200, -0.3097], [-0.5715, -0.9026]],
],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
}
#[test]
fn batch_norm_forward_inference() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
module.forward(input_tensor(&device));
let module = module.valid();
let output = module.forward(input_tensor(&device));
let expected = TensorData::from([
[
[[0.9538, 0.7103], [0.0808, 0.5179]],
[[0.6015, 0.8910], [0.3703, 0.6966]],
[[0.9171, 0.6912], [0.3037, 0.9395]],
],
[
[[0.6138, 0.0904], [0.0874, 0.7113]],
[[-0.0297, 0.9408], [0.3415, 0.2042]],
[[0.6250, 0.5561], [0.5013, 0.4323]],
],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn batch_norm_running_mean() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
let _output = module.forward(input_tensor(&device));
let running_mean = module.running_mean.value_sync();
let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
running_mean
.reshape([3])
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn batch_norm_running_var() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
let _output = module.forward(input_tensor(&device));
let running_var = module.running_var.value_sync();
let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
running_var
.reshape([3])
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn batch_norm_running_mean_inner_module() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
let _output = module.forward(input_tensor(&device));
let module_valid = module.valid();
let running_mean = module_valid.running_mean.value();
let running_mean_after = module.running_mean.value();
running_mean_after
.into_data()
.assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
}
#[test]
fn batch_norm_grads() {
let device = Default::default();
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
let input = input_tensor(&device).require_grad();
let output = module.forward(input.clone());
let grads = output.backward();
let tolerance = Tolerance::rel_abs(0.1, 0.001);
let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
module
.gamma
.grad(&grads)
.unwrap()
.reshape([3])
.into_data()
.assert_approx_eq::<FT>(&expected, tolerance);
let expected = TensorData::from([8., 8., 8.]);
module
.beta
.grad(&grads)
.unwrap()
.reshape([3])
.into_data()
.assert_approx_eq::<FT>(&expected, tolerance);
let expected = TensorData::from([
[
[[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
[[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
[[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
],
[
[[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
[[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
[[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
],
]);
input
.grad(&grads)
.unwrap()
.into_data()
.assert_approx_eq::<FT>(&expected, tolerance);
}
fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
Tensor::<B, 4>::from_floats(
[
[
[[0.9601, 0.7277], [0.1270, 0.5441]],
[[0.6272, 0.9034], [0.4066, 0.7179]],
[[0.9378, 0.7230], [0.3544, 0.9591]],
],
[
[[0.6356, 0.1362], [0.1333, 0.7287]],
[[0.0249, 0.9509], [0.3791, 0.2481]],
[[0.6600, 0.5945], [0.5424, 0.4767]],
],
],
device,
)
}
#[test]
fn display() {
let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());
assert_eq!(
format!("{batch_norm}"),
"BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
);
}
}

View File

@@ -0,0 +1,336 @@
use burn::module::Initializer;
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::Param;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init).
#[derive(Debug, Config)]
pub struct GroupNormConfig {
/// The number of groups to separate the channels into
pub num_groups: usize,
/// The number of channels expected in the input
pub num_channels: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
pub epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
/// per-channel affine parameters initialized to ones (for weights)
/// and zeros (for biases). Default: `true`
#[config(default = true)]
pub affine: bool,
}
/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
///
/// `Y = groupnorm(X) * γ + β`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `γ` is the learnable weight
/// - `β` is the learnable bias
///
/// Should be created using [GroupNormConfig](GroupNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct GroupNorm<B: Backend> {
/// The learnable weight
pub gamma: Option<Param<Tensor<B, 1>>>,
/// The learnable bias
pub beta: Option<Param<Tensor<B, 1>>>,
/// The number of groups to separate the channels into
pub num_groups: usize,
/// The number of channels expected in the input
pub num_channels: usize,
/// A value required for numerical stability
pub epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
pub affine: bool,
}
impl<B: Backend> ModuleDisplay for GroupNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("num_groups", &self.num_groups)
.add("num_channels", &self.num_channels)
.add("epsilon", &self.epsilon)
.add("affine", &self.affine)
.optional()
}
}
impl GroupNormConfig {
/// Initialize a new [group norm](GroupNorm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert_eq!(
self.num_channels % self.num_groups,
0,
"The number of channels must be divisible by the number of groups"
);
let (gamma, beta) = if self.affine {
let gamma = Initializer::Ones.init([self.num_channels], device);
let beta = Initializer::Zeros.init([self.num_channels], device);
(Some(gamma), Some(beta))
} else {
(None, None)
};
GroupNorm {
num_groups: self.num_groups,
num_channels: self.num_channels,
gamma,
beta,
epsilon: self.epsilon,
affine: self.affine,
}
}
}
impl<B: Backend> GroupNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// See [GroupNorm](GroupNorm) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, num_channels, *]`
/// - output: `[batch_size, num_channels, *]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if input.shape()[1] != self.num_channels {
panic!(
"The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
self.num_channels,
input.shape()[1]
);
}
let gamma = self.gamma.as_ref().map(|x| x.val());
let beta = self.beta.as_ref().map(|x| x.val());
group_norm(
input,
gamma,
beta,
self.num_groups,
self.epsilon,
self.affine,
)
}
}
/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
///
/// `Y = groupnorm(X) * γ + β`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `γ` is the learnable weight
/// - `β` is the learnable bias
///
pub(crate) fn group_norm<B: Backend, const D: usize>(
input: Tensor<B, D>,
gamma: Option<Tensor<B, 1>>,
beta: Option<Tensor<B, 1>>,
num_groups: usize,
epsilon: f64,
affine: bool,
) -> Tensor<B, D> {
if (beta.is_none() || gamma.is_none()) && affine {
panic!("Affine is set to true, but gamma or beta is None");
}
let shape = input.shape();
if shape.num_elements() <= 2 {
panic!(
"input rank for GroupNorm should be at least 3, but got {}",
shape.num_elements()
);
}
let batch_size = shape[0];
let num_channels = shape[1];
let hidden_size = shape[2..].iter().product::<usize>() * num_channels / num_groups;
let input = input.reshape([batch_size, num_groups, hidden_size]);
let mean = input.clone().sum_dim(2) / hidden_size as f64;
let input = input.sub(mean);
let var = input.clone().square().sum_dim(2) / hidden_size as f64;
let input_normalized = input.div(var.add_scalar(epsilon).sqrt());
if affine {
let mut affine_shape = [1; D];
affine_shape[1] = num_channels;
input_normalized
.reshape(shape)
.mul(gamma.clone().unwrap().reshape(affine_shape))
.add(beta.clone().unwrap().reshape(affine_shape))
} else {
input_normalized.reshape(shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use alloc::format;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn group_norm_forward_affine_false() {
let device = Default::default();
let module = GroupNormConfig::new(2, 6)
.with_affine(false)
.init::<TestBackend>(&device);
assert!(module.gamma.is_none());
assert!(module.beta.is_none());
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
[
[-0.3034, 0.2726, -0.9659],
[-1.1845, -1.3236, 0.0172],
[1.9507, 1.2554, -0.8625],
[1.0682, 0.3604, 0.3985],
[-0.4957, -0.4461, -0.9721],
[1.5157, -0.1546, -0.5596],
],
[
[-1.6698, -0.4040, -0.7927],
[0.3736, -0.0975, -0.1351],
[-0.9461, 0.5461, -0.6334],
[-1.0919, -0.1158, 0.1213],
[-0.9535, 0.1281, 0.4372],
[-0.2845, 0.3488, 0.5641],
],
]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([
[
[-0.1653, 0.3748, -0.7866],
[-0.9916, -1.1220, 0.1353],
[1.9485, 1.2965, -0.6896],
[1.2769, 0.3628, 0.4120],
[-0.7427, -0.6786, -1.3578],
[1.8547, -0.3022, -0.8252],
],
[
[-1.9342, 0.0211, -0.5793],
[1.2223, 0.4945, 0.4365],
[-0.8163, 1.4887, -0.3333],
[-1.7960, -0.0392, 0.3875],
[-1.5469, 0.3998, 0.9561],
[-0.3428, 0.7970, 1.1845],
],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn group_norm_forward_affine_true() {
let device = Default::default();
let module = GroupNormConfig::new(3, 6)
.with_affine(true)
.init::<TestBackend>(&device);
let tolerance = Tolerance::permissive();
module
.gamma
.as_ref()
.expect("gamma should not be None")
.val()
.to_data()
.assert_approx_eq::<FT>(&TensorData::ones::<f32, _>([6]), tolerance);
module
.beta
.as_ref()
.expect("beta should not be None")
.val()
.to_data()
.assert_approx_eq::<FT>(&TensorData::zeros::<f32, _>([6]), tolerance);
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
[
[0.3345, 0.4429, 0.6639],
[0.5041, 0.4175, 0.8437],
[0.6159, 0.3758, 0.4071],
[0.5417, 0.5785, 0.7671],
[0.3837, 0.9883, 0.0420],
[0.4808, 0.8989, 0.6144],
],
[
[0.3930, 0.2098, 0.0602],
[0.2298, 0.9425, 0.0333],
[0.7409, 0.8172, 0.8879],
[0.4846, 0.0486, 0.2029],
[0.6741, 0.9765, 0.6864],
[0.2827, 0.5534, 0.2125],
],
]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([
[
[-1.1694, -0.5353, 0.7572],
[-0.1775, -0.6838, 1.8087],
[0.5205, -1.3107, -1.0723],
[-0.0459, 0.2351, 1.6734],
[-0.5796, 1.3218, -1.6544],
[-0.2744, 1.0406, 0.1459],
],
[
[0.2665, -0.3320, -0.8205],
[-0.2667, 2.0612, -0.9085],
[0.6681, 0.9102, 1.1345],
[-0.1453, -1.5287, -1.0389],
[0.4253, 1.5962, 0.4731],
[-1.0903, -0.0419, -1.3623],
],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
}
#[test]
fn display() {
let config = GroupNormConfig::new(3, 6);
let group_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{group_norm}"),
"GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
);
}
}

View File

@@ -0,0 +1,228 @@
use burn_core as burn;
use crate::norm::group_norm;
use burn::config::Config;
use burn::module::Initializer;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Module, Param};
use burn::tensor::{Tensor, backend::Backend};
/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).
#[derive(Debug, Config)]
pub struct InstanceNormConfig {
/// The number of channels expected in the input
pub num_channels: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
pub epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
/// per-channel affine parameters initialized to ones (for weights)
/// and zeros (for biases). Default: `true`
#[config(default = true)]
pub affine: bool,
}
/// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022)
///
/// Should be created using [InstanceNormConfig](InstanceNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct InstanceNorm<B: Backend> {
/// The learnable weight
pub gamma: Option<Param<Tensor<B, 1>>>,
/// The learnable bias
pub beta: Option<Param<Tensor<B, 1>>>,
/// The number of channels expected in the input
pub num_channels: usize,
/// A value required for numerical stability
pub epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
pub affine: bool,
}
impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("num_channels", &self.num_channels)
.add("epsilon", &self.epsilon)
.add("affine", &self.affine)
.optional()
}
}
impl InstanceNormConfig {
/// Initialize a new [instance norm](InstanceNorm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
let (gamma, beta) = if self.affine {
let gamma = Initializer::Ones.init([self.num_channels], device);
let beta = Initializer::Zeros.init([self.num_channels], device);
(Some(gamma), Some(beta))
} else {
(None, None)
};
InstanceNorm {
gamma,
beta,
num_channels: self.num_channels,
epsilon: self.epsilon,
affine: self.affine,
}
}
}
impl<B: Backend> InstanceNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// See also [InstanceNormConfig](InstanceNormConfig) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, num_channels, *]`
/// - output: `[batch_size, num_channels, *]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
// Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.
let num_groups = self.num_channels;
let gamma = self.gamma.as_ref().map(|x| x.val());
let beta = self.beta.as_ref().map(|x| x.val());
group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use alloc::format;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn instance_norm_forward_affine_false() {
let device = Default::default();
let module = InstanceNormConfig::new(6)
.with_affine(false)
.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
[
[-0.3034, 0.2726, -0.9659],
[-1.1845, 1.4078, 0.9774],
[0.3963, -1.3738, 1.4125],
[1.0682, 0.3604, 0.3985],
[-0.4957, -0.4461, -0.9721],
[1.5157, -0.1546, -0.5596],
],
[
[-1.6698, -0.4040, -0.7927],
[0.3736, -0.0975, -0.1351],
[-0.9461, 0.5461, -0.6334],
[-1.0919, -0.1158, 0.1213],
[-0.9535, 0.1281, 0.4372],
[-0.2845, 0.3488, 0.5641],
],
]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([
[
[0.0569, 1.1952, -1.2522],
[-1.3971, 0.8883, 0.5088],
[0.2183, -1.3192, 1.1009],
[1.4126, -0.7649, -0.6477],
[0.5999, 0.8091, -1.409],
[1.39, -0.4696, -0.9205],
],
[
[-1.3492, 1.0417, 0.3075],
[1.411, -0.6243, -0.7867],
[-0.9363, 1.386, -0.4497],
[-1.3899, 0.4692, 0.9208],
[-1.3822, 0.4319, 0.9503],
[-1.3714, 0.3868, 0.9846],
],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn instance_norm_forward_affine_true() {
let device = Default::default();
let module = InstanceNormConfig::new(6)
.with_affine(true)
.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
[
[0.3345, 0.4429, 0.6639],
[0.5041, 0.4175, 0.8437],
[0.6159, 0.3758, 0.4071],
[0.5417, 0.5785, 0.7671],
[0.3837, 0.9883, 0.0420],
[0.4808, 0.8989, 0.6144],
],
[
[0.3930, 0.2098, 0.0602],
[0.2298, 0.9425, 0.0333],
[0.7409, 0.8172, 0.8879],
[0.4846, 0.0486, 0.2029],
[0.6741, 0.9765, 0.6864],
[0.2827, 0.5534, 0.2125],
],
]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([
[
[-1.06458, -0.2738, 1.33838],
[-0.45848, -0.92929, 1.38777],
[1.40388, -0.84877, -0.55511],
[-0.88515, -0.51245, 1.3976],
[-0.22397, 1.32124, -1.09727],
[-1.05468, 1.34316, -0.28848],
],
[
[1.26372, -0.08229, -1.18144],
[-0.44049, 1.38403, -0.94354],
[-1.23828, 0.03109, 1.2072],
[1.32524, -1.08999, -0.23524],
[-0.75061, 1.4132, -0.66259],
[-0.45469, 1.38697, -0.93228],
],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = InstanceNormConfig::new(6);
let instance_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{instance_norm}"),
"InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
);
}
}

View File

@@ -0,0 +1,232 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Content;
use burn::module::DisplaySettings;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::ModuleDisplay;
use burn::module::Param;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).
#[derive(Debug, Config)]
pub struct LayerNormConfig {
/// The size of the input features.
pub d_model: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
pub epsilon: f64,
/// If a bias (beta) should be applied during the normalization. Default: true
#[config(default = true)]
pub bias: bool,
}
/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).
///
/// `Y = norm(X) * γ + β`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `γ` is the learnable weight (scale)
/// - `β` is the learnable bias (optional)
///
/// Should be created using [LayerNormConfig](LayerNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct LayerNorm<B: Backend> {
/// The learnable weight (scale).
pub gamma: Param<Tensor<B, 1>>,
/// The learnable bias (optional).
pub beta: Option<Param<Tensor<B, 1>>>,
/// A value required for numerical stability.
epsilon: f64,
}
impl LayerNormConfig {
/// Initialize a new [layer norm](LayerNorm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
let gamma = Initializer::Ones.init([self.d_model], device);
let beta = if self.bias {
Some(Initializer::Zeros.init([self.d_model], device))
} else {
None
};
LayerNorm {
gamma,
beta,
epsilon: self.epsilon,
}
}
}
impl<B: Backend> LayerNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// See the [LayerNorm](LayerNorm) documentation for more information.
///
/// # Shapes
///
/// - input: `[..., any, d_model]`
/// - output: `[..., any, d_model]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let (var, mean) = input.clone().var_mean_bias(D - 1);
let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
let output = input_normalized.mul(self.gamma.val().unsqueeze());
match &self.beta {
Some(beta) => output.add(beta.val().unsqueeze()),
None => output,
}
}
}
impl<B: Backend> ModuleDisplay for LayerNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model] = self.gamma.shape().dims();
content
.add("d_model", &d_model)
.add("epsilon", &self.epsilon)
.add("bias", &self.beta.is_some())
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::format;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[cfg(feature = "std")]
use crate::{TestAutodiffBackend, TestBackend};
#[cfg(not(feature = "std"))]
use crate::TestBackend;
#[test]
fn layer_norm_forward() {
let device = Default::default();
let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
]]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([[
-0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
]]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn layer_norm_forward_large_epsilon() {
let device = Default::default();
let module = LayerNormConfig::new(10)
.with_epsilon(1e-1)
.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
]]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([[
-0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
]]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[cfg(feature = "std")]
#[test]
fn layer_norm_backward() {
let device = Default::default();
let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
&device,
)
.require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
&device,
)
.require_grad();
let x = tensor_1.clone().matmul(tensor_2.clone());
let output = module.forward(x);
let grads = output.backward();
let tensor_1_grad = tensor_1.grad(&grads).unwrap();
let tensor_2_grad = tensor_2.grad(&grads).unwrap();
let gamma_grad = module.gamma.grad(&grads).unwrap();
let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();
let expected = TensorData::from([-2.0, 2.0]);
gamma_grad
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([2.0, 2.0]);
beta_grad
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
tensor_1_grad
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
tensor_2_grad
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = LayerNormConfig::new(6);
let layer_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{layer_norm}"),
"LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}"
);
}
#[test]
fn display_no_bias() {
let config = LayerNormConfig::new(6).with_bias(false);
let layer_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{layer_norm}"),
"LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}"
);
}
}

View File

@@ -0,0 +1,28 @@
//! # Normalization Layers
//!
//! Users who wish to provide an abstraction over swappable normalization
//! layers can use the [`Normalization`] wrapper, with support for:
//! * [`Normalization::Batch`] - [`BatchNorm`]
//! * [`Normalization::Group`] - [`GroupNorm`]
//! * [`Normalization::Instance`] - [`InstanceNorm`]
//! * [`Normalization::Layer`] - [`LayerNorm`]
//! * [`Normalization::Rms`] - [`RmsNorm`]
//!
//! [`NormalizationConfig`] can be used as a generic normalization policy:
//! * Construct a config with arbitrary input features (we suggest `0`).
//! * Clone and match that config to the target input layer,
//! using the [`NormalizationConfig::with_num_features()`] method.
pub(crate) mod batch;
pub(crate) mod group;
pub(crate) mod instance;
pub(crate) mod layer;
pub(crate) mod rms;
mod normalization_wrapper;
pub use batch::*;
pub use group::*;
pub use instance::*;
pub use layer::*;
pub use normalization_wrapper::*;
pub use rms::*;

View File

@@ -0,0 +1,368 @@
use burn_core as burn;
use crate::{
BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig,
LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig,
};
use burn::prelude::{Config, Module};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// ['Normalization'] Configuration.
///
/// The enum is non-exhaustive to prepare for future additions.
///
/// Can be used as a generic configuration for normalization layers:
/// * Construct a config with arbitrary input features (we suggest `0`).
/// * Clone and match that config to the target input layer,
/// using the [`NormalizationConfig::with_num_features()`] method.
#[derive(Config, Debug)]
#[non_exhaustive]
pub enum NormalizationConfig {
/// ['BatchNorm'] Configuration.
Batch(BatchNormConfig),
/// ['GroupNorm'] Configuration.
Group(GroupNormConfig),
/// ['InstanceNorm'] Configuration.
Instance(InstanceNormConfig),
/// ['LayerNorm'] Configuration.
Layer(LayerNormConfig),
/// ['RmsNorm'] Configuration.
Rms(RmsNormConfig),
}
impl From<BatchNormConfig> for NormalizationConfig {
fn from(config: BatchNormConfig) -> Self {
Self::Batch(config)
}
}
impl From<GroupNormConfig> for NormalizationConfig {
fn from(config: GroupNormConfig) -> Self {
Self::Group(config)
}
}
impl From<InstanceNormConfig> for NormalizationConfig {
fn from(config: InstanceNormConfig) -> Self {
Self::Instance(config)
}
}
impl From<LayerNormConfig> for NormalizationConfig {
fn from(config: LayerNormConfig) -> Self {
Self::Layer(config)
}
}
impl From<RmsNormConfig> for NormalizationConfig {
fn from(config: RmsNormConfig) -> Self {
Self::Rms(config)
}
}
impl NormalizationConfig {
/// Initialize a ['Norm'] layer.
pub fn init<B: Backend>(&self, device: &B::Device) -> Normalization<B> {
match self {
NormalizationConfig::Batch(config) => config.init(device).into(),
NormalizationConfig::Group(config) => config.init(device).into(),
NormalizationConfig::Instance(config) => config.init(device).into(),
NormalizationConfig::Layer(config) => config.init(device).into(),
NormalizationConfig::Rms(config) => config.init(device).into(),
}
}
/// Set the number of features.
pub fn with_num_features(self, num_features: usize) -> Self {
match self {
NormalizationConfig::Batch(config) => BatchNormConfig {
num_features,
..config
}
.into(),
NormalizationConfig::Group(config) => GroupNormConfig {
num_channels: num_features,
..config
}
.into(),
NormalizationConfig::Instance(config) => InstanceNormConfig {
num_channels: num_features,
..config
}
.into(),
NormalizationConfig::Layer(config) => LayerNormConfig {
d_model: num_features,
..config
}
.into(),
NormalizationConfig::Rms(config) => RmsNormConfig {
d_model: num_features,
..config
}
.into(),
}
}
/// Get the number of features.
pub fn num_features(&self) -> usize {
match self {
NormalizationConfig::Batch(config) => config.num_features,
NormalizationConfig::Group(config) => config.num_channels,
NormalizationConfig::Instance(config) => config.num_channels,
NormalizationConfig::Layer(config) => config.d_model,
NormalizationConfig::Rms(config) => config.d_model,
}
}
}
/// Normalization Layer Wrapper
///
/// Provides support for built-in ``burn::nn::norm`` norm layers:
/// * [`Normalization::Batch`] - [`BatchNorm`]
/// * [`Normalization::Group`] - [`GroupNorm`]
/// * [`Normalization::Instance`] - [`InstanceNorm`]
/// * [`Normalization::Layer`] - [`LayerNorm`]
/// * [`Normalization::Rms`] - [`RmsNorm`]
///
/// The enum is non-exhaustive, to prepare for future additions.
#[derive(Module, Debug)]
#[non_exhaustive]
pub enum Normalization<B: Backend> {
/// [`BatchNorm`] layer.
Batch(BatchNorm<B>),
/// [`GroupNorm`] layer.
Group(GroupNorm<B>),
/// ['InstanceNorm'] layer.
Instance(InstanceNorm<B>),
/// [`LayerNorm`] layer.
Layer(LayerNorm<B>),
/// ['RmsNorm'] layer.
Rms(RmsNorm<B>),
}
impl<B: Backend> From<BatchNorm<B>> for Normalization<B> {
fn from(layer: BatchNorm<B>) -> Self {
Self::Batch(layer)
}
}
impl<B: Backend> From<GroupNorm<B>> for Normalization<B> {
fn from(layer: GroupNorm<B>) -> Self {
Self::Group(layer)
}
}
impl<B: Backend> From<InstanceNorm<B>> for Normalization<B> {
fn from(layer: InstanceNorm<B>) -> Self {
Self::Instance(layer)
}
}
impl<B: Backend> From<LayerNorm<B>> for Normalization<B> {
fn from(layer: LayerNorm<B>) -> Self {
Self::Layer(layer)
}
}
impl<B: Backend> From<RmsNorm<B>> for Normalization<B> {
fn from(layer: RmsNorm<B>) -> Self {
Self::Rms(layer)
}
}
impl<B: Backend> Normalization<B> {
/// Applies normalization to a tensor.
///
/// The normalization contract depends upon the wrapped norm layer;
/// but all norm layers assume an input of at least rank 2;
/// and produce an output of the same rank and shape.
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
match self {
Normalization::Batch(norm) => norm.forward(input),
Normalization::Group(norm) => norm.forward(input),
Normalization::Instance(norm) => norm.forward(input),
Normalization::Layer(norm) => norm.forward(input),
Normalization::Rms(norm) => norm.forward(input),
}
}
/// Get the number of features.
pub fn num_features(&self) -> usize {
match self {
Normalization::Batch(norm) => norm.gamma.shape()[0],
Normalization::Group(norm) => norm.num_channels,
Normalization::Instance(norm) => norm.num_channels,
Normalization::Layer(norm) => norm.gamma.shape()[0],
Normalization::Rms(norm) => norm.gamma.shape()[0],
}
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestAutodiffBackend>;
#[test]
fn test_match_feature_size() {
let config: NormalizationConfig = BatchNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = GroupNormConfig::new(4, 0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = InstanceNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = LayerNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = RmsNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
}
#[test]
fn test_batch_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
let config: NormalizationConfig = BatchNormConfig::new(12).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Batch(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output.to_data().assert_eq(&expected.to_data(), true);
}
#[test]
fn test_group_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Group(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_instance_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
let config: NormalizationConfig = InstanceNormConfig::new(num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Instance(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_layer_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
let config: NormalizationConfig = LayerNormConfig::new(num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Layer(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_rms_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
let config: NormalizationConfig = RmsNormConfig::new(num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Rms(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,134 @@
use burn::tensor::DType;
use burn_core as burn;
use burn::config::Config;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::Param;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init).
#[derive(Config, Debug)]
pub struct RmsNormConfig {
/// The size of the input features.
pub d_model: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
pub epsilon: f64,
}
impl RmsNormConfig {
/// Initialize a new [RMS Norm](RmsNorm) module.
///
/// # Panics
///
/// Panics if `epsilon` is not positive.
pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
assert!(self.epsilon > 0.0, "epsilon must be positive.");
let gamma = Initializer::Ones.init([self.d_model], device);
RmsNorm {
gamma,
epsilon: self.epsilon,
}
}
}
/// Applies RMS Normalization over an input tensor along the last dimension.
///
/// `Y = X / sqrt(mean(X^2) + eps) * gamma`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `gamma` is the learnable weight
/// - `mean` is the mean operation
/// - `eps` is a small value to avoid division by zero.
///
/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct RmsNorm<B: Backend> {
/// The learnable parameter to scale the normalized tensor
pub gamma: Param<Tensor<B, 1>>,
/// A value required for numerical stability
pub epsilon: f64,
}
impl<B: Backend> RmsNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// See the [RmsNorm](RmsNorm) documentation for more information.
///
/// # Shapes
///
/// - input: `[..., any, d_model]`
/// - output: `[..., any, d_model]`
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
// Calculate the root-mean-square norm of the input tensor along the last dimension
let dtype = x.dtype();
let rms = (x.clone().cast(DType::F32).square().mean_dim(D - 1) + self.epsilon).sqrt();
(x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
}
}
impl<B: Backend> ModuleDisplay for RmsNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model] = self.gamma.shape().dims();
content
.add("d_model", &d_model)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use alloc::format;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn rms_norm_forward() {
let device = Default::default();
let module = RmsNormConfig::new(3)
.with_epsilon(1e-5)
.init::<TestBackend>(&device);
let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);
let output = module.forward(input);
let expected = TensorData::from([
[0.0000, 0.7746, 1.5492],
[0.7348, 0.9798, 1.2247],
[0.8514, 0.9933, 1.1352],
]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn display() {
let config = RmsNormConfig::new(6);
let layer_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{layer_norm}"),
"RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
);
}
}

View File

@@ -0,0 +1,77 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::adaptive_avg_pool1d;
/// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer using the [init function](AdaptiveAvgPool1dConfig::init).
#[derive(Config, Debug)]
pub struct AdaptiveAvgPool1dConfig {
/// The size of the output.
pub output_size: usize,
}
/// Applies a 1D adaptive avg pooling over input tensors.
///
/// Should be created with [AdaptiveAvgPool1dConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AdaptiveAvgPool1d {
/// The size of the output.
pub output_size: usize,
}
impl ModuleDisplay for AdaptiveAvgPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("output_size", &self.output_size).optional()
}
}
impl AdaptiveAvgPool1dConfig {
/// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module.
pub fn init(&self) -> AdaptiveAvgPool1d {
AdaptiveAvgPool1d {
output_size: self.output_size,
}
}
}
impl AdaptiveAvgPool1d {
/// Applies the forward pass on the input tensor.
///
/// See [adaptive_avg_pool1d](burn::tensor::module::adaptive_avg_pool1d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels, length]`
/// - output: `[batch_size, channels, length_out]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
adaptive_avg_pool1d(input, self.output_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = AdaptiveAvgPool1dConfig::new(3);
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"AdaptiveAvgPool1d {output_size: 3}"
);
}
}

View File

@@ -0,0 +1,79 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::adaptive_avg_pool2d;
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer using the [init function](AdaptiveAvgPool2dConfig::init).
#[derive(Config, Debug)]
pub struct AdaptiveAvgPool2dConfig {
/// The size of the output.
pub output_size: [usize; 2],
}
/// Applies a 2D adaptive avg pooling over input tensors.
///
/// Should be created with [AdaptiveAvgPool2dConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AdaptiveAvgPool2d {
/// The size of the output.
pub output_size: [usize; 2],
}
impl ModuleDisplay for AdaptiveAvgPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let output_size = alloc::format!("{:?}", self.output_size);
content.add("output_size", &output_size).optional()
}
}
impl AdaptiveAvgPool2dConfig {
/// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.
pub fn init(&self) -> AdaptiveAvgPool2d {
AdaptiveAvgPool2d {
output_size: self.output_size,
}
}
}
impl AdaptiveAvgPool2d {
/// Applies the forward pass on the input tensor.
///
/// See [adaptive_avg_pool2d](burn::tensor::module::adaptive_avg_pool2d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels, height_in, width_in]`
/// - output: `[batch_size, channels, height_out, width_out]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
adaptive_avg_pool2d(input, self.output_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = AdaptiveAvgPool2dConfig::new([3, 3]);
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"AdaptiveAvgPool2d {output_size: [3, 3]}"
);
}
}

View File

@@ -0,0 +1,220 @@
use burn_core as burn;
use crate::PaddingConfig1d;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Ignored, Module};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::PadMode;
use burn::tensor::module::avg_pool1d;
/// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init).
#[derive(Config, Debug)]
pub struct AvgPool1dConfig {
/// The size of the kernel.
pub kernel_size: usize,
/// The stride.
#[config(default = "kernel_size")]
pub stride: usize,
/// The padding configuration.
///
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
/// will automatically use asymmetric padding to preserve input dimensions.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
pub count_include_pad: bool,
/// If true, use ceiling instead of floor for output size calculation.
#[config(default = "false")]
pub ceil_mode: bool,
}
/// Applies a 1D avg pooling over input tensors.
///
/// Should be created with [AvgPool1dConfig](AvgPool1dConfig).
///
/// # Remarks
///
/// The zero-padding values will be included in the calculation
/// of the average. This means that the zeros are counted as
/// legitimate values, and they contribute to the denominator
/// when calculating the average. This is equivalent to
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AvgPool1d {
/// The stride.
pub stride: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig1d>,
/// If the padding is counted in the denominator when computing the average.
pub count_include_pad: bool,
/// If true, use ceiling instead of floor for output size calculation.
pub ceil_mode: bool,
}
impl ModuleDisplay for AvgPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &self.kernel_size)
.add("stride", &self.stride)
.add("padding", &self.padding)
.add("count_include_pad", &self.count_include_pad)
.add("ceil_mode", &self.ceil_mode)
.optional()
}
}
impl AvgPool1dConfig {
/// Initialize a new [avg pool 1d](AvgPool1d) module.
pub fn init(&self) -> AvgPool1d {
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
ceil_mode: self.ceil_mode,
}
}
}
impl AvgPool1d {
/// Applies the forward pass on the input tensor.
///
/// See [avg_pool1d](burn::tensor::module::avg_pool1d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels, length_in]`
/// - output: `[batch_size, channels, length_out]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch_size, _channels, length] = input.dims();
// Calculate padding as pair - handles Same, Valid, and Explicit uniformly
let (left, right) =
self.padding
.calculate_padding_1d_pair(length, self.kernel_size, self.stride);
// TODO: Move asymmetric padding to functional level via PoolOptions
// See: https://github.com/tracel-ai/burn/issues/4362
// Handle asymmetric padding by applying explicit pad operation first
if left != right {
// Burn's pad takes (left, right, top, bottom) for the last two dimensions
// For 1D (NCL format), we only pad L (last dim), so top/bottom = 0
let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0));
// Use zero padding for the pool operation since we already padded
avg_pool1d(
padded,
self.kernel_size,
self.stride,
0,
self.count_include_pad,
self.ceil_mode,
)
} else {
// Symmetric padding
avg_pool1d(
input,
self.kernel_size,
self.stride,
left,
self.count_include_pad,
self.ceil_mode,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use rstest::rstest;
#[test]
fn same_with_even_kernel_uses_asymmetric_padding() {
let device = Default::default();
let config = AvgPool1dConfig::new(2)
.with_stride(1)
.with_padding(PaddingConfig1d::Same);
let pool = config.init();
// Input: [batch=1, channels=2, length=5]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
let output = pool.forward(input);
// Same padding should preserve spatial dimensions
assert_eq!(output.dims(), [1, 2, 5]);
}
#[test]
fn display() {
let config = AvgPool1dConfig::new(3);
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}"
);
}
#[rstest]
#[case(1)]
#[case(2)]
fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
let config = AvgPool1dConfig::new(kernel_size);
assert_eq!(
config.stride, kernel_size,
"Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor",
config.stride, config.kernel_size
);
}
#[test]
fn asymmetric_padding_forward() {
let device = Default::default();
// Create avg pool with asymmetric padding: left=1, right=2
let config = AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Explicit(1, 2));
let pool = config.init();
// Input: [batch=1, channels=2, length=4]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
let output = pool.forward(input);
// With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
// Output length = (7 - 3) / 1 + 1 = 5
assert_eq!(output.dims(), [1, 2, 5]);
}
#[test]
fn symmetric_explicit_padding_forward() {
let device = Default::default();
// Create avg pool with symmetric explicit padding: left=2, right=2
let config = AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Explicit(2, 2));
let pool = config.init();
// Input: [batch=1, channels=2, length=4]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
let output = pool.forward(input);
// With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
// Output length = (8 - 3) / 1 + 1 = 6
assert_eq!(output.dims(), [1, 2, 6]);
}
}

View File

@@ -0,0 +1,223 @@
use burn_core as burn;
use crate::PaddingConfig2d;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Ignored, Module};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::PadMode;
use burn::tensor::module::avg_pool2d;
/// Configuration to create a [2D avg pooling](AvgPool2d) layer using the [init function](AvgPool2dConfig::init).
#[derive(Config, Debug)]
pub struct AvgPool2dConfig {
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The strides.
#[config(default = "kernel_size")]
pub strides: [usize; 2],
/// The padding configuration.
///
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
/// will automatically use asymmetric padding to preserve input dimensions.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
pub count_include_pad: bool,
/// If true, use ceiling instead of floor for output size calculation.
#[config(default = "false")]
pub ceil_mode: bool,
}
/// Applies a 2D avg pooling over input tensors.
///
/// Should be created with [AvgPool2dConfig](AvgPool2dConfig).
///
/// # Remarks
///
/// The zero-padding values will be included in the calculation
/// of the average. This means that the zeros are counted as
/// legitimate values, and they contribute to the denominator
/// when calculating the average. This is equivalent to
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AvgPool2d {
/// Stride of the pooling.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Padding configuration.
pub padding: Ignored<PaddingConfig2d>,
/// If the padding is counted in the denominator when computing the average.
pub count_include_pad: bool,
/// If true, use ceiling instead of floor for output size calculation.
pub ceil_mode: bool,
}
impl ModuleDisplay for AvgPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("padding", &self.padding)
.add("count_include_pad", &self.count_include_pad)
.add("ceil_mode", &self.ceil_mode)
.optional()
}
}
impl AvgPool2dConfig {
/// Initialize a new [avg pool 2d](AvgPool2d) module.
pub fn init(&self) -> AvgPool2d {
AvgPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
ceil_mode: self.ceil_mode,
}
}
}
impl AvgPool2d {
/// Applies the forward pass on the input tensor.
///
/// See [avg_pool2d](burn::tensor::module::avg_pool2d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels, height_in, width_in]`
/// - output: `[batch_size, channels, height_out, width_out]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
// Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
height_in,
width_in,
&self.kernel_size,
&self.stride,
);
// TODO: Move asymmetric padding to functional level via PoolOptions
// See: https://github.com/tracel-ai/burn/issues/4362
// Handle asymmetric padding by applying explicit pad operation first
if top != bottom || left != right {
// Burn's pad takes (left, right, top, bottom) for the last two dimensions
let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0));
// Use zero padding for the pool operation since we already padded
avg_pool2d(
padded,
self.kernel_size,
self.stride,
[0, 0],
self.count_include_pad,
self.ceil_mode,
)
} else {
// Symmetric padding
avg_pool2d(
input,
self.kernel_size,
self.stride,
[top, left],
self.count_include_pad,
self.ceil_mode,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use rstest::rstest;
#[test]
fn same_with_even_kernel_uses_asymmetric_padding() {
let device = Default::default();
let config = AvgPool2dConfig::new([2, 2])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Same);
let pool = config.init();
// Input: [batch=1, channels=2, height=5, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
let output = pool.forward(input);
// Same padding should preserve spatial dimensions
assert_eq!(output.dims(), [1, 2, 5, 5]);
}
#[test]
fn display() {
let config = AvgPool2dConfig::new([3, 3]);
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"AvgPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, count_include_pad: true, ceil_mode: false}"
);
}
#[rstest]
#[case([2, 2])]
#[case([1, 2])]
fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
let config = AvgPool2dConfig::new(kernel_size);
assert_eq!(
config.strides, kernel_size,
"Expected strides ({:?}) to match kernel size ({:?}) in default AvgPool2dConfig::new constructor",
config.strides, config.kernel_size
);
}
#[test]
fn asymmetric_padding_forward() {
let device = Default::default();
// Create avg pool with asymmetric padding: top=1, left=2, bottom=3, right=4
let config = AvgPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
let pool = config.init();
// Input: [batch=1, channels=2, height=4, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
let output = pool.forward(input);
// Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
// Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
assert_eq!(output.dims(), [1, 2, 6, 9]);
}
#[test]
fn symmetric_explicit_padding_forward() {
let device = Default::default();
// Create avg pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2
let config = AvgPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
let pool = config.init();
// Input: [batch=1, channels=2, height=4, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
let output = pool.forward(input);
// Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
// Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
assert_eq!(output.dims(), [1, 2, 6, 7]);
}
}

View File

@@ -0,0 +1,214 @@
use burn_core as burn;
use crate::PaddingConfig1d;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Ignored, Module};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::PadMode;
use burn::tensor::module::max_pool1d;
/// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init).
#[derive(Config, Debug)]
pub struct MaxPool1dConfig {
/// The size of the kernel.
pub kernel_size: usize,
/// The stride.
#[config(default = "kernel_size")]
pub stride: usize,
/// The padding configuration.
///
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
/// will automatically use asymmetric padding to preserve input dimensions.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// The dilation.
#[config(default = "1")]
pub dilation: usize,
/// If true, use ceiling instead of floor for output size calculation.
#[config(default = "false")]
pub ceil_mode: bool,
}
/// Applies a 1D max pooling over input tensors.
///
/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct MaxPool1d {
/// The stride.
pub stride: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig1d>,
/// The dilation.
pub dilation: usize,
/// If true, use ceiling instead of floor for output size calculation.
pub ceil_mode: bool,
}
impl ModuleDisplay for MaxPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &self.kernel_size)
.add("stride", &self.stride)
.add("padding", &self.padding)
.add("dilation", &self.dilation)
.add("ceil_mode", &self.ceil_mode)
.optional()
}
}
impl MaxPool1dConfig {
/// Initialize a new [max pool 1d](MaxPool1d) module.
pub fn init(&self) -> MaxPool1d {
MaxPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
ceil_mode: self.ceil_mode,
}
}
}
impl MaxPool1d {
/// Applies the forward pass on the input tensor.
///
/// See [max_pool1d](burn::tensor::module::max_pool1d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels, length_in]`
/// - output: `[batch_size, channels, length_out]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch_size, _channels, length] = input.dims();
// Calculate padding as pair - handles Same, Valid, and Explicit uniformly
let (left, right) =
self.padding
.calculate_padding_1d_pair(length, self.kernel_size, self.stride);
// TODO: Move asymmetric padding to functional level via PoolOptions
// See: https://github.com/tracel-ai/burn/issues/4362
// Handle asymmetric padding by applying explicit pad operation first
if left != right {
// For 1D (NCL format), pad the length dimension with (left, right)
// and no padding for channel dimension (top=0, bottom=0)
// Use -inf for max pooling so padded values don't affect the max
let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY));
// Use zero padding for the pool operation since we already padded
max_pool1d(
padded,
self.kernel_size,
self.stride,
0,
self.dilation,
self.ceil_mode,
)
} else {
// Symmetric padding
max_pool1d(
input,
self.kernel_size,
self.stride,
left,
self.dilation,
self.ceil_mode,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use rstest::rstest;
#[test]
fn same_with_even_kernel_uses_asymmetric_padding() {
let device = Default::default();
let config = MaxPool1dConfig::new(2)
.with_stride(1)
.with_padding(PaddingConfig1d::Same);
let pool = config.init();
// Input: [batch=1, channels=2, length=5]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
let output = pool.forward(input);
// Same padding should preserve spatial dimensions
assert_eq!(output.dims(), [1, 2, 5]);
}
#[test]
fn display() {
let config = MaxPool1dConfig::new(3);
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}"
);
}
#[rstest]
#[case(1)]
#[case(2)]
fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
let config = MaxPool1dConfig::new(kernel_size);
assert_eq!(
config.stride, kernel_size,
"Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor",
config.stride, config.kernel_size
);
}
#[test]
fn asymmetric_padding_forward() {
let device = Default::default();
// Create max pool with asymmetric padding: left=1, right=2
let config = MaxPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Explicit(1, 2));
let pool = config.init();
// Input: [batch=1, channels=2, length=4]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
let output = pool.forward(input);
// With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
// Output length = (7 - 3) / 1 + 1 = 5
assert_eq!(output.dims(), [1, 2, 5]);
}
#[test]
fn symmetric_explicit_padding_forward() {
let device = Default::default();
// Create max pool with symmetric explicit padding: left=2, right=2
let config = MaxPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Explicit(2, 2));
let pool = config.init();
// Input: [batch=1, channels=2, length=4]
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
let output = pool.forward(input);
// With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
// Output length = (8 - 3) / 1 + 1 = 6
assert_eq!(output.dims(), [1, 2, 6]);
}
}

View File

@@ -0,0 +1,219 @@
use burn_core as burn;
use crate::PaddingConfig2d;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Ignored, Module};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::PadMode;
use burn::tensor::module::max_pool2d;
/// Configuration to create a [2D max pooling](MaxPool2d) layer using the [init function](MaxPool2dConfig::init).
#[derive(Debug, Config)]
pub struct MaxPool2dConfig {
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The strides.
#[config(default = "kernel_size")]
pub strides: [usize; 2],
/// The padding configuration.
///
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
/// will automatically use asymmetric padding to preserve input dimensions.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// The dilation.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// If true, use ceiling instead of floor for output size calculation.
#[config(default = "false")]
pub ceil_mode: bool,
}
/// Applies a 2D max pooling over input tensors.
///
/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct MaxPool2d {
/// The strides.
pub stride: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
/// The dilation.
pub dilation: [usize; 2],
/// If true, use ceiling instead of floor for output size calculation.
pub ceil_mode: bool,
}
impl ModuleDisplay for MaxPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("padding", &self.padding)
.add("dilation", &alloc::format!("{:?}", &self.dilation))
.add("ceil_mode", &self.ceil_mode)
.optional()
}
}
impl MaxPool2dConfig {
/// Initialize a new [max pool 2d](MaxPool2d) module.
pub fn init(&self) -> MaxPool2d {
MaxPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
ceil_mode: self.ceil_mode,
}
}
}
impl MaxPool2d {
/// Applies the forward pass on the input tensor.
///
/// See [max_pool2d](burn::tensor::module::max_pool2d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels, height_in, width_in]`
/// - output: `[batch_size, channels, height_out, width_out]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
// Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
height_in,
width_in,
&self.kernel_size,
&self.stride,
);
// TODO: Move asymmetric padding to functional level via PoolOptions
// See: https://github.com/tracel-ai/burn/issues/4362
// Handle asymmetric padding by applying explicit pad operation first
if top != bottom || left != right {
// Burn's pad takes (left, right, top, bottom) for the last two dimensions
// Use -inf for max pooling so padded values don't affect the max
let padded = input.pad(
(left, right, top, bottom),
PadMode::Constant(f32::NEG_INFINITY),
);
// Use zero padding for the pool operation since we already padded
max_pool2d(
padded,
self.kernel_size,
self.stride,
[0, 0],
self.dilation,
self.ceil_mode,
)
} else {
// Symmetric padding
max_pool2d(
input,
self.kernel_size,
self.stride,
[top, left],
self.dilation,
self.ceil_mode,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use rstest::rstest;
#[test]
fn same_with_even_kernel_uses_asymmetric_padding() {
let device = Default::default();
let config = MaxPool2dConfig::new([2, 2])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Same);
let pool = config.init();
// Input: [batch=1, channels=2, height=5, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
let output = pool.forward(input);
// Same padding should preserve spatial dimensions
assert_eq!(output.dims(), [1, 2, 5, 5]);
}
#[test]
fn display() {
let config = MaxPool2dConfig::new([3, 3]);
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}"
);
}
#[rstest]
#[case([2, 2])]
#[case([1, 2])]
fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
let config = MaxPool2dConfig::new(kernel_size);
assert_eq!(
config.strides, kernel_size,
"Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor",
config.strides, config.kernel_size
);
}
#[test]
fn asymmetric_padding_forward() {
let device = Default::default();
// Create max pool with asymmetric padding: top=1, left=2, bottom=3, right=4
let config = MaxPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
let pool = config.init();
// Input: [batch=1, channels=2, height=4, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
let output = pool.forward(input);
// Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
// Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
assert_eq!(output.dims(), [1, 2, 6, 9]);
}
#[test]
fn symmetric_explicit_padding_forward() {
let device = Default::default();
// Create max pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2
let config = MaxPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
let pool = config.init();
// Input: [batch=1, channels=2, height=4, width=5]
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
let output = pool.forward(input);
// Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
// Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
assert_eq!(output.dims(), [1, 2, 6, 7]);
}
}

View File

@@ -0,0 +1,13 @@
mod adaptive_avg_pool1d;
mod adaptive_avg_pool2d;
mod avg_pool1d;
mod avg_pool2d;
mod max_pool1d;
mod max_pool2d;
pub use adaptive_avg_pool1d::*;
pub use adaptive_avg_pool2d::*;
pub use avg_pool1d::*;
pub use avg_pool2d::*;
pub use max_pool1d::*;
pub use max_pool2d::*;

View File

@@ -0,0 +1,291 @@
use burn_core as burn;
use alloc::vec::Vec;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::TensorData;
use burn::tensor::backend::Backend;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init).
#[derive(Config, Debug)]
pub struct PositionalEncodingConfig {
/// Maximum sequence size to use.
#[config(default = "5_000")]
pub max_sequence_size: usize,
/// The size of each vector.
pub d_model: usize,
/// Max time scale to use.
#[config(default = "10_000")]
pub max_timescale: usize,
}
/// Positional encoding layer for transformer models.
///
/// This layer adds positional information to the input embeddings, allowing the transformer model
/// to take into account the order of the sequence. The positional encoding is added to the input
/// embeddings by computing a set of sinusoidal functions with different frequencies and phases.
///
/// Sinusoids are used for positional embedding introduced in
/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
///
/// The reference implementation can be found here:
/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
///
/// Should be created using [PositionalEncodingConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionalEncoding<B: Backend> {
/// The sinusoids used to add positional information to the input embeddings.
pub sinusoids: Tensor<B, 3>,
/// The maximum sequence size to use.
pub max_sequence_size: usize,
/// Max time scale to use.
pub max_timescale: usize,
}
impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [_, _, d_model] = self.sinusoids.shape().dims();
content
.add("d_model", &d_model)
.add("max_sequence_size", &self.max_sequence_size)
.add("max_timescale", &self.max_timescale)
.optional()
}
}
impl PositionalEncodingConfig {
/// Initialize a new [PositionalEncoding](PositionalEncoding) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
let sinusoids = generate_sinusoids::<B>(
self.max_sequence_size,
self.d_model,
self.max_timescale,
device,
)
.unsqueeze::<3>();
PositionalEncoding {
sinusoids,
max_sequence_size: self.max_sequence_size,
max_timescale: self.max_timescale,
}
}
}
impl<B: Backend> PositionalEncoding<B> {
/// Applies the forward pass on the input tensor by adding the sinusoids to the input.
///
/// # Shapes
///
/// * input: [batch_size, seq_length, d_model]
/// * output: [batch_size, seq_length, d_model]
///
///
/// # Panics
///
/// * Panics if the input sequence length is greater than the maximum sequence size.
/// * Panics if the input d_model is not equal to the d_model of the sinusoids.
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let [_, seq_length, d_model_input] = input.dims();
let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
assert!(
max_sequence_size >= seq_length,
"max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})"
);
assert!(
d_model_input == d_model,
"d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})"
);
let slices = [0..batch_size, 0..seq_length, 0..d_model];
input.add(self.sinusoids.clone().slice(slices))
}
}
/// Returns sinusoids for positional embedding introduced in
/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
///
/// The reference implementation can be found here:
/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
///
/// # Arguments
///
/// * `length` - The length of the sequence.
/// * `d_model` - The size of each vector.
/// * `max_timescale` - The maximum time scale to use.
///
/// # Returns
///
/// A tensor of shape [length, d_model] containing the sinusoids.
pub fn generate_sinusoids<B: Backend>(
length: usize,
d_model: usize,
max_timescale: usize,
device: &B::Device,
) -> Tensor<B, 2> {
assert!(d_model.is_multiple_of(2), "d_model must be even");
assert!(
max_timescale >= length,
"max_timescale must be greater than length"
);
// Calculate the increment for the logarithmic timescale
let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
// Create a vector to hold the sinusoids
let mut scaled_time_sin_cos = Vec::with_capacity(length);
// Loop over each position in the sequence
for i in 0..length {
// Create a vector to hold the sinusoids for this position
let mut row = Vec::with_capacity(d_model / 2);
// Loop over each dimension of the sinusoids
for k in (0..d_model).step_by(2) {
// Calculate the division term for this dimension
let div_term = (k as f32 * log_timescale_increment).exp();
// Calculate the sine and cosine values for this dimension and position
row.push((div_term * i as f32).sin());
row.push((div_term * i as f32).cos());
}
// Add the sinusoids for this position to the vector
scaled_time_sin_cos.push(row);
}
// Convert the sinusoids to a tensor and return it
let data = TensorData::new(
scaled_time_sin_cos.into_iter().flatten().collect(),
[length, d_model],
);
Tensor::<B, 2>::from_data(data, device)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_module() {
let d_model = 6;
let length = 3;
// expected to broadcast
let batch_size = 2;
let device = Default::default();
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
// Use a tensor of zeros as input for easy verification of the output
// The output should be the sinusoids broadcasted to the input shape
let tensor = Tensor::zeros([batch_size, length, d_model], &device);
let output = pe.forward(tensor);
assert_eq!(&*output.shape(), [batch_size, length, d_model]);
let expected = Tensor::<TestBackend, 3>::from_floats(
[
[
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
],
[
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
],
],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_generate_sinusoids() {
let device = Default::default();
let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
// The values are taken from the pytorch reference implementation
let expected = Tensor::<TestBackend, 2>::from_floats(
[
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
[0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
[-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
[-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
[-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
[0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
[0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
[0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
[-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
[-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
],
&device,
);
sinusoids
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
#[should_panic]
fn d_model_input_should_match() {
let d_model = 8;
let device = Default::default();
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
let input = Tensor::zeros([1, 5, 10], &device);
let _output = pe.forward(input);
}
#[test]
#[should_panic]
fn input_length_should_be_less_than_max_len() {
let d_model = 8;
let device = Default::default();
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
let input = Tensor::zeros([1, 6_000, d_model], &device);
let _output = pe.forward(input);
}
#[test]
fn display() {
let config = PositionalEncodingConfig::new(4);
let pe = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{pe}"),
"PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
);
}
}

View File

@@ -0,0 +1,742 @@
use burn_core as burn;
use crate::GateController;
use crate::activation::{Activation, ActivationConfig};
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// A RnnState is used to store hidden state in RNN.
pub struct RnnState<B: Backend, const D: usize> {
/// The hidden state.
pub hidden: Tensor<B, D>,
}
impl<B: Backend, const D: usize> RnnState<B, D> {
/// Initialize a new [RNN State](RnnState).
pub fn new(hidden: Tensor<B, D>) -> Self {
Self { hidden }
}
}
/// Configuration to create a [Rnn](Rnn) module using the [init function](RnnConfig::init).
#[derive(Config, Debug)]
pub struct RnnConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the hidden state.
pub d_hidden: usize,
/// If a bias should be applied during the Rnn transformation.
pub bias: bool,
/// Rnn initializer
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
#[config(default = true)]
pub batch_first: bool,
/// If true, process the sequence in reverse order.
/// This is useful for implementing reverse-direction RNNs (e.g., ONNX reverse direction).
#[config(default = false)]
pub reverse: bool,
/// Optional hidden state clip threshold. If provided, hidden state values are clipped
/// to the range `[-clip, +clip]` after each timestep. This can help prevent
/// exploding values during inference.
pub clip: Option<f64>,
/// Activation function applied to the hidden state before computing hidden output.
/// Default is Tanh, which is standard for Rnn.
#[config(default = "ActivationConfig::Tanh")]
pub hidden_activation: ActivationConfig,
}
/// The Rnn module. This implementation is for a unidirectional, stateless, Rnn.
/// Should be created with [RnnConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Rnn<B: Backend> {
/// gate controller for Rnn (has single gate).
pub gate: GateController<B>,
/// The hidden state of the Rnn.
pub d_hidden: usize,
/// If true, input is `[batch_size, seq_length, input_size]`.
/// If false, input is `[seq_length, batch_size, input_size]`.
pub batch_first: bool,
/// If true, process the sequence in reverse order.
pub reverse: bool,
/// Optional hidden state clip threshold.
pub clip: Option<f64>,
/// Activation function for hidden output.
pub hidden_activation: Activation<B>,
}
impl<B: Backend> ModuleDisplay for Rnn<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.gate.input_transform.weight.shape().dims();
let bias = self.gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl RnnConfig {
/// Initialize a new [Rnn](Rnn) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Rnn<B> {
let d_output = self.d_hidden;
let new_gate = || {
GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
device,
)
};
Rnn {
gate: new_gate(),
d_hidden: self.d_hidden,
batch_first: self.batch_first,
reverse: self.reverse,
clip: self.clip,
hidden_activation: self.hidden_activation.init(device),
}
}
}
impl<B: Backend> Rnn<B> {
/// Applies the forward pass on the input tensor. This RNN implementation
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
///
/// ## Parameters:
/// - batched_input: The input tensor of shape:
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
/// - state: An optional `RnnState` representing the initial hidden state.
/// The state tensor has shape `[batch_size, hidden_size]`.
/// If no initial state is provided, these tensors are initialized to zeros.
///
/// ## Returns:
/// - output: A tensor represents the output features of Rnn. Shape:
/// - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true
/// - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false
/// - state: A `RnnState` represents the final hidden state. The hidden state tensor has the shape
/// `[batch_size, hidden_size]`.
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<RnnState<B, 2>>,
) -> (Tensor<B, 3>, RnnState<B, 2>) {
// Convert to batch-first layout internally if needed
let batched_input = if self.batch_first {
batched_input
} else {
batched_input.swap_dims(0, 1)
};
let device = batched_input.device();
let [batch_size, seq_length, _] = batched_input.dims();
// Process sequence in forward or reverse order based on config
let (output, state) = if self.reverse {
self.forward_iter(
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
state,
batch_size,
seq_length,
&device,
)
} else {
self.forward_iter(
batched_input.iter_dim(1).zip(0..seq_length),
state,
batch_size,
seq_length,
&device,
)
};
// Convert output back to seq-first layout if needed
let output = if self.batch_first {
output
} else {
output.swap_dims(0, 1)
};
(output, state)
}
fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
&self,
input_timestep_iter: I,
state: Option<RnnState<B, 2>>,
batch_size: usize,
seq_length: usize,
device: &B::Device,
) -> (Tensor<B, 3>, RnnState<B, 2>) {
let mut batched_hidden_state =
Tensor::empty([batch_size, seq_length, self.d_hidden], device);
let mut hidden_state = match state {
Some(state) => state.hidden,
None => Tensor::zeros([batch_size, self.d_hidden], device),
};
for (input_t, t) in input_timestep_iter {
let input_t = input_t.squeeze_dim(1);
// Compute gate output: h_t = activation(W_i @ x_t + W_h @ h_{t-1} + b)
let biased_gate_sum = self
.gate
.gate_product(input_t.clone(), hidden_state.clone());
let output_values = self.hidden_activation.forward(biased_gate_sum);
// Update hidden state
hidden_state = output_values;
// Apply hidden state clipping if configured
if let Some(clip) = self.clip {
hidden_state = hidden_state.clamp(-clip, clip);
}
let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
// store the hidden state for this timestep
batched_hidden_state = batched_hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
unsqueezed_hidden_state.clone(),
);
}
(batched_hidden_state, RnnState::new(hidden_state))
}
}
/// Configuration to create a [BiRnn](BiRnn) module using the [init function](BiRnnConfig::init).
#[derive(Config, Debug)]
pub struct BiRnnConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the hidden state.
pub d_hidden: usize,
/// If a bias should be applied during the BiRnn transformation.
pub bias: bool,
/// BiRnn initializer
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
#[config(default = true)]
pub batch_first: bool,
/// Optional hidden state clip threshold.
pub clip: Option<f64>,
/// Activation function applied to the hidden state before computing hidden output.
#[config(default = "ActivationConfig::Tanh")]
pub hidden_activation: ActivationConfig,
}
/// The BiRnn module. This implementation is for Bidirectional RNN.
/// Should be created with [BiRnnConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BiRnn<B: Backend> {
/// RNN for the forward direction.
pub forward: Rnn<B>,
/// RNN for the reverse direction.
pub reverse: Rnn<B>,
/// The size of the hidden state.
pub d_hidden: usize,
/// If true, input is `[batch_size, seq_length, input_size]`.
/// If false, input is `[seq_length, batch_size, input_size]`.
pub batch_first: bool,
}
impl<B: Backend> ModuleDisplay for BiRnn<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.forward.gate.input_transform.weight.shape().dims();
let bias = self.forward.gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl BiRnnConfig {
/// Initialize a new [Bidirectional RNN](BiRnn) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> BiRnn<B> {
// Internal RNNs always use batch_first=true; BiRnn handles layout conversion
let base_config = RnnConfig::new(self.d_input, self.d_hidden, self.bias)
.with_initializer(self.initializer.clone())
.with_batch_first(true)
.with_clip(self.clip)
.with_hidden_activation(self.hidden_activation.clone());
BiRnn {
forward: base_config.clone().init(device),
reverse: base_config.init(device),
d_hidden: self.d_hidden,
batch_first: self.batch_first,
}
}
}
impl<B: Backend> BiRnn<B> {
/// Applies the forward pass on the input tensor. This Bidirectional RNN implementation
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
///
/// ## Parameters:
/// - batched_input: The input tensor of shape:
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
/// - state: An optional `RnnState` representing the hidden state.
/// Each state tensor has shape `[2, batch_size, hidden_size]`.
/// If no initial state is provided, these tensors are initialized to zeros.
///
/// ## Returns:
/// - output: A tensor represents the output features of RNN. Shape:
/// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
/// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
/// - state: A `RnnState` represents the final forward and reverse states.
/// The `state.hidden` have the shape `[2, batch_size, hidden_size]`.
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<RnnState<B, 3>>,
) -> (Tensor<B, 3>, RnnState<B, 3>) {
// Convert to batch-first layout internally if needed
let batched_input = if self.batch_first {
batched_input
} else {
batched_input.swap_dims(0, 1)
};
let device = batched_input.clone().device();
let [batch_size, seq_length, _] = batched_input.shape().dims();
let [init_state_forward, init_state_reverse] = match state {
Some(state) => {
let hidden_state_forward = state
.hidden
.clone()
.slice([0..1, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
let hidden_state_reverse = state
.hidden
.slice([1..2, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
[
Some(RnnState::new(hidden_state_forward)),
Some(RnnState::new(hidden_state_reverse)),
]
}
None => [None, None],
};
// forward direction
let (batched_hidden_state_forward, final_state_forward) = self
.forward
.forward(batched_input.clone(), init_state_forward);
// reverse direction
let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
init_state_reverse,
batch_size,
seq_length,
&device,
);
let output = Tensor::cat(
[batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
2,
);
// Convert output back to seq-first layout if needed
let output = if self.batch_first {
output
} else {
output.swap_dims(0, 1)
};
let state = RnnState::new(Tensor::stack(
[final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
0,
));
(output, state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{LinearRecord, TestBackend};
use burn::module::Param;
use burn::tensor::{Device, Distribution, TensorData};
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[cfg(feature = "std")]
use crate::TestAutodiffBackend;
fn create_single_feature_gate_controller(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &Device<TestBackend>,
) -> GateController<TestBackend> {
let record_1 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
};
let record_2 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
};
GateController::create_with_weights(
d_input,
d_output,
bias,
initializer,
record_1,
record_2,
)
}
#[test]
fn test_with_uniform_initializer() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = RnnConfig::new(5, 5, false)
.with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
let rnn = config.init::<TestBackend>(&Default::default());
let gate_to_data =
|gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
gate_to_data(rnn.gate).assert_within_range::<FT>(0.elem()..1.elem());
}
/// Test forward pass with simple input vector.
///
/// Simple RNN: h_t = tanh(W_input @ x_t + W_hidden @ h_{t-1} + b)
/// With input=0.1, weight_input=0.5, bias=0.0, h_0=0.0, weight_hidden=0.5
/// h_t = tanh(0.5*0.1 + 0.5*0) = tanh(0.05) = 0.04995
#[test]
fn test_forward_single_input_single_feature() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = RnnConfig::new(1, 1, false);
let device = Default::default();
let mut rnn = config.init::<TestBackend>(&device);
rnn.gate = create_single_feature_gate_controller(
0.5,
0.0,
1,
1,
false,
Initializer::XavierUniform { gain: 1.0 },
&device,
);
// single timestep with single feature
let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
let (output, state) = rnn.forward(input, None);
let tolerance = Tolerance::default();
let expected = TensorData::from([[0.04995]]);
state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
output
.select(0, Tensor::arange(0..1, &device))
.squeeze_dim::<2>(0)
.to_data()
.assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
}
#[test]
fn test_batched_forward_pass_batch_of_one() {
let device = Default::default();
let rnn = RnnConfig::new(64, 1024, true).init(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
let (output, state) = rnn.forward(batched_input, None);
assert_eq!(output.dims(), [1, 2, 1024]);
assert_eq!(state.hidden.dims(), [1, 1024]);
}
#[test]
#[cfg(feature = "std")]
fn test_batched_backward_pass() {
use burn::tensor::Shape;
let device = Default::default();
let rnn = RnnConfig::new(64, 32, true).init(&device);
let shape: Shape = [8, 10, 64].into();
let batched_input =
Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
let (output, _) = rnn.forward(batched_input.clone(), None);
let fake_loss = output;
let grads = fake_loss.backward();
let some_gradient = rnn.gate.hidden_transform.weight.grad(&grads).unwrap();
// Asserts that the gradients exist and are non-zero
assert_ne!(
some_gradient
.any()
.into_data()
.iter::<f32>()
.next()
.unwrap(),
0.0
);
}
#[test]
fn test_bidirectional() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = BiRnnConfig::new(2, 3, true);
let mut rnn = config.init(&device);
fn create_gate_controller<const D1: usize, const D2: usize>(
input_weights: [[f32; D1]; D2],
input_biases: [f32; D1],
hidden_weights: [[f32; D1]; D1],
hidden_biases: [f32; D1],
device: &Device<TestBackend>,
) -> GateController<TestBackend> {
let d_input = input_weights[0].len();
let d_output = input_weights.len();
let input_record = LinearRecord {
weight: Param::from_data(TensorData::from(input_weights), device),
bias: Some(Param::from_data(TensorData::from(input_biases), device)),
};
let hidden_record = LinearRecord {
weight: Param::from_data(TensorData::from(hidden_weights), device),
bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
};
GateController::create_with_weights(
d_input,
d_output,
true,
Initializer::XavierUniform { gain: 1.0 },
input_record,
hidden_record,
)
}
// [batch_size=1, seq_length=4, input_size=2]
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[
[0.949, -0.861],
[0.892, 0.927],
[-0.173, -0.301],
[-0.081, 0.992],
]]),
&device,
);
// [2, batch_size=1, hidden_size=3]
let h0 = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
&device,
);
rnn.forward.gate = create_gate_controller(
// input_weights: [input_size=2, hidden_size=3]
[[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
// input_biases: [hidden_size=3]
[-0.196, 0.354, 0.209],
// hidden_weights: [hidden_size=3, hidden_size=3]
[
[-0.320, 0.232, -0.165],
[0.093, -0.572, -0.315],
[-0.467, 0.325, 0.046],
],
// hidden_biases: [hidden_size=3]
[0.181, -0.190, -0.245],
&device,
);
rnn.reverse.gate = create_gate_controller(
[[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
[0.540, -0.164, 0.033],
[
[0.159, 0.180, -0.037],
[-0.443, 0.485, -0.488],
[0.098, -0.085, -0.140],
],
[-0.510, 0.105, 0.114],
&device,
);
// [batch_size=1, sequence_length=4, hidden_size * 2 = 6]
// The expected output values were computed from PyTorch
let expected_output_with_init_state = TensorData::from([[
[0.5226, -0.6370, 0.0210, 0.0685, 0.3867, 0.3602],
[0.3580, 0.8431, 0.4129, -0.3175, 0.4374, 0.1766],
[-0.3837, -0.2703, -0.3957, -0.1542, -0.1122, 0.0725],
[0.5059, 0.5527, 0.1244, -0.6779, 0.3725, -0.3387],
]]);
let expected_output_without_init_state = TensorData::from([[
[0.0560, -0.2056, 0.2334, 0.0892, 0.3912, 0.3607],
[0.4340, 0.7378, 0.3714, -0.2394, 0.4235, 0.2002],
[-0.3962, -0.2097, -0.3798, 0.0532, -0.2067, 0.1727],
[0.5075, 0.5298, 0.1083, -0.3200, 0.0764, -0.1282],
]]);
//`[2, batch_size=1, hidden_size=3]`
let expected_hn_with_init_state =
TensorData::from([[[0.5059, 0.5527, 0.1244]], [[0.0685, 0.3867, 0.3602]]]);
let expected_hn_without_init_state =
TensorData::from([[[0.5075, 0.5298, 0.1083]], [[0.0892, 0.3912, 0.3607]]]);
let (output_with_init_state, state_with_init_state) =
rnn.forward(input.clone(), Some(RnnState::new(h0)));
let (output_without_init_state, state_without_init_state) = rnn.forward(input, None);
let tolerance = Tolerance::permissive();
output_with_init_state
.to_data()
.assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
output_without_init_state
.to_data()
.assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
state_with_init_state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
state_without_init_state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
}
#[test]
fn display_rnn() {
let config = RnnConfig::new(2, 3, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"Rnn {d_input: 2, d_hidden: 3, bias: true, params: 21}"
);
}
#[test]
fn display_birnn() {
let config = BiRnnConfig::new(2, 3, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"BiRnn {d_input: 2, d_hidden: 3, bias: true, params: 42}"
);
}
#[test]
fn test_rnn_clipping() {
let device = Default::default();
// Create Rnn with clipping enabled
let clip_value = 0.3;
let config = RnnConfig::new(4, 8, true).with_clip(Some(clip_value));
let rnn = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
let (_, state) = rnn.forward(input, None);
// Verify output values are within the clip range
let hidden_state: Vec<f32> = state.hidden.to_data().to_vec().unwrap();
for val in hidden_state {
assert!(
val >= -clip_value as f32 && val <= clip_value as f32,
"Value {} is outside clip range [-{}, {}]",
val,
clip_value,
clip_value
);
}
}
#[test]
fn test_forward_reverse_sequence() {
let device = Default::default();
TestBackend::seed(&device, 0);
// Create RNN with reverse=true to process sequence in reverse order
let config = RnnConfig::new(1, 1, false).with_reverse(true);
let mut rnn = config.init::<TestBackend>(&device);
rnn.gate = create_single_feature_gate_controller(
0.5,
0.0,
1,
1,
false,
Initializer::XavierUniform { gain: 1.0 },
&device,
);
// Create input with 3 timesteps: [0.1, 0.2, 0.3]
// Shape: [batch_size=1, seq_length=3, input_features=1]
let input =
Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
let (output, state) = rnn.forward(input, None);
// With reverse=true and weight=0.5, sequence is processed in reverse:
// t=2 (last): h = tanh(0.5*0.3 + 0.5*0) = tanh(0.15) ≈ 0.1488850
// t=1 (mid): h = tanh(0.5*0.2 + 0.5*0.1488850) ≈ 0.17269433
// t=0 (first): h = tanh(0.5*0.1 + 0.5*0.17269433) ≈ 0.135508
let expected_final_hidden = TensorData::from([[0.135508]]);
let tolerance = Tolerance::default();
state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected_final_hidden, tolerance);
// Verify output tensor has correct shape and matches state at final timestep
assert_eq!(output.dims(), [1, 3, 1]);
}
}

View File

@@ -0,0 +1,101 @@
use burn_core as burn;
use crate::{Linear, LinearConfig, LinearLayout};
use burn::module::{Initializer, Module};
use burn::tensor::{Tensor, backend::Backend};
/// A GateController represents a gate in an LSTM cell. An
/// LSTM cell generally contains three gates: an input gate,
/// forget gate, and output gate. Additionally, cell gate
/// is just used to compute the cell state.
///
/// An Lstm gate is modeled as two linear transformations.
/// The results of these transformations are used to calculate
/// the gate's output.
#[derive(Module, Debug)]
pub struct GateController<B: Backend> {
/// Represents the affine transformation applied to input vector
pub input_transform: Linear<B>,
/// Represents the affine transformation applied to the hidden state
pub hidden_transform: Linear<B>,
}
impl<B: Backend> GateController<B> {
/// Initialize a new [gate_controller](GateController) module.
pub fn new(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &B::Device,
) -> Self {
Self {
input_transform: LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
layout: LinearLayout::Row,
}
.init(device),
hidden_transform: LinearConfig {
d_input: d_output,
d_output,
bias,
initializer,
layout: LinearLayout::Row,
}
.init(device),
}
}
/// Helper function for performing weighted matrix product for a gate and adds
/// bias, if any.
///
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
/// Wx = weight matrix for the connection to input vector X
/// Wh = weight matrix for the connection to hidden state H
/// X = input vector
/// H = hidden state
/// b = bias terms
pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
}
/// Used to initialize a gate controller with known weight layers,
/// allowing for predictable behavior. Used only for testing in
/// lstm.
#[cfg(test)]
pub fn create_with_weights(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
input_record: crate::LinearRecord<B>,
hidden_record: crate::LinearRecord<B>,
) -> Self {
let l1 = LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
layout: LinearLayout::Row,
}
.init(&input_record.weight.device())
.load_record(input_record);
let l2 = LinearConfig {
d_input,
d_output,
bias,
initializer,
layout: LinearLayout::Row,
}
.init(&hidden_record.weight.device())
.load_record(hidden_record);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,922 @@
use burn_core as burn;
use crate::GateController;
use crate::activation::{Activation, ActivationConfig};
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// A LstmState is used to store cell state and hidden state in LSTM.
pub struct LstmState<B: Backend, const D: usize> {
/// The cell state.
pub cell: Tensor<B, D>,
/// The hidden state.
pub hidden: Tensor<B, D>,
}
impl<B: Backend, const D: usize> LstmState<B, D> {
/// Initialize a new [LSTM State](LstmState).
pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
Self { cell, hidden }
}
}
/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).
#[derive(Config, Debug)]
pub struct LstmConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the hidden state.
pub d_hidden: usize,
/// If a bias should be applied during the Lstm transformation.
pub bias: bool,
/// Lstm initializer
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
#[config(default = true)]
pub batch_first: bool,
/// If true, process the sequence in reverse order.
/// This is useful for implementing reverse-direction LSTMs (e.g., ONNX reverse direction).
#[config(default = false)]
pub reverse: bool,
/// Optional cell state clip threshold. If provided, cell state values are clipped
/// to the range `[-clip, +clip]` after each timestep. This can help prevent
/// exploding values during inference.
pub clip: Option<f64>,
/// If true, couples the input and forget gates: `f_t = 1 - i_t`.
/// This reduces the number of parameters and is based on GRU-style simplification.
#[config(default = false)]
pub input_forget: bool,
/// Activation function for the input, forget, and output gates.
/// Default is Sigmoid, which is standard for LSTM gates.
#[config(default = "ActivationConfig::Sigmoid")]
pub gate_activation: ActivationConfig,
/// Activation function for the cell gate (candidate cell state).
/// Default is Tanh, which is standard for LSTM.
#[config(default = "ActivationConfig::Tanh")]
pub cell_activation: ActivationConfig,
/// Activation function applied to the cell state before computing hidden output.
/// Default is Tanh, which is standard for LSTM.
#[config(default = "ActivationConfig::Tanh")]
pub hidden_activation: ActivationConfig,
}
/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
///
/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).
///
/// Should be created with [LstmConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Lstm<B: Backend> {
/// The input gate regulates which information to update and store in the cell state at each time step.
pub input_gate: GateController<B>,
/// The forget gate is used to control which information to discard or keep in the memory cell at each time step.
/// Note: When `input_forget` is true, this gate is not used (forget = 1 - input).
pub forget_gate: GateController<B>,
/// The output gate determines which information from the cell state to output at each time step.
pub output_gate: GateController<B>,
/// The cell gate is used to compute the cell state that stores and carries information through time.
pub cell_gate: GateController<B>,
/// The hidden state of the LSTM.
pub d_hidden: usize,
/// If true, input is `[batch_size, seq_length, input_size]`.
/// If false, input is `[seq_length, batch_size, input_size]`.
pub batch_first: bool,
/// If true, process the sequence in reverse order.
pub reverse: bool,
/// Optional cell state clip threshold.
pub clip: Option<f64>,
/// If true, couples input and forget gates: f_t = 1 - i_t.
pub input_forget: bool,
/// Activation function for gates (input, forget, output).
pub gate_activation: Activation<B>,
/// Activation function for cell gate (candidate cell state).
pub cell_activation: Activation<B>,
/// Activation function for hidden output.
pub hidden_activation: Activation<B>,
}
impl<B: Backend> ModuleDisplay for Lstm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
let bias = self.input_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl LstmConfig {
/// Initialize a new [lstm](Lstm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
let d_output = self.d_hidden;
let new_gate = || {
GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
device,
)
};
Lstm {
input_gate: new_gate(),
forget_gate: new_gate(),
output_gate: new_gate(),
cell_gate: new_gate(),
d_hidden: self.d_hidden,
batch_first: self.batch_first,
reverse: self.reverse,
clip: self.clip,
input_forget: self.input_forget,
gate_activation: self.gate_activation.init(device),
cell_activation: self.cell_activation.init(device),
hidden_activation: self.hidden_activation.init(device),
}
}
}
impl<B: Backend> Lstm<B> {
/// Applies the forward pass on the input tensor. This LSTM implementation
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
///
/// ## Parameters:
/// - batched_input: The input tensor of shape:
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
/// - state: An optional `LstmState` representing the initial cell state and hidden state.
/// Each state tensor has shape `[batch_size, hidden_size]`.
/// If no initial state is provided, these tensors are initialized to zeros.
///
/// ## Returns:
/// - output: A tensor represents the output features of LSTM. Shape:
/// - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true
/// - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false
/// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape
/// `[batch_size, hidden_size]`.
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<LstmState<B, 2>>,
) -> (Tensor<B, 3>, LstmState<B, 2>) {
// Convert to batch-first layout internally if needed
let batched_input = if self.batch_first {
batched_input
} else {
batched_input.swap_dims(0, 1)
};
let device = batched_input.device();
let [batch_size, seq_length, _] = batched_input.dims();
// Process sequence in forward or reverse order based on config
let (output, state) = if self.reverse {
self.forward_iter(
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
state,
batch_size,
seq_length,
&device,
)
} else {
self.forward_iter(
batched_input.iter_dim(1).zip(0..seq_length),
state,
batch_size,
seq_length,
&device,
)
};
// Convert output back to seq-first layout if needed
let output = if self.batch_first {
output
} else {
output.swap_dims(0, 1)
};
(output, state)
}
fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
&self,
input_timestep_iter: I,
state: Option<LstmState<B, 2>>,
batch_size: usize,
seq_length: usize,
device: &B::Device,
) -> (Tensor<B, 3>, LstmState<B, 2>) {
let mut batched_hidden_state =
Tensor::empty([batch_size, seq_length, self.d_hidden], device);
let (mut cell_state, mut hidden_state) = match state {
Some(state) => (state.cell, state.hidden),
None => (
Tensor::zeros([batch_size, self.d_hidden], device),
Tensor::zeros([batch_size, self.d_hidden], device),
),
};
for (input_t, t) in input_timestep_iter {
let input_t = input_t.squeeze_dim(1);
// i(nput)g(ate) tensors
let biased_ig_input_sum = self
.input_gate
.gate_product(input_t.clone(), hidden_state.clone());
let input_values = self.gate_activation.forward(biased_ig_input_sum);
// f(orget)g(ate) tensors - either computed or coupled to input gate
let forget_values = if self.input_forget {
// Coupled mode: f_t = 1 - i_t
input_values.clone().neg().add_scalar(1.0)
} else {
let biased_fg_input_sum = self
.forget_gate
.gate_product(input_t.clone(), hidden_state.clone());
self.gate_activation.forward(biased_fg_input_sum)
};
// o(output)g(ate) tensors
let biased_og_input_sum = self
.output_gate
.gate_product(input_t.clone(), hidden_state.clone());
let output_values = self.gate_activation.forward(biased_og_input_sum);
// c(ell)g(ate) tensors
let biased_cg_input_sum = self
.cell_gate
.gate_product(input_t.clone(), hidden_state.clone());
let candidate_cell_values = self.cell_activation.forward(biased_cg_input_sum);
cell_state = forget_values * cell_state.clone() + input_values * candidate_cell_values;
// Apply cell state clipping if configured
if let Some(clip) = self.clip {
cell_state = cell_state.clamp(-clip, clip);
}
hidden_state = output_values * self.hidden_activation.forward(cell_state.clone());
let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
// store the hidden state for this timestep
batched_hidden_state = batched_hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
unsqueezed_hidden_state.clone(),
);
}
(
batched_hidden_state,
LstmState::new(cell_state, hidden_state),
)
}
}
/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).
#[derive(Config, Debug)]
pub struct BiLstmConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the hidden state.
pub d_hidden: usize,
/// If a bias should be applied during the BiLstm transformation.
pub bias: bool,
/// BiLstm initializer
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
#[config(default = true)]
pub batch_first: bool,
/// Optional cell state clip threshold.
pub clip: Option<f64>,
/// If true, couples the input and forget gates.
#[config(default = false)]
pub input_forget: bool,
/// Activation function for the input, forget, and output gates.
#[config(default = "ActivationConfig::Sigmoid")]
pub gate_activation: ActivationConfig,
/// Activation function for the cell gate (candidate cell state).
#[config(default = "ActivationConfig::Tanh")]
pub cell_activation: ActivationConfig,
/// Activation function applied to the cell state before computing hidden output.
#[config(default = "ActivationConfig::Tanh")]
pub hidden_activation: ActivationConfig,
}
/// The BiLstm module. This implementation is for Bidirectional LSTM.
///
/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).
///
/// Should be created with [BiLstmConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BiLstm<B: Backend> {
/// LSTM for the forward direction.
pub forward: Lstm<B>,
/// LSTM for the reverse direction.
pub reverse: Lstm<B>,
/// The size of the hidden state.
pub d_hidden: usize,
/// If true, input is `[batch_size, seq_length, input_size]`.
/// If false, input is `[seq_length, batch_size, input_size]`.
pub batch_first: bool,
}
impl<B: Backend> ModuleDisplay for BiLstm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self
.forward
.input_gate
.input_transform
.weight
.shape()
.dims();
let bias = self.forward.input_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl BiLstmConfig {
/// Initialize a new [Bidirectional LSTM](BiLstm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
// Internal LSTMs always use batch_first=true; BiLstm handles layout conversion
let base_config = LstmConfig::new(self.d_input, self.d_hidden, self.bias)
.with_initializer(self.initializer.clone())
.with_batch_first(true)
.with_clip(self.clip)
.with_input_forget(self.input_forget)
.with_gate_activation(self.gate_activation.clone())
.with_cell_activation(self.cell_activation.clone())
.with_hidden_activation(self.hidden_activation.clone());
BiLstm {
forward: base_config.clone().init(device),
reverse: base_config.init(device),
d_hidden: self.d_hidden,
batch_first: self.batch_first,
}
}
}
impl<B: Backend> BiLstm<B> {
/// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
///
/// ## Parameters:
/// - batched_input: The input tensor of shape:
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
/// - state: An optional `LstmState` representing the initial cell state and hidden state.
/// Each state tensor has shape `[2, batch_size, hidden_size]`.
/// If no initial state is provided, these tensors are initialized to zeros.
///
/// ## Returns:
/// - output: A tensor represents the output features of LSTM. Shape:
/// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
/// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
/// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and
/// `state.hidden` have the shape `[2, batch_size, hidden_size]`.
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<LstmState<B, 3>>,
) -> (Tensor<B, 3>, LstmState<B, 3>) {
// Convert to batch-first layout internally if needed
let batched_input = if self.batch_first {
batched_input
} else {
batched_input.swap_dims(0, 1)
};
let device = batched_input.clone().device();
let [batch_size, seq_length, _] = batched_input.shape().dims();
let [init_state_forward, init_state_reverse] = match state {
Some(state) => {
let cell_state_forward = state
.cell
.clone()
.slice([0..1, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
let hidden_state_forward = state
.hidden
.clone()
.slice([0..1, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
let cell_state_reverse = state
.cell
.slice([1..2, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
let hidden_state_reverse = state
.hidden
.slice([1..2, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
[
Some(LstmState::new(cell_state_forward, hidden_state_forward)),
Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
]
}
None => [None, None],
};
// forward direction
let (batched_hidden_state_forward, final_state_forward) = self
.forward
.forward(batched_input.clone(), init_state_forward);
// reverse direction
let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
init_state_reverse,
batch_size,
seq_length,
&device,
);
let output = Tensor::cat(
[batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
2,
);
// Convert output back to seq-first layout if needed
let output = if self.batch_first {
output
} else {
output.swap_dims(0, 1)
};
let state = LstmState::new(
Tensor::stack(
[final_state_forward.cell, final_state_reverse.cell].to_vec(),
0,
),
Tensor::stack(
[final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
0,
),
);
(output, state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{LinearRecord, TestBackend};
use burn::module::Param;
use burn::tensor::{Device, Distribution, TensorData};
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[cfg(feature = "std")]
use crate::TestAutodiffBackend;
#[test]
fn test_with_uniform_initializer() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = LstmConfig::new(5, 5, false)
.with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
let lstm = config.init::<TestBackend>(&Default::default());
let gate_to_data =
|gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());
gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());
gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());
gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());
}
/// Test forward pass with simple input vector.
///
/// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
/// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
/// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
/// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
/// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
/// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
#[test]
fn test_forward_single_input_single_feature() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = LstmConfig::new(1, 1, false);
let device = Default::default();
let mut lstm = config.init::<TestBackend>(&device);
fn create_gate_controller(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &Device<TestBackend>,
) -> GateController<TestBackend> {
let record_1 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
};
let record_2 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
};
GateController::create_with_weights(
d_input,
d_output,
bias,
initializer,
record_1,
record_2,
)
}
lstm.input_gate = create_gate_controller(
0.5,
0.0,
1,
1,
false,
Initializer::XavierUniform { gain: 1.0 },
&device,
);
lstm.forget_gate = create_gate_controller(
0.7,
0.0,
1,
1,
false,
Initializer::XavierUniform { gain: 1.0 },
&device,
);
lstm.cell_gate = create_gate_controller(
0.9,
0.0,
1,
1,
false,
Initializer::XavierUniform { gain: 1.0 },
&device,
);
lstm.output_gate = create_gate_controller(
1.1,
0.0,
1,
1,
false,
Initializer::XavierUniform { gain: 1.0 },
&device,
);
// single timestep with single feature
let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
let (output, state) = lstm.forward(input, None);
let expected = TensorData::from([[0.046]]);
let tolerance = Tolerance::default();
state
.cell
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
let expected = TensorData::from([[0.0242]]);
state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
output
.select(0, Tensor::arange(0..1, &device))
.squeeze_dim::<2>(0)
.to_data()
.assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
}
#[test]
fn test_batched_forward_pass() {
let device = Default::default();
let lstm = LstmConfig::new(64, 1024, true).init(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
let (output, state) = lstm.forward(batched_input, None);
assert_eq!(output.dims(), [8, 10, 1024]);
assert_eq!(state.cell.dims(), [8, 1024]);
assert_eq!(state.hidden.dims(), [8, 1024]);
}
#[test]
fn test_batched_forward_pass_batch_of_one() {
let device = Default::default();
let lstm = LstmConfig::new(64, 1024, true).init(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
let (output, state) = lstm.forward(batched_input, None);
assert_eq!(output.dims(), [1, 2, 1024]);
assert_eq!(state.cell.dims(), [1, 1024]);
assert_eq!(state.hidden.dims(), [1, 1024]);
}
#[test]
#[cfg(feature = "std")]
fn test_batched_backward_pass() {
use burn::tensor::Shape;
let device = Default::default();
let lstm = LstmConfig::new(64, 32, true).init(&device);
let shape: Shape = [8, 10, 64].into();
let batched_input =
Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
let (output, _) = lstm.forward(batched_input.clone(), None);
let fake_loss = output;
let grads = fake_loss.backward();
let some_gradient = lstm
.output_gate
.hidden_transform
.weight
.grad(&grads)
.unwrap();
// Asserts that the gradients exist and are non-zero
assert_ne!(
some_gradient
.any()
.into_data()
.iter::<f32>()
.next()
.unwrap(),
0.0
);
}
#[test]
fn test_bidirectional() {
let device = Default::default();
TestBackend::seed(&device, 0);
let config = BiLstmConfig::new(2, 3, true);
let device = Default::default();
let mut lstm = config.init(&device);
fn create_gate_controller<const D1: usize, const D2: usize>(
input_weights: [[f32; D1]; D2],
input_biases: [f32; D1],
hidden_weights: [[f32; D1]; D1],
hidden_biases: [f32; D1],
device: &Device<TestBackend>,
) -> GateController<TestBackend> {
let d_input = input_weights[0].len();
let d_output = input_weights.len();
let input_record = LinearRecord {
weight: Param::from_data(TensorData::from(input_weights), device),
bias: Some(Param::from_data(TensorData::from(input_biases), device)),
};
let hidden_record = LinearRecord {
weight: Param::from_data(TensorData::from(hidden_weights), device),
bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
};
GateController::create_with_weights(
d_input,
d_output,
true,
Initializer::XavierUniform { gain: 1.0 },
input_record,
hidden_record,
)
}
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[
[0.949, -0.861],
[0.892, 0.927],
[-0.173, -0.301],
[-0.081, 0.992],
]]),
&device,
);
let h0 = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
&device,
);
let c0 = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
&device,
);
lstm.forward.input_gate = create_gate_controller(
[[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
[-0.196, 0.354, 0.209],
[
[-0.320, 0.232, -0.165],
[0.093, -0.572, -0.315],
[-0.467, 0.325, 0.046],
],
[0.181, -0.190, -0.245],
&device,
);
lstm.forward.forget_gate = create_gate_controller(
[[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
[0.315, -0.413, -0.041],
[
[0.453, 0.063, 0.561],
[0.211, 0.149, 0.213],
[-0.499, -0.158, 0.068],
],
[-0.431, -0.535, 0.125],
&device,
);
lstm.forward.cell_gate = create_gate_controller(
[[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
[-0.358, 0.282, -0.078],
[
[-0.358, 0.109, 0.139],
[-0.345, 0.091, -0.368],
[-0.508, 0.221, -0.507],
],
[0.502, -0.509, -0.247],
&device,
);
lstm.forward.output_gate = create_gate_controller(
[[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
[-0.227, -0.274, 0.039],
[
[-0.383, 0.449, 0.222],
[-0.357, -0.093, 0.449],
[-0.106, 0.236, 0.360],
],
[-0.361, -0.209, -0.454],
&device,
);
lstm.reverse.input_gate = create_gate_controller(
[[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
[0.540, -0.164, 0.033],
[
[0.159, 0.180, -0.037],
[-0.443, 0.485, -0.488],
[0.098, -0.085, -0.140],
],
[-0.510, 0.105, 0.114],
&device,
);
lstm.reverse.forget_gate = create_gate_controller(
[[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
[0.141, 0.004, 0.055],
[
[-0.005, -0.277, -0.515],
[-0.011, -0.101, -0.365],
[0.426, 0.379, 0.337],
],
[-0.382, 0.331, -0.176],
&device,
);
lstm.reverse.cell_gate = create_gate_controller(
[[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
[-0.206, -0.546, 0.462],
[
[0.449, -0.240, 0.071],
[-0.045, 0.131, 0.124],
[0.138, -0.201, 0.191],
],
[-0.030, 0.211, -0.352],
&device,
);
lstm.reverse.output_gate = create_gate_controller(
[[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
[-0.387, -0.250, 0.066],
[
[-0.030, 0.268, 0.299],
[-0.019, -0.280, -0.314],
[0.466, -0.365, -0.248],
],
[-0.398, -0.199, -0.566],
&device,
);
let expected_output_with_init_state = TensorData::from([[
[0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
[0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
[0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
[-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
]]);
let expected_output_without_init_state = TensorData::from([[
[0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
[-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
[0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
[-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
]]);
let expected_hn_with_init_state = TensorData::from([
[[-0.03420, 0.07774, -0.09774]],
[[-0.15635, -0.03366, -0.05798]],
]);
let expected_cn_with_init_state = TensorData::from([
[[-0.13593, 0.17125, -0.22395]],
[[-0.45425, -0.11206, -0.12908]],
]);
let expected_hn_without_init_state = TensorData::from([
[[-0.04026, 0.07178, -0.10189]],
[[-0.15969, -0.05322, -0.08863]],
]);
let expected_cn_without_init_state = TensorData::from([
[[-0.15839, 0.15923, -0.23569]],
[[-0.47407, -0.17493, -0.19643]],
]);
let (output_with_init_state, state_with_init_state) =
lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
let tolerance = Tolerance::permissive();
output_with_init_state
.to_data()
.assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
output_without_init_state
.to_data()
.assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
state_with_init_state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
state_with_init_state
.cell
.to_data()
.assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);
state_without_init_state
.hidden
.to_data()
.assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
state_without_init_state
.cell
.to_data()
.assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);
}
#[test]
fn display_lstm() {
let config = LstmConfig::new(2, 3, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
);
}
#[test]
fn display_bilstm() {
let config = BiLstmConfig::new(2, 3, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
);
}
}

View File

@@ -0,0 +1,15 @@
mod gate_controller;
/// Basic RNN.
pub mod basic;
/// Gated Recurrent Unit module.
pub mod gru;
/// Long Short-Term Memory module.
pub mod lstm;
pub use basic::*;
pub use gate_controller::*;
pub use gru::*;
pub use lstm::*;

View File

@@ -0,0 +1,581 @@
use burn_core as burn;
use alloc::vec;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::Int;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use core::ops::Range;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
#[derive(Config, Debug)]
pub struct RotaryEncodingConfig {
/// Maximum sequence length of input
pub max_sequence_length: usize,
/// Size of the input embedding or hidden dimension
pub d_model: usize,
/// Scaling factor for frequency computation. Defaults to 10000.0
#[config(default = "10000.0")]
pub theta: f32,
}
impl RotaryEncodingConfig {
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
///
/// # Panics
///
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
self.initialize(|x| x, device)
}
/// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
/// This is useful to apply different RoPE extensions.
///
/// # Panics
///
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
pub fn init_with_frequency_scaling<B: Backend>(
&self,
scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
device: &B::Device,
) -> RotaryEncoding<B> {
self.initialize(scaling, device)
}
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
///
/// # Panics
///
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
fn initialize<B: Backend>(
&self,
scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
device: &B::Device,
) -> RotaryEncoding<B> {
assert_eq!(
self.d_model % 2,
0,
"The input embedding dimension must be even"
);
assert!(
self.theta > 0.0,
"Theta parameter must be positive (default: 10000)."
);
// Calculate the rotation frequencies for positional embeddings based on the formula
// `theta = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
.float()
.div_scalar(self.d_model as f32);
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
// This is done since burn doesn't support exponentiation of scalar to tensor
let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();
let theta = scaling(theta);
let freq_complex =
RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());
RotaryEncoding {
freq_complex,
theta,
start_offset: 0,
}
}
}
/// A module that applies rotary positional encoding to a tensor.
/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
/// absolute positional information with rotation matrix and naturally incorporates
/// explicit relative position dependency in self-attention formulation.
///
/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
///
/// Should be created using [RotaryEncodingConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct RotaryEncoding<B: Backend> {
/// Complex frequency tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
// Essentially a cache of pre-computed RoPE values.
pub freq_complex: Tensor<B, 3>,
/// Frequency vector used to compute/apply the complex rotations.
pub theta: Tensor<B, 1>,
start_offset: usize,
}
impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();
content
.add("d_model", &d_model)
.add("max_sequence_length", &max_sequence_length)
.optional()
}
}
#[allow(clippy::single_range_in_vec_init)]
impl<B: Backend> RotaryEncoding<B> {
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
///
/// # Arguments:
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
/// respectively.
///
/// # Returns:
/// Output tensor with the same shape as input tensor after applying rotary encoding.
///
/// # Panics
/// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
self.apply(x, 0)
}
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
///
/// # Arguments:
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
/// respectively.
/// * `start` - Sequence start position index.
///
/// # Returns:
/// Output tensor with the same shape as input tensor after applying rotary encoding.
///
/// # Panics
/// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
assert!(
D >= 2,
"Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
);
let device = x.device();
let input_shape = x.shape();
// Extract the sequence length and embedding dimension, other dimensions are kept generic
// to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
// Create a dummy tensor with signed ones based on the 2D rotation matrix
// [[cos, -sin], [sin, cos]]
let sign_tensor =
Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
// Rotate input using the frequency tensor. Slice the frequencies till input sequence length
let out: Tensor<B, 4> = x
.reshape([dummy_dim_size, seq_len, d_model / 2, 2])
.matmul(sign_tensor.unsqueeze())
.reshape([dummy_dim_size, seq_len, d_model, 2])
* self
.freq_complex
.clone()
.slice([start..start + seq_len])
.unsqueeze();
// Sum the real and imaginary components to get output tensor and reshape to original shape
out.sum_dim(-1).reshape(input_shape)
}
/// Shifts the pre-computed rotary frequency to cover a new range of positions.
///
/// This method updates the internal frequency tensor `freq_complex` to store
/// the rotary positional encodings for a new window of positions starting at `start`.
pub fn shift(&mut self, start: usize) {
let max_seq_len = self.freq_complex.dims()[0];
assert!(
start > self.start_offset,
"Shift start position must be monotonically increasing"
);
let current_end = self.start_offset + max_seq_len;
if start >= current_end {
// Overwrite the whole buffer
let new_freqs =
Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());
self.freq_complex
.inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));
} else {
// Shift the tail
let num_keep = current_end - start;
let start_rel = start - self.start_offset;
let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);
self.freq_complex
.inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));
// Compute the rest and assign
let new_freqs = Self::compute_rotary_frequencies(
current_end..start + max_seq_len,
self.theta.clone(),
);
self.freq_complex
.inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));
}
self.start_offset = start;
}
/// Computes the positional rotation frequencies (cosine and sine values) used in RoPE.
///
/// # Arguments
/// - `range`: Range of position indices `[start, end)`.
/// - `theta`: 1D tensor of shape `(d_model / 2)` containing base angular frequencies.
///
/// # Returns
/// Tensor of shape `(range.len(), d_model, 2)` containing `[cos, sin]` pairs for each position and frequency.
fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {
let d_model = theta.dims()[0] * 2;
let num_positions = range.end - range.start;
// Generate frequency values for positional embeddings
let frequencies: Tensor<B, 2> =
Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())
.float()
.unsqueeze()
.transpose()
.repeat_dim(1, d_model / 2)
* theta.unsqueeze();
// Convert frequency values to complex numbers (polar form)
let p_cos = frequencies.clone().cos();
let p_sin = frequencies.sin();
Tensor::cat(vec![p_cos, p_sin], 1)
.reshape([num_positions, 2, d_model / 2])
.transpose()
.unsqueeze_dim::<4>(2)
.repeat_dim(2, 2)
.reshape([num_positions, d_model, 2])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_rotary_encoding_forward() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::from_floats(
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
],
&device,
);
// Input = [Batch size, Num of heads, Seq_len, d_model]
let input = input.unsqueeze::<4>();
let output = rotary_encoding.forward(input);
let expected_output = Tensor::<TestBackend, 3>::from_floats(
[
[
[1.0000, 2.0000, 3.0000, 4.0000],
[-2.3473, 7.4492, 6.9197, 8.0696],
],
[
[9.0000, 10.0000, 11.0000, 12.0000],
[-4.7567, 18.5034, 14.8393, 16.1492],
],
],
&device,
);
output
.squeeze_dim::<3>(0)
.to_data()
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
}
#[test]
fn test_rotary_encoding_3d() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::from_floats(
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
],
&device,
);
// Input = [Batch size, Num of heads, Seq_len, d_model]
// let input = input.unsqueeze::<4>();
let output = rotary_encoding.forward(input);
let expected_output = Tensor::<TestBackend, 3>::from_floats(
[
[
[1.0000, 2.0000, 3.0000, 4.0000],
[-2.3473, 7.4492, 6.9197, 8.0696],
],
[
[9.0000, 10.0000, 11.0000, 12.0000],
[-4.7567, 18.5034, 14.8393, 16.1492],
],
],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
}
#[test]
fn test_zero_input_rotary_encoding_forward() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
// Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
let output = rotary_encoding.forward(input);
let expected_output = Tensor::<TestBackend, 3>::from_floats(
[
[
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
],
[
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
],
],
&device,
);
output
.squeeze_dim::<3>(0)
.to_data()
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
}
#[test]
#[should_panic]
fn test_valid_input_hidden_dim() {
// Hidden dimension must be even to be able to split into real and imaginary components
// for rotation
let d_model = 15;
let device = Default::default();
let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
let _output = pe.forward(input);
}
#[test]
fn test_rotary_encoding_frequencies() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
[
[
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
],
[
[5.4030e-01, 8.4147e-01],
[9.9500e-01, 9.9833e-02],
[9.9995e-01, 9.9998e-03],
[9.9999e-01, 9.9999e-04],
],
],
&device,
)
.unsqueeze_dim::<4>(2)
.repeat_dim(2, 2)
.reshape([2, 8, 2]);
rotary_encoding
.freq_complex
.to_data()
.assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
}
fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
// Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
let scale_factor = 8.;
let low_freq_factor = 1.;
let high_freq_factor = 4.;
let old_context_len = 8192.;
let low_freq_wavelen = old_context_len / low_freq_factor;
let high_freq_wavelen = old_context_len / high_freq_factor;
let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
// if wavelen >= high_freq_wavelen
let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
let smooth = wavelen
.clone()
.recip()
.mul_scalar(old_context_len)
.sub_scalar(low_freq_factor)
.div_scalar(high_freq_factor - low_freq_factor);
// (1 - smooth) * freq / scale_factor + smooth * freq
let new_freqs = smooth
.clone()
.neg()
.add_scalar(1.)
.mul(freqs.clone().div_scalar(scale_factor))
.add(smooth.clone().mul(freqs.clone()));
let new_freqs = freqs.clone().mask_where(cond, new_freqs);
// if wavelen > low_freq_wavelen
let cond = wavelen.clone().greater_elem(low_freq_wavelen);
let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
// if wavelen < high_freq_wavelen
let cond = wavelen.lower_elem(high_freq_wavelen);
new_freqs.mask_where(cond, freqs)
}
#[test]
fn test_rotary_encoding_with_frequency_scaling() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(2, 8)
.init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
[
[
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
],
[
[5.4030e-01, 8.4148e-01],
[9.9500e-01, 9.9833e-02],
[9.9995e-01, 9.9998e-03],
[1.0000, 2.1361e-04],
],
],
&device,
)
.unsqueeze_dim::<4>(2)
.repeat_dim(2, 2)
.reshape([2, 8, 2]);
rotary_encoding
.freq_complex
.to_data()
.assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
}
#[test]
fn test_rotary_encoding_shift_full() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
// Input = [Batch size, Num of heads, Seq_len, d_model]
let input = Tensor::<TestBackend, 3>::from_floats(
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
],
&device,
)
.unsqueeze::<4>();
// Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
// as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
// initial position
let expected_output = rotary_encoding.apply(input.clone(), 6);
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
rotary_encoding.shift(6); // start > 4 will perform a full re-compute
let output = rotary_encoding.apply(input, 0);
output
.into_data()
.assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
}
#[test]
fn test_rotary_encoding_shift() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
// Input = [Batch size, Num of heads, Seq_len, d_model]
let input = Tensor::<TestBackend, 3>::from_floats(
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
],
&device,
)
.unsqueeze::<4>();
// Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
// as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
// initial position
let expected_output = rotary_encoding.apply(input.clone(), 2);
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
rotary_encoding.shift(2); // start < 4 will shift the (current_end - start) freqs and compute the rest
let output = rotary_encoding.apply(input, 0);
output
.into_data()
.assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
}
#[test]
fn test_rotary_encoding_shift_multiple() {
let device = Default::default();
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
rotary_encoding.shift(2);
rotary_encoding.shift(5);
}
#[test]
#[should_panic = "Shift start position must be monotonically increasing"]
fn test_rotary_encoding_shift_should_increase() {
let device = Default::default();
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
rotary_encoding.shift(6);
rotary_encoding.shift(4); // should be monotonically increasing
}
#[test]
fn display() {
let config = RotaryEncodingConfig::new(10, 4);
let pe = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{pe}"),
"RotaryEncoding {d_model: 4, max_sequence_length: 10}"
);
}
}

View File

@@ -0,0 +1,573 @@
use burn_core as burn;
use alloc::vec::Vec;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::{Bool, Tensor, backend::Backend};
use crate::activation::ActivationConfig;
use crate::cache::TensorCache;
use crate::{
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
};
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).
#[derive(Config, Debug)]
pub struct TransformerDecoderConfig {
/// The size of the model.
pub d_model: usize,
/// The size of the position-wise feed-forward network.
pub d_ff: usize,
/// The number of attention heads.
pub n_heads: usize,
/// The number of layers.
pub n_layers: usize,
/// The dropout rate. Default: 0.1
#[config(default = 0.1)]
pub dropout: f64,
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
/// The activation function used in the position-wise feed-forward network. Default: Gelu
#[config(default = "ActivationConfig::Gelu")]
pub activation: ActivationConfig,
/// The epsilon value for layer normalization. Default: 1e-5
#[config(default = 1e-5)]
pub layer_norm_eps: f64,
}
/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
///
/// # Params
///
/// - layers: transformer decoder layers with `d_model` input and output features.
///
/// Should be created using [TransformerDecoderConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct TransformerDecoder<B: Backend> {
/// Transformer decoder layers.
pub layers: Vec<TransformerDecoderLayer<B>>,
/// The size of the model.
pub d_model: usize,
/// The size of the position-wise feed-forward network.
pub d_ff: usize,
/// The number of attention heads.
pub n_heads: usize,
/// The number of layers.
pub n_layers: usize,
/// The dropout rate. Default: 0.1
pub dropout: f64,
/// Layer norm will be applied first instead of after the other modules.
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("d_ff", &self.d_ff)
.add("n_heads", &self.n_heads)
.add("n_layers", &self.n_layers)
.add("dropout", &self.dropout)
.add("norm_first", &self.norm_first)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
impl TransformerDecoderConfig {
/// Initialize a new [Transformer Decoder](TransformerDecoder) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
let layers = (0..self.n_layers)
.map(|_| TransformerDecoderLayer::new(self, device))
.collect::<Vec<_>>();
TransformerDecoder {
layers,
d_model: self.d_model,
d_ff: self.d_ff,
n_heads: self.n_heads,
n_layers: self.n_layers,
dropout: self.dropout,
norm_first: self.norm_first,
quiet_softmax: self.quiet_softmax,
}
}
}
/// [Transformer Decoder](TransformerDecoder) forward pass input argument.
#[derive(Debug)]
pub struct TransformerDecoderInput<B: Backend> {
target: Tensor<B, 3>,
target_mask_pad: Option<Tensor<B, 2, Bool>>,
target_mask_attn: Option<Tensor<B, 3, Bool>>,
memory: Tensor<B, 3>,
memory_mask_pad: Option<Tensor<B, 2, Bool>>,
memory_mask_attn: Option<Tensor<B, 3, Bool>>,
}
impl<B: Backend> TransformerDecoderInput<B> {
/// Create a [transformer decoder](TransformerDecoder) input argument.
pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
Self {
target,
target_mask_pad: None,
target_mask_attn: None,
memory,
memory_mask_pad: None,
memory_mask_attn: None,
}
}
/// Register the memory padding mask.
pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.memory_mask_pad = Some(mask_pad);
self
}
/// Register the memory attention mask.
pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.memory_mask_attn = Some(mask_attn);
self
}
/// Register the target padding mask.
pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.target_mask_pad = Some(mask_pad);
self
}
/// Register the target attention mask.
pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.target_mask_attn = Some(mask_attn);
self
}
}
/// [Transformer Decoder](TransformerDecoder) layer module.
#[derive(Module, Debug)]
pub struct TransformerDecoderLayer<B: Backend> {
/// Cross-attention module.
pub cross_attn: MultiHeadAttention<B>,
/// Self-attention module.
pub self_attn: MultiHeadAttention<B>,
/// Position-wise feed-forward module.
pub pwff: PositionWiseFeedForward<B>,
/// First layer norm.
pub norm_1: LayerNorm<B>,
/// Second layer norm.
pub norm_2: LayerNorm<B>,
/// Third layer norm.
pub norm_3: LayerNorm<B>,
/// Dropout.
pub dropout: Dropout,
/// Whether to apply norm first.
pub norm_first: bool,
}
/// Autoregressive cache for a single [Transformer Decoder Layer](TransformerDecoderLayer).
pub struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
/// Cross-attention cache.
pub cross_attn: MhaCache<B>,
/// Self-attention cache.
pub self_attn: MhaCache<B>,
/// Position-wise feed-forward cache.
pub pwff: TensorCache<B, 3>,
/// First layer norm cache.
pub norm_1: TensorCache<B, 3>,
/// Second layer norm cache.
pub norm_2: TensorCache<B, 3>,
/// Third layer norm cache.
pub norm_3: TensorCache<B, 3>,
}
impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
/// Create an empty cache.
pub fn empty() -> Self {
Self {
cross_attn: MhaCache::autoregressive_cross_attention(),
self_attn: MhaCache::autoregressive(),
pwff: TensorCache::empty(),
norm_1: TensorCache::empty(),
norm_2: TensorCache::empty(),
norm_3: TensorCache::empty(),
}
}
}
/// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer.
///
/// To be used during inference when decoding tokens.
pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
}
impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
fn empty(num_layers: usize) -> Self {
Self {
layers: (0..num_layers)
.map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
.collect(),
}
}
}
impl<B: Backend> TransformerDecoderLayer<B> {
/// Create a new [TransformerDecoderLayer](TransformerDecoderLayer).
pub fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init(device);
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init(device);
let norm_1 = LayerNormConfig::new(config.d_model)
.with_epsilon(config.layer_norm_eps)
.init(device);
let norm_2 = LayerNormConfig::new(config.d_model)
.with_epsilon(config.layer_norm_eps)
.init(device);
let norm_3 = LayerNormConfig::new(config.d_model)
.with_epsilon(config.layer_norm_eps)
.init(device);
let dropout = DropoutConfig::new(config.dropout).init();
let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_activation(config.activation.clone())
.init(device);
Self {
cross_attn,
self_attn,
norm_1,
norm_2,
norm_3,
pwff,
dropout,
norm_first: config.norm_first,
}
}
/// Applies the TransformerDecoder forward pass to the input tensor.
pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
// Self attention residual path.
let x = input.target;
let mut residual_path = x.clone();
// Normalize.
if self.norm_first {
residual_path = self.norm_3.forward(residual_path);
}
// Self attention.
let mut self_attn_input = MhaInput::self_attn(residual_path);
if let Some(mask_pad) = &input.target_mask_pad {
self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.target_mask_attn {
self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self.self_attn.forward(self_attn_input).context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Cross attention residual path.
// Normalize.
let residual_path = if self.norm_first {
self.norm_1.forward(x.clone())
} else {
x = self.norm_1.forward(x);
x.clone()
};
// Cross attention.
let mut cross_attn_input =
MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
if let Some(mask_pad) = &input.memory_mask_pad {
cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.memory_mask_attn {
cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self.cross_attn.forward(cross_attn_input).context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Feed forward residual path.
// Normalize.
let residual_path = if self.norm_first {
self.norm_2.forward(x.clone())
} else {
x = self.norm_2.forward(x);
x.clone()
};
let residual_path = self.pwff.forward(residual_path);
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Main path.
// Normalize.
if !self.norm_first {
x = self.norm_3.forward(x)
}
input.target = x;
input
}
/// Applies the forward pass using an autoregressive cache.
pub fn forward_autoregressive_inference(
&self,
mut input: TransformerDecoderInput<B>,
cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
) -> TransformerDecoderInput<B> {
// Self attention residual path.
let x = input.target;
let mut residual_path = x.clone();
// Normalize.
if self.norm_first {
residual_path = cache
.norm_3
.forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
}
// Self attention.
let mut self_attn_input = MhaInput::self_attn(residual_path);
if let Some(mask_pad) = &input.target_mask_pad {
self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.target_mask_attn {
self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self
.self_attn
.forward_cache(self_attn_input, &mut cache.self_attn)
.context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Cross attention residual path.
// Normalize.
let residual_path = if self.norm_first {
cache
.norm_1
.forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
} else {
x = cache
.norm_1
.forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
x.clone()
};
// Cross attention.
let mut cross_attn_input =
MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
if let Some(mask_pad) = &input.memory_mask_pad {
cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.memory_mask_attn {
cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self
.cross_attn
.forward_cache(cross_attn_input, &mut cache.cross_attn)
.context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Feed forward residual path.
// Normalize.
let residual_path = if self.norm_first {
cache
.norm_2
.forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
} else {
x = cache
.norm_2
.forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
x.clone()
};
let residual_path = cache
.pwff
.forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Main path.
// Normalize.
if !self.norm_first {
x = cache
.norm_3
.forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
}
input.target = x;
input
}
}
impl<B: Backend> TransformerDecoder<B> {
/// Applies the forward pass.
pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
for layer in self.layers.iter() {
input = layer.forward(input);
}
input.target
}
/// Applies the forward pass on the input using autoregressive cache.
pub fn forward_autoregressive_inference(
&self,
mut input: TransformerDecoderInput<B>,
cache: &mut TransformerDecoderAutoregressiveCache<B>,
) -> Tensor<B, 3> {
for i in 0..self.layers.len() {
let layer = self.layers.get(i).unwrap();
let cache = cache.layers.get_mut(i).unwrap();
input = layer.forward_autoregressive_inference(input, cache);
}
input.target
}
/// Create an empty autoregressive cache.
pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
TransformerDecoderAutoregressiveCache::empty(self.layers.len())
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Device;
use super::*;
use crate::{TestBackend, attention::generate_autoregressive_mask};
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_autoregressive_norm_last() {
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
let device = Default::default();
TestBackend::seed(&device, 0);
test_autoregressive(
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
.with_norm_first(false),
)
}
#[test]
fn test_autoregressive_norm_first() {
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
let device = Default::default();
TestBackend::seed(&device, 0);
test_autoregressive(
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
)
}
fn test_autoregressive(config: TransformerDecoderConfig) {
let device: Device<TestBackend> = Default::default();
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
let transformer = config.init::<TestBackend>(&device);
let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
// Normal forward using masking.
let output_1 = transformer.forward(input);
// Forward using the autoregressive cache.
let mut output_2 = Vec::new();
let mut cache = transformer.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
let next_tok = transformer // Greedy sampling
.forward_autoregressive_inference(input, &mut cache)
.slice([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
// Should produce the same tokens.
let tolerance = Tolerance::rel_abs(5e-3, 1e-4);
output_1
.into_data()
.assert_approx_eq::<FT>(&output_2.into_data(), tolerance);
}
#[test]
fn display() {
let config = TransformerDecoderConfig::new(2, 4, 2, 3);
let transformer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{transformer}"),
"TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
);
}
}

View File

@@ -0,0 +1,489 @@
use burn_core as burn;
use alloc::vec::Vec;
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::{
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
activation::ActivationConfig,
attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
cache::TensorCache,
};
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::{Bool, Tensor, backend::Backend};
/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).
#[derive(Config, Debug)]
pub struct TransformerEncoderConfig {
/// The size of the model.
pub d_model: usize,
/// The size of the position-wise feed-forward network.
pub d_ff: usize,
/// The number of attention heads.
pub n_heads: usize,
/// The number of layers.
pub n_layers: usize,
/// The dropout rate. Default: 0.1
#[config(default = 0.1)]
pub dropout: f64,
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
/// The activation function used in the position-wise feed-forward network. Default: Gelu
#[config(default = "ActivationConfig::Gelu")]
pub activation: ActivationConfig,
/// The epsilon value for layer normalization. Default: 1e-5
#[config(default = 1e-5)]
pub layer_norm_eps: f64,
}
/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
///
/// # Params
///
/// - layers: transformer encoder layers with `d_model` input and output features.
///
/// Should be created using [TransformerEncoderConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct TransformerEncoder<B: Backend> {
/// The transformer encoder layers.
pub layers: Vec<TransformerEncoderLayer<B>>,
/// The size of the model.
pub d_model: usize,
/// The size of the position-wise feed-forward network.
pub d_ff: usize,
/// The number of attention heads.
pub n_heads: usize,
/// The number of layers.
pub n_layers: usize,
/// The dropout rate. Default: 0.1
pub dropout: f64,
/// Layer norm will be applied first instead of after the other modules.
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("d_ff", &self.d_ff)
.add("n_heads", &self.n_heads)
.add("n_layers", &self.n_layers)
.add("dropout", &self.dropout)
.add("norm_first", &self.norm_first)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
#[derive(Debug)]
pub struct TransformerEncoderInput<B: Backend> {
tensor: Tensor<B, 3>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
}
impl<B: Backend> TransformerEncoderInput<B> {
/// Create a [transformer encoder](TransformerEncoder) input argument.
pub fn new(tensor: Tensor<B, 3>) -> Self {
Self {
tensor,
mask_pad: None,
mask_attn: None,
}
}
/// Register the padding mask.
pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.mask_pad = Some(mask_pad);
self
}
/// Register the attention mask.
pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.mask_attn = Some(mask_attn);
self
}
}
impl TransformerEncoderConfig {
/// Initialize a new [transformer encoder](TransformerEncoder) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
let layers = (0..self.n_layers)
.map(|_| TransformerEncoderLayer::new(self, device))
.collect::<Vec<_>>();
TransformerEncoder {
layers,
d_model: self.d_model,
d_ff: self.d_ff,
n_heads: self.n_heads,
n_layers: self.n_layers,
dropout: self.dropout,
norm_first: self.norm_first,
quiet_softmax: self.quiet_softmax,
}
}
}
impl<B: Backend> TransformerEncoder<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - tensor: `[batch_size, seq_length, d_model]`
/// - output: `[batch_size, seq_length, d_model]`
pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
let mut x = input.tensor;
for layer in self.layers.iter() {
x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
}
x
}
/// Applies the forward pass on the input tensor using autoregressive cache.
///
/// # Shapes
///
/// - tensor: `[batch_size, seq_length, d_model]`
/// - output: `[batch_size, seq_length, d_model]`
pub fn forward_autoregressive_inference(
&self,
input: TransformerEncoderInput<B>,
cache: &mut TransformerEncoderAutoregressiveCache<B>,
) -> Tensor<B, 3> {
let mut x = input.tensor;
for i in 0..self.layers.len() {
let layer = self.layers.get(i).unwrap();
let cache = cache.layers.get_mut(i).unwrap();
x = layer.forward_autoregressive_inference(
x,
input.mask_pad.clone(),
input.mask_attn.clone(),
cache,
);
}
x
}
/// Create an empty autoregressive cache.
pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
TransformerEncoderAutoregressiveCache::empty(self.layers.len())
}
}
/// Transformer encoder layer module.
#[derive(Module, Debug)]
pub struct TransformerEncoderLayer<B: Backend> {
/// Multi-head self-attention sub-layer.
pub mha: MultiHeadAttention<B>,
/// Position-wise feed-forward sub-layer.
pub pwff: PositionWiseFeedForward<B>,
/// Layer normalization applied around the feed-forward sub-layer.
pub norm_1: LayerNorm<B>,
/// Layer normalization applied around the attention sub-layer.
pub norm_2: LayerNorm<B>,
/// Dropout module applied to residual connections.
pub dropout: Dropout,
/// If `true`, apply layer normalization before sub-layers (pre-norm),
/// otherwise apply it after (post-norm).
pub norm_first: bool,
}
impl<B: Backend> TransformerEncoderLayer<B> {
/// Create a new transformer encoder layer from the given configuration.
pub fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init(device);
let norm_1 = LayerNormConfig::new(config.d_model)
.with_epsilon(config.layer_norm_eps)
.init(device);
let norm_2 = LayerNormConfig::new(config.d_model)
.with_epsilon(config.layer_norm_eps)
.init(device);
let dropout = DropoutConfig::new(config.dropout).init();
let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_activation(config.activation.clone())
.init(device);
Self {
mha,
norm_1,
norm_2,
pwff,
dropout,
norm_first: config.norm_first,
}
}
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[batch_size, seq_length, d_model]`
/// - output: `[batch_size, seq_length, d_model]`
pub fn forward(
&self,
input: Tensor<B, 3>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
) -> Tensor<B, 3> {
// Multi-head attention residual path.
let x = input;
let mut residual_path = x.clone();
// Normalize.
if self.norm_first {
residual_path = self.norm_2.forward(residual_path)
}
// Multi-head attention.
let mut input_mhs = MhaInput::self_attn(residual_path);
if let Some(mask_pad) = mask_pad {
input_mhs = input_mhs.mask_pad(mask_pad);
}
if let Some(mask_attn) = mask_attn {
input_mhs = input_mhs.mask_attn(mask_attn);
}
let residual_path = self.mha.forward(input_mhs).context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Feed forward residual path.
// Normalize.
let residual_path = if self.norm_first {
self.norm_1.forward(x.clone())
} else {
x = self.norm_1.forward(x);
x.clone()
};
// Feed forward.
let residual_path = self.pwff.forward(residual_path);
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Main path.
// Normalize.
if !self.norm_first {
x = self.norm_2.forward(x)
}
x
}
/// Applies the forward pass using an autoregressive cache.
pub fn forward_autoregressive_inference(
&self,
input: Tensor<B, 3>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
) -> Tensor<B, 3> {
// Multi-head attention residual path.
let x = input;
let mut residual_path = x.clone();
// Normalize.
if self.norm_first {
residual_path = cache
.norm_2
.forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
}
// Multi-head attention.
let mut input_mhs = MhaInput::self_attn(residual_path);
if let Some(mask_pad) = mask_pad {
input_mhs = input_mhs.mask_pad(mask_pad);
}
if let Some(mask_attn) = mask_attn {
input_mhs = input_mhs.mask_attn(mask_attn);
}
let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Feed forward residual path.
// Normalize.
let residual_path = if self.norm_first {
cache
.norm_1
.forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
} else {
x = cache
.norm_1
.forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
x.clone()
};
// Feed forward.
let residual_path = cache
.pwff
.forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
// Main path.
// Normalize.
if !self.norm_first {
x = cache
.norm_2
.forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
}
x
}
}
/// Autoregressive cache for a single [Transformer Encoder Layer](TransformerEncoderLayer).
pub struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
/// Multi-head attention cache.
pub mha: MhaCache<B>,
/// Position-wise feed-forward cache.
pub pwff: TensorCache<B, 3>,
/// First layer norm cache.
pub norm_1: TensorCache<B, 3>,
/// Second layer norm cache.
pub norm_2: TensorCache<B, 3>,
}
impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
/// Create an empty cache.
pub fn empty() -> Self {
Self {
mha: MhaCache::autoregressive(),
pwff: TensorCache::empty(),
norm_1: TensorCache::empty(),
norm_2: TensorCache::empty(),
}
}
}
/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
///
/// To be used during inference when decoding tokens.
pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
}
impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
fn empty(num_layers: usize) -> Self {
Self {
layers: (0..num_layers)
.map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
.collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TestBackend, attention::generate_autoregressive_mask};
use burn::tensor::Distribution;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_autoregressive_norm_last() {
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
test_autoregressive(
TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
.with_norm_first(false),
)
}
#[test]
fn test_autoregressive_norm_first() {
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
test_autoregressive(
TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
)
}
fn test_autoregressive(config: TransformerEncoderConfig) {
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
let device = Default::default();
let transformer = config.init(&device);
let tensor = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
let output_1 = transformer.forward(input);
let mut output_2 = Vec::new();
let mut cache = transformer.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
let input = TransformerEncoderInput::new(tensor.clone());
let next_tok = transformer
.forward_autoregressive_inference(input, &mut cache)
.slice([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
output_1
.into_data()
.assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());
}
#[test]
fn display() {
let config = TransformerEncoderConfig::new(2, 4, 2, 3);
let transformer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{transformer}"),
"TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
);
}
}

View File

@@ -0,0 +1,7 @@
mod decoder;
mod encoder;
mod pwff;
pub use decoder::*;
pub use encoder::*;
pub use pwff::*;

View File

@@ -0,0 +1,117 @@
use burn_core as burn;
use crate::activation::{Activation, ActivationConfig};
use crate::{Dropout, DropoutConfig, Linear, LinearConfig};
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
use burn::tensor::{Tensor, backend::Backend};
/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).
#[derive(Config, Debug)]
pub struct PositionWiseFeedForwardConfig {
/// The size of the input and output features.
pub d_model: usize,
/// The size of the hidden inner features.
pub d_ff: usize,
/// The dropout rate. Default: 0.1
#[config(default = 0.1)]
pub dropout: f64,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
/// The activation function used between the two linear layers. Default: Gelu
#[config(default = "ActivationConfig::Gelu")]
pub activation: ActivationConfig,
}
/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).
///
/// # Params
///
/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.
/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
///
/// `FFN(x) = max(0, xW1 + b1)W2 + b2`
///
/// Should be created using [PositionWiseFeedForwardConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
/// Linear layer with `d_model` input features and `d_ff` output features.
pub linear_inner: Linear<B>,
/// Linear layer with `d_ff` input features and `d_model` output features.
pub linear_outer: Linear<B>,
/// Dropout layer.
pub dropout: Dropout,
/// Activation function.
pub activation: Activation<B>,
}
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model, dff] = self.linear_inner.weight.shape().dims();
content
.add("d_model", &d_model)
.add("d_ff", &dff)
.add("prob", &self.dropout.prob)
.optional()
}
}
impl PositionWiseFeedForwardConfig {
/// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
PositionWiseFeedForward {
linear_inner: LinearConfig::new(self.d_model, self.d_ff)
.with_initializer(self.initializer.clone())
.init(device),
linear_outer: LinearConfig::new(self.d_ff, self.d_model)
.with_initializer(self.initializer.clone())
.init(device),
dropout: DropoutConfig::new(self.dropout).init(),
activation: self.activation.init(device),
}
}
}
impl<B: Backend> PositionWiseFeedForward<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - tensor: `[batch_size, seq_length, d_model]`
/// - output: `[batch_size, seq_length, d_model]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.activation.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn display() {
let config = PositionWiseFeedForwardConfig::new(2, 4);
let pwff = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{pwff}"),
"PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
);
}
}

View File

@@ -0,0 +1,104 @@
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::module::unfold4d;
use burn::tensor::ops::UnfoldOptions;
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
#[derive(Config, Debug)]
pub struct Unfold4dConfig {
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
#[config(default = "[1, 1]")]
pub stride: [usize; 2],
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// The padding configuration.
#[config(default = "[0, 0]")]
pub padding: [usize; 2],
}
/// Four-dimensional unfolding.
///
/// Should be created with [Unfold4dConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Unfold4d {
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
pub stride: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// The padding configuration.
pub padding: [usize; 2],
}
impl ModuleDisplay for Unfold4d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("dilation", &alloc::format!("{:?}", &self.dilation))
.add("padding", &alloc::format!("{:?}", &self.padding))
.optional()
}
}
impl Unfold4dConfig {
/// Initializes a new [Unfold4d] module.
pub fn init(&self) -> Unfold4d {
Unfold4d {
kernel_size: self.kernel_size,
stride: self.stride,
dilation: self.dilation,
padding: self.padding,
}
}
}
impl Unfold4d {
/// Applies the forward pass on the input tensor.
///
/// See [unfold4d](burn::tensor::module::unfold4d) for more information.
///
/// # Shapes
///
/// input: `[batch_size, channels_in, height, width]`
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
unfold4d(
input,
self.kernel_size,
UnfoldOptions::new(self.stride, self.padding, self.dilation),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = Unfold4dConfig::new([3, 3]);
let unfold = config.init();
assert_eq!(
alloc::format!("{unfold}"),
"Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}"
);
}
}

View File

@@ -0,0 +1,247 @@
use burn_core as burn;
use burn::config::Config;
/// Calculate asymmetric padding for "same" convolution.
/// Returns (start_padding, end_padding) where start is applied first (top/left).
/// For odd total padding, the extra pad goes to the end (bottom/right) following ONNX convention.
fn calculate_same_padding(kernel_size: usize, stride: usize, size_in: usize) -> (usize, usize) {
let size_out = size_in.div_ceil(stride); // ceil division for same padding
let total_padding = if size_out > 0 {
let needed = (size_out - 1) * stride + kernel_size;
needed.saturating_sub(size_in)
} else {
0
};
let pad_start = total_padding / 2;
let pad_end = total_padding - pad_start;
(pad_start, pad_end)
}
/// Padding configuration for 1D operators.
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig1d {
/// Dynamically calculates padding to ensure output size matches input size.
Same,
/// No padding applied.
Valid,
/// Applies explicit padding values.
/// Format: (left, right)
/// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`).
Explicit(usize, usize),
}
impl PaddingConfig1d {
/// Calculate padding as (left, right) pair for 1D operations.
/// For `Same` padding, this computes the actual asymmetric padding if needed.
pub(crate) fn calculate_padding_1d_pair(
&self,
length: usize,
kernel_size: usize,
stride: usize,
) -> (usize, usize) {
match self {
Self::Valid => (0, 0),
Self::Same => calculate_same_padding(kernel_size, stride, length),
Self::Explicit(left, right) => (*left, *right),
}
}
}
/// Padding configuration for 2D operators.
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig2d {
/// Dynamically calculates padding to preserve input dimensions in output.
Same,
/// No padding applied.
Valid,
/// Applies explicit padding values.
/// Format: (top, left, bottom, right)
/// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`).
Explicit(usize, usize, usize, usize),
}
impl PaddingConfig2d {
/// Calculate padding as ((top, bottom), (left, right)) pairs for 2D operations.
/// For `Same` padding, this computes the actual asymmetric padding if needed.
pub(crate) fn calculate_padding_2d_pairs(
&self,
height: usize,
width: usize,
kernel_size: &[usize; 2],
stride: &[usize; 2],
) -> ((usize, usize), (usize, usize)) {
match self {
Self::Valid => ((0, 0), (0, 0)),
Self::Same => {
let (top, bottom) = calculate_same_padding(kernel_size[0], stride[0], height);
let (left, right) = calculate_same_padding(kernel_size[1], stride[1], width);
((top, bottom), (left, right))
}
Self::Explicit(top, left, bottom, right) => ((*top, *bottom), (*left, *right)),
}
}
/// Calculate symmetric padding for 2D operations.
/// Returns padding values [height, width] (same for both sides).
/// Panics if asymmetric padding is detected.
pub(crate) fn calculate_padding_2d(
&self,
height: usize,
width: usize,
kernel_size: &[usize; 2],
stride: &[usize; 2],
) -> [usize; 2] {
let ((top, bottom), (left, right)) =
self.calculate_padding_2d_pairs(height, width, kernel_size, stride);
if top != bottom || left != right {
panic!("Asymmetric padding should be handled via calculate_padding_2d_pairs()")
}
[top, left]
}
}
/// Padding configuration for 3D operators.
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig3d {
/// Dynamically calculates padding to preserve input dimensions in output.
Same,
/// No padding applied.
Valid,
/// Applies explicit symmetric padding values.
/// Format: (depth, height, width) — same padding on both sides of each dimension.
Explicit(usize, usize, usize),
}
impl PaddingConfig3d {
/// Calculate symmetric padding for 3D operations.
/// Returns padding values [depth, height, width] (same for both sides).
pub(crate) fn calculate_padding_3d(
&self,
depth: usize,
height: usize,
width: usize,
kernel_size: &[usize; 3],
stride: &[usize; 3],
) -> [usize; 3] {
match self {
Self::Valid => [0, 0, 0],
Self::Same => {
let (front, back) = calculate_same_padding(kernel_size[0], stride[0], depth);
let (top, bottom) = calculate_same_padding(kernel_size[1], stride[1], height);
let (left, right) = calculate_same_padding(kernel_size[2], stride[2], width);
if front != back || top != bottom || left != right {
panic!(
"Asymmetric 3D 'Same' padding is not supported. \
Use odd kernel sizes for symmetric padding."
)
}
[front, top, left]
}
Self::Explicit(depth, height, width) => [*depth, *height, *width],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
// ==================== PaddingConfig1d Tests ====================
#[test]
fn test_padding_config_1d_calculate_pair_valid() {
let padding = PaddingConfig1d::Valid;
assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (0, 0));
}
#[test]
fn test_padding_config_1d_calculate_pair_explicit() {
let padding = PaddingConfig1d::Explicit(1, 2);
assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 2));
}
#[test]
fn test_padding_config_1d_calculate_pair_same() {
let padding = PaddingConfig1d::Same;
// kernel=3, stride=1, length=10: total=2, start=1, end=1
assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 1));
}
// ==================== PaddingConfig2d Tests ====================
#[test]
fn test_padding_config_2d_calculate_pairs_valid() {
let padding = PaddingConfig2d::Valid;
assert_eq!(
padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
((0, 0), (0, 0))
);
}
#[test]
fn test_padding_config_2d_calculate_pairs_explicit() {
let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
assert_eq!(
padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
((1, 3), (2, 4))
);
}
#[test]
fn test_padding_config_2d_calculate_symmetric_valid() {
let padding = PaddingConfig2d::Valid;
assert_eq!(
padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
[0, 0]
);
}
#[test]
fn test_padding_config_2d_calculate_symmetric_explicit() {
let padding = PaddingConfig2d::Explicit(2, 3, 2, 3);
assert_eq!(
padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
[2, 3]
);
}
#[test]
#[should_panic(
expected = "Asymmetric padding should be handled via calculate_padding_2d_pairs"
)]
fn test_padding_config_2d_calculate_symmetric_asymmetric_panics() {
let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]);
}
// ==================== PaddingConfig3d Tests ====================
#[test]
fn test_padding_config_3d_calculate_valid() {
let padding = PaddingConfig3d::Valid;
assert_eq!(
padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
[0, 0, 0]
);
}
#[test]
fn test_padding_config_3d_calculate_explicit() {
let padding = PaddingConfig3d::Explicit(1, 2, 3);
assert_eq!(
padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
[1, 2, 3]
);
}
#[test]
fn test_padding_config_3d_calculate_same_odd_kernel() {
let padding = PaddingConfig3d::Same;
// kernel=3, stride=1: total=2, symmetric (1,1) per dim
assert_eq!(
padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
[1, 1, 1]
);
}
}

View File

@@ -0,0 +1,167 @@
use burn_core as burn;
use burn::module::{Module, Quantizer};
use burn::tensor::{
Device, Distribution, Tensor, Tolerance,
ops::{FloatElem, QuantizedTensor},
quantization::{
Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantScheme, QuantValue,
},
};
use burn_nn::{
Linear, LinearConfig,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
};
#[cfg(all(
test,
not(feature = "test-wgpu"),
not(feature = "test-cuda"),
not(feature = "test-rocm")
))]
pub type B = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type B = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type B = burn_cuda::Cuda;
#[cfg(all(test, feature = "test-rocm"))]
/// Backend for test cases
pub type B = burn_rocm::Rocm;
fn should_quantize_module<M: Module<B>, const D: usize, F: Fn(&M) -> Tensor<B, D>>(
module: M,
scheme: QuantScheme,
func: F,
tolerance: Tolerance<FloatElem<B>>,
) {
let result = func(&module);
let calibration = Calibration::MinMax;
let mut quantizer = Quantizer {
calibration,
scheme,
};
let q_module = module.quantize_weights(&mut quantizer);
let q_result = func(&q_module);
result
.into_data()
.assert_approx_eq::<f32>(&q_result.into_data(), tolerance);
}
#[test]
fn should_quantize_transformer() {
let device: Device<B> = Default::default();
let transformer: TransformerEncoder<B> =
TransformerEncoderConfig::new(128, 256, 2, 2).init(&device);
let signal = Tensor::random([2, 32, 128], Distribution::Default, &device);
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::block([32]))
.with_param(QuantParam::F32);
should_quantize_module(
transformer,
scheme,
|tr| tr.forward(TransformerEncoderInput::new(signal.clone())),
Tolerance::rel_abs(1e-2, 2e-2), // slightly higher abs tolerance (permissive: 1e-2)
);
}
#[test]
fn should_quantize_linear_128_256() {
let device: Device<B> = Default::default();
let transformer: Linear<B> = LinearConfig::new(128, 256).with_bias(false).init(&device);
let signal = Tensor::<B, 2>::random([1, 128], Distribution::Default, &device);
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::Tensor)
.with_param(QuantParam::F32);
should_quantize_module(
transformer,
scheme,
|tr| tr.forward(signal.clone()),
Tolerance::permissive(),
);
}
#[test]
fn should_quantize_linear() {
let device: Device<B> = Default::default();
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);
// Default scheme should select supported QuantStore default
// TODO: set native if dtype is supported by the test backend
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::Tensor)
// .with_store(QuantStore::Native)
.with_param(QuantParam::F32);
should_quantize_module(
transformer,
scheme,
|tr| tr.forward(signal.clone()),
Tolerance::permissive(),
);
}
#[test]
fn should_quantize_linear_weights() {
let device: Device<B> = Default::default();
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::Tensor)
.with_param(QuantParam::F32);
should_quantize_module(
transformer,
scheme,
|tr| tr.weight.val().dequantize(),
Tolerance::permissive(),
);
}
#[test]
fn should_quantize_linear_blocks() {
let device: Device<B> = Default::default();
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::block([16]))
// .with_store(QuantStore::Native)
.with_param(QuantParam::F32);
should_quantize_module(
transformer,
scheme,
|tr| tr.forward(signal.clone()),
Tolerance::permissive(),
);
}
#[test]
fn should_quantize_linear_weights_blocks() {
let device: Device<B> = Default::default();
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::block([16]))
// .with_store(QuantStore::Native)
.with_param(QuantParam::F32);
should_quantize_module(
transformer,
scheme,
|tr| tr.weight.val().dequantize(),
Tolerance::permissive(),
);
}