feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
87
crates/stable-diffusion-burn/burn-crates/burn-nn/Cargo.toml
Normal file
87
crates/stable-diffusion-burn/burn-crates/burn-nn/Cargo.toml
Normal file
@@ -0,0 +1,87 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Neural network building blocks for the Burn deep learning framework"
|
||||
documentation = "https://docs.rs/burn-nn"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-nn"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-nn"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"std",
|
||||
"burn-core/default",
|
||||
]
|
||||
doc = [
|
||||
"std",
|
||||
# Doc features
|
||||
"burn-core/doc",
|
||||
]
|
||||
std = [
|
||||
"burn-core/std",
|
||||
"num-traits/std",
|
||||
]
|
||||
tracing = [
|
||||
"burn-core/tracing",
|
||||
"burn-cuda?/tracing",
|
||||
"burn-rocm?/tracing",
|
||||
"burn-tch?/tracing",
|
||||
"burn-wgpu?/tracing",
|
||||
"burn-fusion?/tracing",
|
||||
]
|
||||
|
||||
test-cuda = [
|
||||
"burn-cuda/default",
|
||||
] # To use cuda during testing, default uses ndarray.
|
||||
test-rocm = [
|
||||
"burn-rocm/default",
|
||||
] # To use hip during testing, default uses ndarray.
|
||||
test-tch = [
|
||||
"burn-tch/default",
|
||||
] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = [
|
||||
"burn-wgpu/default",
|
||||
] # To use wgpu during testing, default uses ndarray.
|
||||
test-vulkan = [
|
||||
"test-wgpu",
|
||||
"burn-wgpu/vulkan",
|
||||
] # To use wgpu-spirv during testing, default uses ndarray.
|
||||
test-metal = [
|
||||
"test-wgpu",
|
||||
"burn-wgpu/metal",
|
||||
] # To use wgpu-spirv during testing, default uses ndarray.
|
||||
|
||||
# Memory checks are disabled by default
|
||||
test-memory-checks = ["burn-fusion/memory-checks"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# FOR TESTING
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
|
||||
rstest = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1,3 @@
|
||||
# Burn Neural Networks
|
||||
|
||||
Core building blocks for Burn neural networks.
|
||||
@@ -0,0 +1,598 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::activation::{
|
||||
Celu, CeluConfig, Elu, EluConfig, Gelu, HardShrink, HardShrinkConfig, HardSigmoid,
|
||||
HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu, PReluConfig, Relu, Selu,
|
||||
Shrink, ShrinkConfig, Sigmoid, SoftShrink, SoftShrinkConfig, Softplus, SoftplusConfig,
|
||||
Softsign, SwiGlu, SwiGluConfig, Tanh, ThresholdedRelu, ThresholdedReluConfig,
|
||||
};
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// [`Activation`] Configuration.
|
||||
#[derive(Config, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum ActivationConfig {
|
||||
/// [`Gelu`] activation layer.
|
||||
Gelu,
|
||||
|
||||
/// [`Gelu`] activation layer with tanh approximation.
|
||||
GeluApproximate,
|
||||
|
||||
/// [`PRelu`] activation layer.
|
||||
PRelu(PReluConfig),
|
||||
|
||||
/// [`Relu`] activation layer.
|
||||
Relu,
|
||||
|
||||
/// [`LeakyRelu`] activation layer.
|
||||
LeakyRelu(LeakyReluConfig),
|
||||
|
||||
/// [`SwiGlu`] activation layer.
|
||||
SwiGlu(SwiGluConfig),
|
||||
|
||||
/// [`Selu`] activation layer.
|
||||
Selu,
|
||||
|
||||
/// [`Sigmoid`] activation layer.
|
||||
Sigmoid,
|
||||
|
||||
/// [`Tanh`] activation layer.
|
||||
Tanh,
|
||||
|
||||
/// [`HardSigmoid`] activation layer.
|
||||
HardSigmoid(HardSigmoidConfig),
|
||||
|
||||
/// [`HardSwish`] activation layer.
|
||||
HardSwish,
|
||||
|
||||
/// [`Softplus`] activation layer.
|
||||
Softplus(SoftplusConfig),
|
||||
|
||||
/// [`Softsign`] activation layer.
|
||||
Softsign,
|
||||
|
||||
/// [`Elu`] activation layer.
|
||||
Elu(EluConfig),
|
||||
|
||||
/// [`Celu`] activation layer.
|
||||
Celu(CeluConfig),
|
||||
|
||||
/// [`ThresholdedRelu`] activation layer.
|
||||
ThresholdedRelu(ThresholdedReluConfig),
|
||||
|
||||
/// [`HardShrink`] activation layer.
|
||||
HardShrink(HardShrinkConfig),
|
||||
|
||||
/// [`SoftShrink`] activation layer.
|
||||
SoftShrink(SoftShrinkConfig),
|
||||
|
||||
/// [`Shrink`] activation layer.
|
||||
Shrink(ShrinkConfig),
|
||||
}
|
||||
|
||||
impl From<PReluConfig> for ActivationConfig {
|
||||
fn from(config: PReluConfig) -> Self {
|
||||
Self::PRelu(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LeakyReluConfig> for ActivationConfig {
|
||||
fn from(config: LeakyReluConfig) -> Self {
|
||||
Self::LeakyRelu(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SwiGluConfig> for ActivationConfig {
|
||||
fn from(config: SwiGluConfig) -> Self {
|
||||
Self::SwiGlu(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HardSigmoidConfig> for ActivationConfig {
|
||||
fn from(config: HardSigmoidConfig) -> Self {
|
||||
Self::HardSigmoid(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SoftplusConfig> for ActivationConfig {
|
||||
fn from(config: SoftplusConfig) -> Self {
|
||||
Self::Softplus(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EluConfig> for ActivationConfig {
|
||||
fn from(config: EluConfig) -> Self {
|
||||
Self::Elu(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CeluConfig> for ActivationConfig {
|
||||
fn from(config: CeluConfig) -> Self {
|
||||
Self::Celu(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ThresholdedReluConfig> for ActivationConfig {
|
||||
fn from(config: ThresholdedReluConfig) -> Self {
|
||||
Self::ThresholdedRelu(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HardShrinkConfig> for ActivationConfig {
|
||||
fn from(config: HardShrinkConfig) -> Self {
|
||||
Self::HardShrink(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SoftShrinkConfig> for ActivationConfig {
|
||||
fn from(config: SoftShrinkConfig) -> Self {
|
||||
Self::SoftShrink(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ShrinkConfig> for ActivationConfig {
|
||||
fn from(config: ShrinkConfig) -> Self {
|
||||
Self::Shrink(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl ActivationConfig {
|
||||
/// Initialize a wrapped activation layer.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {
|
||||
match self {
|
||||
ActivationConfig::Relu => Relu.into(),
|
||||
ActivationConfig::LeakyRelu(conf) => conf.init().into(),
|
||||
ActivationConfig::Gelu => Gelu::new().into(),
|
||||
ActivationConfig::GeluApproximate => Gelu::new_approximate().into(),
|
||||
ActivationConfig::PRelu(conf) => conf.init(device).into(),
|
||||
ActivationConfig::SwiGlu(conf) => conf.init(device).into(),
|
||||
ActivationConfig::HardSigmoid(conf) => conf.init().into(),
|
||||
ActivationConfig::HardSwish => HardSwish.into(),
|
||||
ActivationConfig::Softplus(conf) => conf.init().into(),
|
||||
ActivationConfig::Selu => Selu.into(),
|
||||
ActivationConfig::Sigmoid => Sigmoid.into(),
|
||||
ActivationConfig::Tanh => Tanh.into(),
|
||||
ActivationConfig::Softsign => Softsign.into(),
|
||||
ActivationConfig::Elu(conf) => conf.init().into(),
|
||||
ActivationConfig::Celu(conf) => conf.init().into(),
|
||||
ActivationConfig::HardShrink(conf) => conf.init().into(),
|
||||
ActivationConfig::SoftShrink(conf) => conf.init().into(),
|
||||
ActivationConfig::Shrink(conf) => conf.init().into(),
|
||||
ActivationConfig::ThresholdedRelu(conf) => conf.init().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Activation Layer Wrapper.
|
||||
///
|
||||
/// Provides support for many in-built `burn::nn` activations.
|
||||
#[derive(Module, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum Activation<B: Backend> {
|
||||
/// [`Gelu`] activation layer.
|
||||
Gelu(Gelu),
|
||||
|
||||
/// [`PRelu`] activation layer.
|
||||
PRelu(PRelu<B>),
|
||||
|
||||
/// [`Relu`] activation layer.
|
||||
Relu(Relu),
|
||||
|
||||
/// [`LeakyRelu`] activation layer.
|
||||
LeakyRelu(LeakyRelu),
|
||||
|
||||
/// [`SwiGlu`] activation layer.
|
||||
SwiGlu(SwiGlu<B>),
|
||||
|
||||
/// [`Selu`] activation layer.
|
||||
Selu(Selu),
|
||||
|
||||
/// [`Sigmoid`] activation layer.
|
||||
Sigmoid(Sigmoid),
|
||||
|
||||
/// [`Tanh`] activation layer.
|
||||
Tanh(Tanh),
|
||||
|
||||
/// [`HardSigmoid`] activation layer.
|
||||
HardSigmoid(HardSigmoid),
|
||||
|
||||
/// [`HardSwish`] activation layer.
|
||||
HardSwish(HardSwish),
|
||||
|
||||
/// [`Softplus`] activation layer.
|
||||
Softplus(Softplus),
|
||||
|
||||
/// [`Softsign`] activation layer.
|
||||
Softsign(Softsign),
|
||||
|
||||
/// [`Elu`] activation layer.
|
||||
Elu(Elu),
|
||||
|
||||
/// [`Celu`] activation layer.
|
||||
Celu(Celu),
|
||||
|
||||
/// [`ThresholdedRelu`] activation layer.
|
||||
ThresholdedRelu(ThresholdedRelu),
|
||||
|
||||
/// [`HardShrink`] activation layer.
|
||||
HardShrink(HardShrink),
|
||||
|
||||
/// [`SoftShrink`] activation layer.
|
||||
SoftShrink(SoftShrink),
|
||||
|
||||
/// [`Shrink`] activation layer.
|
||||
Shrink(Shrink),
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Gelu> for Activation<B> {
|
||||
fn from(layer: Gelu) -> Self {
|
||||
Self::Gelu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<PRelu<B>> for Activation<B> {
|
||||
fn from(layer: PRelu<B>) -> Self {
|
||||
Self::PRelu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Relu> for Activation<B> {
|
||||
fn from(layer: Relu) -> Self {
|
||||
Self::Relu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<LeakyRelu> for Activation<B> {
|
||||
fn from(layer: LeakyRelu) -> Self {
|
||||
Self::LeakyRelu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<SwiGlu<B>> for Activation<B> {
|
||||
fn from(layer: SwiGlu<B>) -> Self {
|
||||
Self::SwiGlu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Selu> for Activation<B> {
|
||||
fn from(layer: Selu) -> Self {
|
||||
Self::Selu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Sigmoid> for Activation<B> {
|
||||
fn from(layer: Sigmoid) -> Self {
|
||||
Self::Sigmoid(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Tanh> for Activation<B> {
|
||||
fn from(layer: Tanh) -> Self {
|
||||
Self::Tanh(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<HardSigmoid> for Activation<B> {
|
||||
fn from(layer: HardSigmoid) -> Self {
|
||||
Self::HardSigmoid(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<HardSwish> for Activation<B> {
|
||||
fn from(layer: HardSwish) -> Self {
|
||||
Self::HardSwish(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Softplus> for Activation<B> {
|
||||
fn from(layer: Softplus) -> Self {
|
||||
Self::Softplus(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Softsign> for Activation<B> {
|
||||
fn from(layer: Softsign) -> Self {
|
||||
Self::Softsign(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Elu> for Activation<B> {
|
||||
fn from(layer: Elu) -> Self {
|
||||
Self::Elu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Celu> for Activation<B> {
|
||||
fn from(layer: Celu) -> Self {
|
||||
Self::Celu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<ThresholdedRelu> for Activation<B> {
|
||||
fn from(layer: ThresholdedRelu) -> Self {
|
||||
Self::ThresholdedRelu(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<HardShrink> for Activation<B> {
|
||||
fn from(layer: HardShrink) -> Self {
|
||||
Self::HardShrink(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<SoftShrink> for Activation<B> {
|
||||
fn from(layer: SoftShrink) -> Self {
|
||||
Self::SoftShrink(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<Shrink> for Activation<B> {
|
||||
fn from(layer: Shrink) -> Self {
|
||||
Self::Shrink(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Activation<B> {
|
||||
/// Forward pass.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
match self {
|
||||
Activation::Relu(layer) => layer.forward(input),
|
||||
Activation::LeakyRelu(layer) => layer.forward(input),
|
||||
Activation::Gelu(layer) => layer.forward(input),
|
||||
Activation::PRelu(layer) => layer.forward(input),
|
||||
Activation::SwiGlu(layer) => layer.forward(input),
|
||||
Activation::HardSigmoid(layer) => layer.forward(input),
|
||||
Activation::HardSwish(layer) => layer.forward(input),
|
||||
Activation::Softplus(layer) => layer.forward(input),
|
||||
Activation::Selu(layer) => layer.forward(input),
|
||||
Activation::Sigmoid(layer) => layer.forward(input),
|
||||
Activation::Tanh(layer) => layer.forward(input),
|
||||
Activation::Softsign(layer) => layer.forward(input),
|
||||
Activation::Elu(layer) => layer.forward(input),
|
||||
Activation::Celu(layer) => layer.forward(input),
|
||||
Activation::ThresholdedRelu(layer) => layer.forward(input),
|
||||
Activation::HardShrink(layer) => layer.forward(input),
|
||||
Activation::SoftShrink(layer) => layer.forward(input),
|
||||
Activation::Shrink(layer) => layer.forward(input),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::module::Module;
|
||||
|
||||
fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
|
||||
Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)
|
||||
}
|
||||
|
||||
fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {
|
||||
actual.to_data().assert_eq(&expected.to_data(), true);
|
||||
}
|
||||
|
||||
fn check_stateless_config_output<B: Backend, const D: usize>(
|
||||
config: ActivationConfig,
|
||||
input: Tensor<B, D>,
|
||||
expected: Tensor<B, D>,
|
||||
device: &B::Device,
|
||||
) {
|
||||
let act = config.init(device);
|
||||
let output = act.forward(input);
|
||||
expect_tensor(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Gelu::new().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_approximate() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Gelu::new_approximate().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::GeluApproximate, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prelu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = PReluConfig::new();
|
||||
let expected = inner_config.init(&device).forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Relu.forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = LeakyReluConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_swi_glu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let d_input = input.shape()[1];
|
||||
let d_output = 2 * d_input;
|
||||
|
||||
let inner_config = SwiGluConfig::new(d_input, d_output);
|
||||
let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);
|
||||
|
||||
let config: ActivationConfig = inner_config.into();
|
||||
let layer = config.init(&device);
|
||||
|
||||
match &layer {
|
||||
Activation::SwiGlu(inner) => {
|
||||
// Clone the initialized weights.
|
||||
let state = inner.clone().into_record();
|
||||
reference = reference.load_record(state);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
expect_tensor(
|
||||
layer.forward(input.clone()),
|
||||
reference.forward(input.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_selu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Selu.forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::Selu, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Sigmoid.forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tanh() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Tanh.forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hard_sigmoid() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = HardSigmoidConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softsign() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let expected = Softsign.forward(input.clone());
|
||||
|
||||
check_stateless_config_output(ActivationConfig::Softsign, input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = EluConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softplus() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = SoftplusConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_celu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = CeluConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thresholded_relu() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = ThresholdedReluConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hard_shrink() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = HardShrinkConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soft_shrink() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = SoftShrinkConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shrink() {
|
||||
let device = Default::default();
|
||||
let input = make_input::<TestBackend>(&device);
|
||||
|
||||
let inner_config = ShrinkConfig::new();
|
||||
let expected = inner_config.init().forward(input.clone());
|
||||
|
||||
check_stateless_config_output(inner_config.into(), input, expected, &device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::celu;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// CELU (Continuously Differentiable Exponential Linear Unit) layer.
|
||||
///
|
||||
/// Applies the CELU function element-wise:
|
||||
/// `celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`
|
||||
///
|
||||
/// Should be created with [CeluConfig](CeluConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Celu {
|
||||
/// The alpha value for the CELU formulation.
|
||||
pub alpha: f64,
|
||||
}
|
||||
|
||||
/// Configuration to create a [Celu](Celu) layer using the [init function](CeluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct CeluConfig {
|
||||
/// The alpha value for the CELU formulation. Default is 1.0
|
||||
#[config(default = "1.0")]
|
||||
pub alpha: f64,
|
||||
}
|
||||
|
||||
impl CeluConfig {
|
||||
/// Initialize a new [Celu](Celu) Layer
|
||||
pub fn init(&self) -> Celu {
|
||||
Celu { alpha: self.alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Celu {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("alpha", &self.alpha).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Celu {
|
||||
/// Forward pass for the Celu layer.
|
||||
///
|
||||
/// See [celu](burn::tensor::activation::celu) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
celu(input, self.alpha)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_celu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Celu = CeluConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, -0.5, -1.0]]), &device);
|
||||
let out = model.forward(input);
|
||||
// celu(0.5, 1) = 0.5
|
||||
// celu(-0.5, 1) = 1 * (exp(-0.5) - 1) = -0.393469
|
||||
// celu(-1.0, 1) = 1 * (exp(-1) - 1) = -0.632121
|
||||
let expected = TensorData::from([[0.5, -0.393469, -0.632121]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_celu_with_alpha() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Celu = CeluConfig::new().with_alpha(2.0).init();
|
||||
let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, -2.0]]), &device);
|
||||
let out = model.forward(input);
|
||||
// celu(0, 2) = 0
|
||||
// celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241
|
||||
let expected = TensorData::from([[0.0, -1.264241]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = CeluConfig::new().init();
|
||||
assert_eq!(alloc::format!("{config}"), "Celu {alpha: 1}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::tensor::activation::elu;
|
||||
|
||||
/// ELU (Exponential Linear Unit) layer.
|
||||
///
|
||||
/// Should be created with [EluConfig](EluConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Elu {
|
||||
/// The alpha value.
|
||||
pub alpha: f64,
|
||||
}
|
||||
/// Configuration to create an [Elu](Elu) layer using the [init function](EluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct EluConfig {
|
||||
/// The alpha value. Default is 1.0
|
||||
#[config(default = "1.0")]
|
||||
pub alpha: f64,
|
||||
}
|
||||
impl EluConfig {
|
||||
/// Initialize a new [Elu](Elu) Layer
|
||||
pub fn init(&self) -> Elu {
|
||||
Elu { alpha: self.alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Elu {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("alpha", &self.alpha).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Elu {
|
||||
/// Forward pass for the ELU layer.
|
||||
///
|
||||
/// See [elu](burn::tensor::activation::elu) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
elu(input, self.alpha)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_elu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Elu = EluConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
|
||||
let out = model.forward(input);
|
||||
// elu(0.4410, 1.0) = 0.4410
|
||||
// elu(-0.2507, 1.0) = 1.0 * (exp(-0.2507) - 1) = -0.22186
|
||||
let expected = TensorData::from([[0.4410, -0.22186]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = EluConfig::new().init();
|
||||
assert_eq!(alloc::format!("{config}"), "Elu {alpha: 1}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the Gaussian Error Linear Units function element-wise.
|
||||
///
|
||||
/// See also [gelu](burn::tensor::activation::gelu)
|
||||
///
|
||||
/// When `approximate` is true, uses the tanh approximation:
|
||||
/// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Gelu {
|
||||
/// Whether to use tanh approximation.
|
||||
pub approximate: bool,
|
||||
}
|
||||
|
||||
impl Gelu {
|
||||
/// Create the module with exact GELU.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create the module with tanh approximation.
|
||||
pub fn new_approximate() -> Self {
|
||||
Self { approximate: true }
|
||||
}
|
||||
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
if self.approximate {
|
||||
burn::tensor::activation::gelu_approximate(input)
|
||||
} else {
|
||||
burn::tensor::activation::gelu(input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::Tolerance;
|
||||
use burn::tensor::ops::FloatElem;
|
||||
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Gelu::new();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Gelu {\n approximate: false\n}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_approximate() {
|
||||
let device = Default::default();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[-1.0, 0.0, 1.0], [0.5, -0.5, 2.0]], &device);
|
||||
|
||||
let output = Gelu::new_approximate().forward(input);
|
||||
|
||||
// PyTorch: torch.nn.functional.gelu(x, approximate="tanh")
|
||||
let expected = Tensor::<TestBackend, 2>::from_data(
|
||||
[
|
||||
[-0.1588079929, 0.0000000000, 0.8411920071],
|
||||
[0.3457140028, -0.1542859972, 1.9545977116],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::rel_abs(1e-5, 1e-5));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the gated linear unit function.
|
||||
///
|
||||
/// See also [glu](burn::tensor::activation::glu)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct GLU {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl GLU {
|
||||
/// Create the module.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - The dimension on which to split the input.
|
||||
pub fn new(dim: usize) -> Self {
|
||||
Self { dim }
|
||||
}
|
||||
|
||||
/// Applies the gated linear unit function.
|
||||
///
|
||||
/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
|
||||
///
|
||||
/// **Note**:
|
||||
/// * The size of the input tensor along `dim` must be divisible by 2.
|
||||
///
|
||||
/// ### Arguments
|
||||
/// * `tensor` - The input tensor.
|
||||
///
|
||||
/// ### Returns
|
||||
/// * A tensor with the same shape as the input, except the size along `dim` is halved.
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::glu(input, self.dim)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = GLU::new(1);
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "GLU {\n dim: 1\n}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::hard_shrink;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Hard Shrink layer.
|
||||
///
|
||||
/// Applies the Hard Shrink function element-wise:
|
||||
/// `hard_shrink(x) = x if |x| > lambda else 0`
|
||||
///
|
||||
/// Should be created with [HardShrinkConfig](HardShrinkConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct HardShrink {
|
||||
/// The lambda value for the Hard Shrink formulation.
|
||||
pub lambda: f64,
|
||||
}
|
||||
|
||||
/// Configuration to create a [HardShrink](HardShrink) layer using the [init function](HardShrinkConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct HardShrinkConfig {
|
||||
/// The lambda value for the Hard Shrink formulation. Default is 0.5
|
||||
#[config(default = "0.5")]
|
||||
pub lambda: f64,
|
||||
}
|
||||
|
||||
impl HardShrinkConfig {
|
||||
/// Initialize a new [HardShrink](HardShrink) Layer
|
||||
pub fn init(&self) -> HardShrink {
|
||||
HardShrink {
|
||||
lambda: self.lambda,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for HardShrink {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("lambda", &self.lambda).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl HardShrink {
|
||||
/// Forward pass for the Hard Shrink layer.
|
||||
///
|
||||
/// See [hard_shrink](burn::tensor::activation::hard_shrink) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
hard_shrink(input, self.lambda)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_hard_shrink_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: HardShrink = HardShrinkConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, -1.0], [8.0, 0.0, 0.0]]);
|
||||
assert_eq!(out.into_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hard_shrink_with_lambda() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: HardShrink = HardShrinkConfig::new().with_lambda(0.2).init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[0.1, -0.1, -0.3], [0.5, 0.1, 0.0]], &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, -0.3], [0.5, 0.0, 0.0]]);
|
||||
assert_eq!(out.into_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = HardShrinkConfig::new().init();
|
||||
assert_eq!(alloc::format!("{config}"), "HardShrink {lambda: 0.5}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::hard_sigmoid;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Hard Sigmoid layer.
|
||||
///
|
||||
/// Should be created with [HardSigmoidConfig](HardSigmoidConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct HardSigmoid {
|
||||
/// The alpha value.
|
||||
pub alpha: f64,
|
||||
/// The beta value.
|
||||
pub beta: f64,
|
||||
}
|
||||
/// Configuration to create a [Hard Sigmoid](HardSigmoid) layer using the [init function](HardSigmoidConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct HardSigmoidConfig {
|
||||
/// The alpha value. Default is 0.2
|
||||
#[config(default = "0.2")]
|
||||
pub alpha: f64,
|
||||
/// The beta value. Default is 0.5
|
||||
#[config(default = "0.5")]
|
||||
pub beta: f64,
|
||||
}
|
||||
impl HardSigmoidConfig {
|
||||
/// Initialize a new [Hard Sigmoid](HardSigmoid) Layer
|
||||
pub fn init(&self) -> HardSigmoid {
|
||||
HardSigmoid {
|
||||
alpha: self.alpha,
|
||||
beta: self.beta,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for HardSigmoid {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("alpha", &self.alpha)
|
||||
.add("beta", &self.beta)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl HardSigmoid {
|
||||
/// Forward pass for the Hard Sigmoid layer.
|
||||
///
|
||||
/// See [hard_sigmoid](burn::tensor::activation::hard_sigmoid) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
hard_sigmoid(input, self.alpha, self.beta)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_hard_sigmoid_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: HardSigmoid = HardSigmoidConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.5882, 0.44986]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = HardSigmoidConfig::new().init();
|
||||
assert_eq!(
|
||||
alloc::format!("{config}"),
|
||||
"HardSigmoid {alpha: 0.2, beta: 0.5}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::hard_swish;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Hard Swish layer.
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct HardSwish;
|
||||
|
||||
impl HardSwish {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Forward pass for the Hard Swish layer.
|
||||
///
|
||||
/// See [hard_swish](burn::tensor::activation::hard_swish) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
hard_swish(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_hard_swish_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model = HardSwish::new();
|
||||
|
||||
let input = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[3.0f32, -3.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[3.0f32, 0.0], [0.0, 0.6666667]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = HardSwish::new();
|
||||
assert_eq!(alloc::format!("{layer}"), "HardSwish");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::tensor::activation::leaky_relu;
|
||||
|
||||
/// Leaky ReLu layer.
|
||||
///
|
||||
/// Should be created with [LeakyReluConfig](LeakyReluConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct LeakyRelu {
|
||||
/// The negative slope.
|
||||
pub negative_slope: f64,
|
||||
}
|
||||
/// Configuration to create a [Leaky Relu](LeakyRelu) layer using the [init function](LeakyReluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LeakyReluConfig {
|
||||
/// The negative slope. Default is 0.01
|
||||
#[config(default = "0.01")]
|
||||
pub negative_slope: f64,
|
||||
}
|
||||
impl LeakyReluConfig {
|
||||
/// Initialize a new [Leaky Relu](LeakyRelu) Layer
|
||||
pub fn init(&self) -> LeakyRelu {
|
||||
LeakyRelu {
|
||||
negative_slope: self.negative_slope,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for LeakyRelu {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("negative_slope", &self.negative_slope)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl LeakyRelu {
|
||||
/// Forward pass for the Leaky ReLu layer.
|
||||
///
|
||||
/// See [leaky_relu](burn::tensor::activation::leaky_relu) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
leaky_relu(input, self.negative_slope)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu = LeakyReluConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.4410, -0.002507]]);
|
||||
out.to_data().assert_eq(&expected, false);
|
||||
}
|
||||
#[test]
|
||||
fn test_leaky_relu_forward_multi_dim() {
|
||||
let input = [
|
||||
[
|
||||
[-1.0222, 1.5810, 0.3457, -1.3530],
|
||||
[0.0231, 0.8681, 0.2473, -0.0377],
|
||||
[0.3520, -1.1199, 1.2219, 0.2804],
|
||||
],
|
||||
[
|
||||
[1.0002, 0.7259, 0.8779, 0.2084],
|
||||
[1.5615, -0.1057, -0.4886, -1.5184],
|
||||
[-0.5523, -0.2741, -0.0210, -1.1352],
|
||||
],
|
||||
];
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[-1.0222e-02, 1.5810e+00, 3.457e-01, -1.3530e-02],
|
||||
[2.31e-02, 8.681e-01, 2.473e-01, -3.77e-04],
|
||||
[3.52e-01, -1.1199e-02, 1.2219e+00, 2.804e-01],
|
||||
],
|
||||
[
|
||||
[1.0002e+00, 7.259e-01, 8.779e-01, 2.084e-01],
|
||||
[1.5615e+00, -1.057e-03, -4.886e-03, -1.5184e-02],
|
||||
[-5.523e-03, -2.741e-03, -2.1e-04, -1.1352e-02],
|
||||
],
|
||||
]);
|
||||
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu = LeakyReluConfig::new().init();
|
||||
let input_data = Tensor::<TestBackend, 3>::from_data(TensorData::from(input), &device);
|
||||
let actual_output = model.forward(input_data);
|
||||
actual_output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = LeakyReluConfig::new().init();
|
||||
assert_eq!(
|
||||
alloc::format!("{config}"),
|
||||
"LeakyRelu {negative_slope: 0.01}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
//! # Activation Layers
|
||||
//!
|
||||
//! Users who desire a selectable activation function should
|
||||
//! consider [`Activation`], which provides an abstraction over:
|
||||
//! * [`Relu`] - the default,
|
||||
//! * ['PRelu']
|
||||
//! * [`Gelu`]
|
||||
//! * [`LeakyRelu`]
|
||||
//! * [`SwiGlu`]
|
||||
//! * [`Selu`]
|
||||
//! * [`Sigmoid`]
|
||||
//! * [`HardSigmoid`]
|
||||
//! * [`HardSwish`]
|
||||
//! * [`Softplus`]
|
||||
//! * [`Softsign`]
|
||||
//! * [`Tanh`]
|
||||
//! * [`Elu`]
|
||||
//! * [`Celu`]
|
||||
//! * [`ThresholdedRelu`]
|
||||
//!
|
||||
//! The activation layer [`GLU`] has shape-changing behaviors
|
||||
//! not compatible with the common API, and is not included
|
||||
//! in the abstraction wrappers.
|
||||
|
||||
mod activation_wrapper;
|
||||
|
||||
// These are pub(crate) for dual-export in `nn` without re-exporting
|
||||
// all of `nn.activation`, or manually listing each symbol.
|
||||
pub(crate) mod celu;
|
||||
pub(crate) mod elu;
|
||||
pub(crate) mod gelu;
|
||||
pub(crate) mod glu;
|
||||
pub(crate) mod hard_shrink;
|
||||
pub(crate) mod hard_sigmoid;
|
||||
pub(crate) mod hard_swish;
|
||||
pub(crate) mod leaky_relu;
|
||||
pub(crate) mod prelu;
|
||||
pub(crate) mod relu;
|
||||
pub(crate) mod selu;
|
||||
pub(crate) mod shrink;
|
||||
pub(crate) mod sigmoid;
|
||||
pub(crate) mod soft_shrink;
|
||||
pub(crate) mod softplus;
|
||||
pub(crate) mod softsign;
|
||||
pub(crate) mod swiglu;
|
||||
pub(crate) mod tanh;
|
||||
pub(crate) mod thresholded_relu;
|
||||
|
||||
pub use activation_wrapper::*;
|
||||
pub use celu::*;
|
||||
pub use elu::*;
|
||||
pub use gelu::*;
|
||||
pub use glu::*;
|
||||
pub use hard_shrink::*;
|
||||
pub use hard_sigmoid::*;
|
||||
pub use hard_swish::*;
|
||||
pub use leaky_relu::*;
|
||||
pub use prelu::*;
|
||||
pub use relu::*;
|
||||
pub use selu::*;
|
||||
pub use shrink::*;
|
||||
pub use sigmoid::*;
|
||||
pub use soft_shrink::*;
|
||||
pub use softplus::*;
|
||||
pub use softsign::*;
|
||||
pub use swiglu::*;
|
||||
pub use tanh::*;
|
||||
pub use thresholded_relu::*;
|
||||
@@ -0,0 +1,87 @@
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn_core as burn;
|
||||
/// Parametric Relu layer.
|
||||
///
|
||||
/// Should be created using [PReluConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PRelu<B: Backend> {
|
||||
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
|
||||
/// be the same as number of channels in the input tensor
|
||||
pub alpha: Param<Tensor<B, 1>>,
|
||||
|
||||
/// Alpha value for the PRelu layer
|
||||
pub alpha_value: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PRelu<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [num_parameters] = self.alpha.shape().dims();
|
||||
|
||||
content
|
||||
.add("num_parameters", &num_parameters)
|
||||
.add("alpha_value", &self.alpha_value)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PReluConfig {
|
||||
/// The number of parameters.
|
||||
#[config(default = "1")]
|
||||
pub num_parameters: usize,
|
||||
/// The learnable weight alpha. Default is 0.25
|
||||
#[config(default = "0.25")]
|
||||
pub alpha: f64,
|
||||
}
|
||||
|
||||
impl PReluConfig {
|
||||
/// Initialize a new [Parametric Relu](PRelu) Layer
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
|
||||
PRelu {
|
||||
// alpha is a tensor of length num_parameters
|
||||
alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
|
||||
alpha_value: self.alpha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PRelu<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
///
|
||||
/// See also [prelu](burn::tensor::activation::prelu) for more information.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::prelu(input, self.alpha.val())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = PReluConfig::new().init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the rectified linear unit function element-wise
|
||||
/// See also [relu](burn::tensor::activation::relu)
|
||||
///
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Relu;
|
||||
|
||||
impl Relu {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::relu(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Relu::new();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Relu");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the Scaled Exponential Linear Unit function element-wise.
|
||||
/// See also [selu](burn::tensor::activation::selu)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Selu;
|
||||
|
||||
impl Selu {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::selu(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Selu::new();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Selu");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::shrink;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Shrink layer.
|
||||
///
|
||||
/// Applies the Shrink function element-wise:
|
||||
/// `shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`
|
||||
///
|
||||
/// Should be created with [ShrinkConfig](ShrinkConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Shrink {
|
||||
/// The lambda value for the Shrink formulation.
|
||||
pub lambda: f64,
|
||||
/// The bias value for the Shrink formulation.
|
||||
// Usually bias = lambda, but need this to handle onnx spec https://onnx.ai/onnx/operators/onnx__Shrink.html
|
||||
pub bias: f64,
|
||||
}
|
||||
|
||||
/// Configuration to create a [Shrink](Shrink) layer using the [init function](ShrinkConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ShrinkConfig {
|
||||
/// The lambda value for the Shrink formulation. Default is 0.5
|
||||
#[config(default = "0.5")]
|
||||
pub lambda: f64,
|
||||
/// The bias value for the Shrink formulation. Default is 0.5.
|
||||
#[config(default = "0.5")]
|
||||
pub bias: f64,
|
||||
}
|
||||
|
||||
impl ShrinkConfig {
|
||||
/// Initialize a new [Shrink](Shrink) Layer
|
||||
pub fn init(&self) -> Shrink {
|
||||
Shrink {
|
||||
lambda: self.lambda,
|
||||
bias: self.bias,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Shrink {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("lambda", &self.lambda)
|
||||
.add("bias", &self.bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Shrink {
|
||||
/// Forward pass for the Shrink layer.
|
||||
///
|
||||
/// See [shrink](burn::tensor::activation::shrink) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
shrink(input, self.lambda, self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_shrink_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Shrink = ShrinkConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);
|
||||
assert_eq!(out.into_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shrink_with_lambda_and_bias() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Shrink = ShrinkConfig::new()
|
||||
.with_lambda(0.25)
|
||||
.with_bias(0.125)
|
||||
.init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, -0.375], [0.625, 0.0, 0.0]]);
|
||||
assert_eq!(out.into_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ShrinkConfig::new().init();
|
||||
assert_eq!(
|
||||
alloc::format!("{config}"),
|
||||
"Shrink {lambda: 0.5, bias: 0.5}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the sigmoid function element-wise
|
||||
/// See also [sigmoid](burn::tensor::activation::sigmoid)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Sigmoid;
|
||||
|
||||
impl Sigmoid {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::sigmoid(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Sigmoid::new();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Sigmoid");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::soft_shrink;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Soft Shrink layer.
|
||||
///
|
||||
/// Applies the Soft Shrink function element-wise:
|
||||
/// `soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`
|
||||
///
|
||||
/// Should be created with [SoftShrinkConfig](SoftShrinkConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct SoftShrink {
|
||||
/// The lambda value for the Soft Shrink formulation.
|
||||
pub lambda: f64,
|
||||
}
|
||||
|
||||
/// Configuration to create a [SoftShrink](SoftShrink) layer using the [init function](SoftShrinkConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SoftShrinkConfig {
|
||||
/// The lambda value for the Soft Shrink formulation. Default is 0.5
|
||||
#[config(default = "0.5")]
|
||||
pub lambda: f64,
|
||||
}
|
||||
|
||||
impl SoftShrinkConfig {
|
||||
/// Initialize a new [SoftShrink](SoftShrink) Layer
|
||||
pub fn init(&self) -> SoftShrink {
|
||||
SoftShrink {
|
||||
lambda: self.lambda,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for SoftShrink {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("lambda", &self.lambda).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl SoftShrink {
|
||||
/// Forward pass for the Soft Shrink layer.
|
||||
///
|
||||
/// See [soft_shrink](burn::tensor::activation::soft_shrink) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
soft_shrink(input, self.lambda)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_soft_shrink_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: SoftShrink = SoftShrinkConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);
|
||||
assert_eq!(out.into_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soft_shrink_with_lambda() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: SoftShrink = SoftShrinkConfig::new().with_lambda(0.25).init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, -0.25], [0.5, 0.0, 0.0]]);
|
||||
assert_eq!(out.into_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = SoftShrinkConfig::new().init();
|
||||
assert_eq!(alloc::format!("{config}"), "SoftShrink {lambda: 0.5}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::activation::softplus;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Softplus layer.
|
||||
///
|
||||
/// Applies the softplus function element-wise:
|
||||
/// `softplus(x) = (1/beta) * log(1 + exp(beta * x))`
|
||||
///
|
||||
/// Should be created with [SoftplusConfig](SoftplusConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Softplus {
|
||||
/// The beta value.
|
||||
pub beta: f64,
|
||||
}
|
||||
|
||||
/// Configuration to create a [Softplus](Softplus) layer using the [init function](SoftplusConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SoftplusConfig {
|
||||
/// The beta value. Default is 1.0
|
||||
#[config(default = "1.0")]
|
||||
pub beta: f64,
|
||||
}
|
||||
|
||||
impl SoftplusConfig {
|
||||
/// Initialize a new [Softplus](Softplus) Layer
|
||||
pub fn init(&self) -> Softplus {
|
||||
Softplus { beta: self.beta }
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Softplus {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("beta", &self.beta).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Softplus {
|
||||
/// Forward pass for the Softplus layer.
|
||||
///
|
||||
/// See [softplus](burn::tensor::activation::softplus) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
softplus(input, self.beta)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::approx_constant)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_softplus_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Softplus = SoftplusConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0, -1.0]]), &device);
|
||||
let out = model.forward(input);
|
||||
// softplus(0) = log(2) ≈ 0.6931
|
||||
// softplus(1) = log(1 + e) ≈ 1.3133
|
||||
// softplus(-1) = log(1 + e^-1) ≈ 0.3133
|
||||
let expected = TensorData::from([[0.6931, 1.3133, 0.3133]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softplus_with_beta() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: Softplus = SoftplusConfig::new().with_beta(2.0).init();
|
||||
let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0]]), &device);
|
||||
let out = model.forward(input);
|
||||
// softplus(0, beta=2) = (1/2) * log(1 + exp(0)) = 0.5 * log(2) ≈ 0.3466
|
||||
// softplus(1, beta=2) = (1/2) * log(1 + exp(2)) = 0.5 * log(8.389) ≈ 1.0635
|
||||
let expected = TensorData::from([[0.3466, 1.0635]]);
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = SoftplusConfig::new().init();
|
||||
assert_eq!(alloc::format!("{config}"), "Softplus {beta: 1}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the softsign function element-wise
|
||||
/// See also [softsign](burn::tensor::activation::softsign)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Softsign;
|
||||
|
||||
impl Softsign {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::softsign(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Softsign::new();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Softsign");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::activation::silu;
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
use crate::{Linear, LinearConfig, LinearLayout};
|
||||
|
||||
/// Configuration to create a [SwiGlu](SwiGlu) activation layer using the [init function](SwiGluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SwiGluConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the output features.
|
||||
pub d_output: usize,
|
||||
/// If a bias should be applied during the linear transformation. Default behaviour is False
|
||||
/// for SwiGLU activation implementations.
|
||||
#[config(default = false)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize the linear layer parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
/// The layout in which the linear parameters are stored.
|
||||
#[config(default = "LinearLayout::Row")]
|
||||
pub layout: LinearLayout,
|
||||
}
|
||||
|
||||
/// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor.
|
||||
/// The SwiGLU activation function is defined as:
|
||||
/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`
|
||||
///
|
||||
/// Should be created with [SwiGluConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct SwiGlu<B: Backend> {
|
||||
/// The inner linear layer for Swish activation function
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
pub linear_inner: Linear<B>,
|
||||
/// The outer linear layer for element wise multiplication
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
pub linear_outer: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for SwiGlu<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, d_output] = self.linear_inner.weight.shape().dims();
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_output", &d_output)
|
||||
.add("bias", &self.linear_inner.bias.is_some())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl SwiGluConfig {
|
||||
/// Initialize a new [SwiGLU](SwiGlu) activation layer.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
|
||||
SwiGlu {
|
||||
linear_inner: LinearConfig::new(self.d_input, self.d_output)
|
||||
.with_bias(self.bias)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.with_layout(self.layout)
|
||||
.init(device),
|
||||
linear_outer: LinearConfig::new(self.d_input, self.d_output)
|
||||
.with_bias(self.bias)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.with_layout(self.layout)
|
||||
.init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SwiGlu<B> {
|
||||
/// Applies the Swish Gated Linear Unit to the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, seq_length, d_input]`
|
||||
/// - output: `[batch_size, seq_length, d_output]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = self.linear_inner.forward(input.clone());
|
||||
let x = silu(x);
|
||||
x.mul(self.linear_outer.forward(input))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_swiglu_forward_no_bias() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 });
|
||||
let swiglu = config.init(&device);
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||||
let output = swiglu.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 2>::from_data(
|
||||
[[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_swiglu_forward_with_bias() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = SwiGluConfig::new(3, 3)
|
||||
.with_bias(true)
|
||||
.with_initializer(Initializer::Constant { value: 0.5 });
|
||||
let swiglu = config.init(&device);
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||||
let output = swiglu.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 2>::from_data(
|
||||
[[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = SwiGluConfig::new(3, 5);
|
||||
let swiglu = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{swiglu}"),
|
||||
"SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Applies the tanh activation function element-wise
|
||||
/// See also [tanh](burn::tensor::activation::tanh)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Tanh;
|
||||
|
||||
impl Tanh {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
burn::tensor::activation::tanh(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Tanh::new();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Tanh");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::tensor::activation::thresholded_relu;
|
||||
|
||||
/// Thresholded ReLU layer.
|
||||
///
|
||||
/// Should be created with [ThresholdedReluConfig](ThresholdedReluConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct ThresholdedRelu {
|
||||
/// The alpha threshold.
|
||||
pub alpha: f64,
|
||||
}
|
||||
|
||||
/// Configuration to create a [ThresholdedRelu](ThresholdedRelu) layer using the [init function](ThresholdedReluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ThresholdedReluConfig {
|
||||
/// The alpha threshold. Default is 1.0
|
||||
#[config(default = "1.0")]
|
||||
pub alpha: f64,
|
||||
}
|
||||
|
||||
impl ThresholdedReluConfig {
|
||||
/// Initialize a new [ThresholdedRelu](ThresholdedRelu) layer.
|
||||
pub fn init(&self) -> ThresholdedRelu {
|
||||
ThresholdedRelu { alpha: self.alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for ThresholdedRelu {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("alpha", &self.alpha).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl ThresholdedRelu {
|
||||
/// Forward pass for the Thresholded ReLU layer.
|
||||
///
|
||||
/// See [thresholded_relu](burn::tensor::activation::thresholded_relu) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
thresholded_relu(input, self.alpha)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_thresholded_relu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: ThresholdedRelu = ThresholdedReluConfig::new().init();
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, 1.5, -0.2]]), &device);
|
||||
let out = model.forward(input);
|
||||
let expected = TensorData::from([[0.0, 1.5, 0.0]]);
|
||||
out.to_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ThresholdedReluConfig::new().init();
|
||||
assert_eq!(alloc::format!("{config}"), "ThresholdedRelu {alpha: 1}");
|
||||
}
|
||||
}
|
||||
63
crates/stable-diffusion-burn/burn-crates/burn-nn/src/lib.rs
Normal file
63
crates/stable-diffusion-burn/burn-crates/burn-nn/src/lib.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
//! Burn neural network module.
|
||||
|
||||
/// Loss module
|
||||
pub mod loss;
|
||||
|
||||
/// Neural network modules implementations.
|
||||
pub mod modules;
|
||||
pub use modules::*;
|
||||
|
||||
pub mod activation;
|
||||
pub use activation::{
|
||||
celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*,
|
||||
relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*,
|
||||
tanh::*, thresholded_relu::*,
|
||||
};
|
||||
|
||||
mod padding;
|
||||
pub use padding::*;
|
||||
|
||||
// For backward compat, `burn::nn::Initializer`
|
||||
pub use burn_core::module::Initializer;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Backend for test cases
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-tch"),
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda"),
|
||||
not(feature = "test-rocm")
|
||||
))]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-tch"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_tch::LibTorch<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_cuda::Cuda;
|
||||
|
||||
#[cfg(all(test, feature = "test-rocm"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_rocm::Rocm;
|
||||
|
||||
/// Backend for autodiff test cases
|
||||
#[cfg(test)]
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
|
||||
#[cfg(all(test, feature = "test-memory-checks"))]
|
||||
mod tests {
|
||||
burn_fusion::memory_checks!();
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::activation::log_sigmoid;
|
||||
use burn::tensor::{Int, Tensor, backend::Backend};
|
||||
use burn::{config::Config, module::Module};
|
||||
|
||||
/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BinaryCrossEntropyLossConfig {
|
||||
/// Create weighted binary cross-entropy with a weight for each class.
|
||||
///
|
||||
/// The loss of a specific sample will simply be multiplied by its label weight.
|
||||
pub weights: Option<Vec<f32>>,
|
||||
|
||||
/// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).
|
||||
///
|
||||
/// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
|
||||
/// Alpha = 0 would be the same as default.
|
||||
pub smoothing: Option<f32>,
|
||||
|
||||
/// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
|
||||
#[config(default = false)]
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl BinaryCrossEntropyLossConfig {
|
||||
/// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
|
||||
self.assertions();
|
||||
BinaryCrossEntropyLoss {
|
||||
weights: self
|
||||
.weights
|
||||
.as_ref()
|
||||
.map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
|
||||
smoothing: self.smoothing,
|
||||
logits: self.logits,
|
||||
}
|
||||
}
|
||||
|
||||
fn assertions(&self) {
|
||||
if let Some(alpha) = self.smoothing {
|
||||
assert!(
|
||||
(0.0..=1.).contains(&alpha),
|
||||
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
|
||||
);
|
||||
};
|
||||
if let Some(weights) = self.weights.as_ref() {
|
||||
assert!(
|
||||
weights.iter().all(|e| e > &0.),
|
||||
"Weights of cross-entropy have to be positive."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the binary cross entropy loss from the input logits and the targets.
|
||||
///
|
||||
/// Should be created using [BinaryCrossEntropyLossConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BinaryCrossEntropyLoss<B: Backend> {
|
||||
/// Weights for cross-entropy.
|
||||
pub weights: Option<Tensor<B, 1>>,
|
||||
/// Label smoothing alpha.
|
||||
pub smoothing: Option<f32>,
|
||||
/// Treat the inputs as logits
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("weights", &self.weights)
|
||||
.add("smoothing", &self.smoothing)
|
||||
.add("logits", &self.logits)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BinaryCrossEntropyLoss<B> {
|
||||
/// Compute the criterion on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// Binary:
|
||||
/// - logits: `[batch_size]`
|
||||
/// - targets: `[batch_size]`
|
||||
///
|
||||
/// Multi-label:
|
||||
/// - logits: `[batch_size, num_classes]`
|
||||
/// - targets: `[batch_size, num_classes]`
|
||||
pub fn forward<const D: usize>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
targets: Tensor<B, D, Int>,
|
||||
) -> Tensor<B, 1> {
|
||||
self.assertions(&logits, &targets);
|
||||
|
||||
let mut targets_float = targets.clone().float();
|
||||
let shape = targets.dims();
|
||||
|
||||
if let Some(alpha) = self.smoothing {
|
||||
let num_classes = if D > 1 { shape[D - 1] } else { 2 };
|
||||
targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
|
||||
}
|
||||
|
||||
let mut loss = if self.logits {
|
||||
// Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
|
||||
(targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
|
||||
} else {
|
||||
// - (target * log(input) + (1 - target) * log(1 - input))
|
||||
// https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
|
||||
(targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
|
||||
- targets_float * logits.log().clamp_min(-100.0)
|
||||
};
|
||||
|
||||
if let Some(weights) = &self.weights {
|
||||
let weights = if D > 1 {
|
||||
weights.clone().expand(shape)
|
||||
} else {
|
||||
// Flatten targets and expand resulting weights to make it compatible with
|
||||
// Tensor<B, D> for binary 1-D case
|
||||
weights
|
||||
.clone()
|
||||
.gather(0, targets.flatten(0, 0))
|
||||
.expand(shape)
|
||||
};
|
||||
loss = loss * weights;
|
||||
}
|
||||
|
||||
loss.mean()
|
||||
}
|
||||
|
||||
fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
|
||||
let logits_dims = logits.dims();
|
||||
let targets_dims = targets.dims();
|
||||
assert!(
|
||||
logits_dims == targets_dims,
|
||||
"Shape of targets ({targets_dims:?}) should correspond to outer shape of logits ({logits_dims:?})."
|
||||
);
|
||||
|
||||
if let Some(weights) = &self.weights
|
||||
&& D > 1
|
||||
{
|
||||
let targets_classes = targets_dims[D - 1];
|
||||
let weights_classes = weights.dims()[0];
|
||||
assert!(
|
||||
weights_classes == targets_classes,
|
||||
"The number of classes ({weights_classes}) does not match the weights provided ({targets_classes})."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::{TensorData, activation::sigmoid};
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_preds_all_correct() {
|
||||
let device = Default::default();
|
||||
let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.init(&device)
|
||||
.forward(preds, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([0.000]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_preds_all_incorrect() {
|
||||
let device = Default::default();
|
||||
let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.init(&device)
|
||||
.forward(preds, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([100.000]); // clamped value
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
|
||||
// target = torch.tensor([0., 1., 0., 1.])
|
||||
// loss = nn.BCELoss()
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target) # tensor(0.7491)
|
||||
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.init(&device)
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([0.7491]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_with_logits() {
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.init(&device)
|
||||
.forward(logits, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([0.7491]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_with_weights() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
|
||||
// target = torch.tensor([0, 1, 0, 1])
|
||||
// weights = torch.tensor([3., 7.]).gather(0, target)
|
||||
// loss = nn.BCELoss(weights)
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target.float()) # tensor(3.1531)
|
||||
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
|
||||
let weights = [3., 7.];
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_weights(Some(weights.to_vec()))
|
||||
.init(&device)
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([3.1531]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_with_smoothing() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
|
||||
// target = torch.tensor([0., 1., 0., 1.])
|
||||
// target_smooth = target * (1 - 0.1) + (0.1 / 2)
|
||||
// loss = nn.BCELoss()
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target_smooth) # tensor(0.7490)
|
||||
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_smoothing(Some(0.1))
|
||||
.init(&device)
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([0.7490]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_multilabel() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// weights = torch.tensor([3., 7., 0.9])
|
||||
// loss = nn.BCEWithLogitsLoss()
|
||||
// out = loss(input, target) # tensor(0.7112)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2, Int>::from_data(
|
||||
TensorData::from([[1, 0, 1], [1, 0, 0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.init(&device)
|
||||
.forward(logits, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([0.7112]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_multilabel_with_weights() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// loss = nn.BCEWithLogitsLoss()
|
||||
// out = loss(input, target) # tensor(3.1708)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2, Int>::from_data(
|
||||
TensorData::from([[1, 0, 1], [1, 0, 0]]),
|
||||
&device,
|
||||
);
|
||||
let weights = [3., 7., 0.9];
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.with_weights(Some(weights.to_vec()))
|
||||
.init(&device)
|
||||
.forward(logits, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([3.1708]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_multilabel_with_smoothing() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// target_smooth = target * (1 - 0.1) + (0.1 / 3)
|
||||
// loss = nn.BCELoss()
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target_smooth) # tensor(0.7228)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2, Int>::from_data(
|
||||
TensorData::from([[1, 0, 1], [1, 0, 0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_smoothing(Some(0.1))
|
||||
.init(&device)
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = TensorData::from([0.7228]);
|
||||
loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "The number of classes"]
|
||||
fn multilabel_weights_should_match_target() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// loss = nn.BCEWithLogitsLoss()
|
||||
// out = loss(input, target) # tensor(3.1708)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2, Int>::from_data(
|
||||
TensorData::from([[1, 0, 1], [1, 0, 0]]),
|
||||
&device,
|
||||
);
|
||||
let weights = [3., 7.];
|
||||
|
||||
let _loss = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.with_weights(Some(weights.to_vec()))
|
||||
.init(&device)
|
||||
.forward(logits, targets);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config =
|
||||
BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
|
||||
let loss = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{loss}"),
|
||||
"BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn::tensor::linalg::cosine_similarity;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::loss::reduction::Reduction;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::module::{Ignored, Module};
|
||||
use burn::tensor::{Int, Tensor, activation::relu, backend::Backend};
|
||||
|
||||
/// Configuration for CosineEmbeddingLoss.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct CosineEmbeddingLossConfig {
|
||||
/// Margin for negative samples.
|
||||
#[config(default = 0.0)]
|
||||
pub margin: f32,
|
||||
|
||||
/// Specifies the reduction to apply to the output.
|
||||
#[config(default = "Reduction::Mean")]
|
||||
pub reduction: Reduction,
|
||||
}
|
||||
|
||||
impl CosineEmbeddingLossConfig {
|
||||
/// Initialize CosineEmbeddingLoss.
|
||||
pub fn init(&self) -> CosineEmbeddingLoss {
|
||||
CosineEmbeddingLoss {
|
||||
margin: self.margin,
|
||||
reduction: Ignored(self.reduction.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine embedding loss between two tensors.
|
||||
///
|
||||
/// Measures cosine distance between tensors.
|
||||
/// Used for learning embeddings or similarity.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct CosineEmbeddingLoss {
|
||||
/// Margin value. Default: 0.0
|
||||
pub margin: f32,
|
||||
|
||||
/// Reduction method
|
||||
pub reduction: Ignored<Reduction>,
|
||||
}
|
||||
|
||||
impl Default for CosineEmbeddingLoss {
|
||||
fn default() -> Self {
|
||||
CosineEmbeddingLossConfig::new().init()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for CosineEmbeddingLoss {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("margin", &self.margin)
|
||||
.add("reduction", format!("{:?}", &self.reduction.0).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl CosineEmbeddingLoss {
|
||||
/// Creates a new instance
|
||||
pub fn new() -> Self {
|
||||
CosineEmbeddingLossConfig::new().init()
|
||||
}
|
||||
|
||||
/// Compute loss with reduction.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input1: ``[batch_size, embedding_dim]``
|
||||
/// - input2: ``[batch_size, embedding_dim]``
|
||||
/// - target: ``[batch_size]`` with values 1 or -1
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Loss tensor of shape ``[1]``
|
||||
pub fn forward<B: Backend>(
|
||||
&self,
|
||||
input1: Tensor<B, 2>,
|
||||
input2: Tensor<B, 2>,
|
||||
target: Tensor<B, 1, Int>,
|
||||
) -> Tensor<B, 1> {
|
||||
let tensor = self.forward_no_reduction(input1, input2, target);
|
||||
match &self.reduction.0 {
|
||||
Reduction::Mean | Reduction::Auto => tensor.mean(),
|
||||
Reduction::Sum => tensor.sum(),
|
||||
other => panic!("{other:?} reduction is not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute loss without applying reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input1` - First input tensor of shape ``[batch_size, embedding_dim]``
|
||||
/// * `input2` - Second input tensor of shape ``[batch_size, embedding_dim]``
|
||||
/// * `target` - Target tensor of shape ``[batch_size]`` with values 1 or -1
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Tensor of per-element losses with shape ``[batch_size]``
|
||||
pub fn forward_no_reduction<B: Backend>(
|
||||
&self,
|
||||
input1: Tensor<B, 2>,
|
||||
input2: Tensor<B, 2>,
|
||||
target: Tensor<B, 1, Int>,
|
||||
) -> Tensor<B, 1> {
|
||||
self.assertions(&input1, &input2, &target);
|
||||
|
||||
// cos_sim shape: [batch_size, 1]
|
||||
let cos_sim = cosine_similarity(input1, input2, 1, None);
|
||||
// cos_sim shape: [batch_size]
|
||||
let cos_sim: Tensor<B, 1> = cos_sim.squeeze_dim(1);
|
||||
|
||||
let mut loss = cos_sim.zeros_like();
|
||||
|
||||
// Similar pairs (target == 1) - Formula: L = 1 - cos_sim
|
||||
let similar_mask = target.clone().equal_elem(1);
|
||||
let similar_loss = cos_sim.clone().neg().add_scalar(1);
|
||||
loss = loss.mask_where(similar_mask, similar_loss);
|
||||
|
||||
// Dissimilar pairs (target == -1) - Formula: L = max(0, cos_sim - margin)
|
||||
let dissimilar_mask = target.equal_elem(-1);
|
||||
let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin));
|
||||
loss = loss.mask_where(dissimilar_mask, dissimilar_loss);
|
||||
|
||||
// return loss shape: [batch_size]
|
||||
loss
|
||||
}
|
||||
|
||||
fn assertions<B: Backend>(
|
||||
&self,
|
||||
input1: &Tensor<B, 2>,
|
||||
input2: &Tensor<B, 2>,
|
||||
target: &Tensor<B, 1, Int>,
|
||||
) {
|
||||
let [batch_size1, dim1] = input1.dims();
|
||||
let [batch_size2, dim2] = input2.dims();
|
||||
let [batch_size_target] = target.dims();
|
||||
|
||||
assert_eq!(
|
||||
batch_size1, batch_size2,
|
||||
"Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
dim1, dim2,
|
||||
"Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
batch_size1, batch_size_target,
|
||||
"Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn cosine_embedding_loss_positive_target() {
|
||||
let device = Default::default();
|
||||
|
||||
// Two identical vectors should have cosine similarity of 1
|
||||
let input1 = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let input2 = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
// Target 1 means that inputs should be similar
|
||||
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 1]), &device);
|
||||
|
||||
let loss = CosineEmbeddingLossConfig::new().init();
|
||||
let loss_no_reduction =
|
||||
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
|
||||
let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
|
||||
|
||||
let loss_sum = loss.forward(input1, input2, target);
|
||||
|
||||
// For identical vectors, 1 - cos_sim = 1 - 1 = 0
|
||||
let expected_no_reduction = TensorData::from([0.0, 0.0]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
|
||||
|
||||
let expected_mean = TensorData::from([0.0]);
|
||||
loss_mean
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
|
||||
|
||||
let expected_sum = TensorData::from([0.0]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_embedding_loss_negative_target() {
|
||||
let device = Default::default();
|
||||
|
||||
// Two identical vectors should have cosine similarity of 1
|
||||
let input1 = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let input2 = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
// Target -1 means that inputs should be dissimilar
|
||||
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([-1, -1]), &device);
|
||||
|
||||
// With margin 0.0, max(0, cos_sim - margin) = max(0, 1 - 0) = 1
|
||||
let loss = CosineEmbeddingLossConfig::new().init();
|
||||
let loss_no_reduction =
|
||||
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
|
||||
let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
|
||||
|
||||
// Create a loss with Sum reduction for testing
|
||||
let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum);
|
||||
let loss_sum =
|
||||
loss_sum_config
|
||||
.init()
|
||||
.forward(input1.clone(), input2.clone(), target.clone());
|
||||
|
||||
let expected_no_reduction = TensorData::from([1.0, 1.0]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
|
||||
|
||||
let expected_mean = TensorData::from([1.0]);
|
||||
loss_mean
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
|
||||
|
||||
let expected_sum = TensorData::from([2.0]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
|
||||
|
||||
// With margin 0.5, max(0, cos_sim - margin) = max(0, 1 - 0.5) = 0.5
|
||||
let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init();
|
||||
let loss_with_margin = loss_with_margin.forward(input1, input2, target);
|
||||
|
||||
let expected = TensorData::from([0.5]);
|
||||
loss_with_margin
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_embedding_loss_mixed_targets() {
|
||||
let device = Default::default();
|
||||
|
||||
let input1 = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let input2 = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
// Mixed targets
|
||||
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, -1]), &device);
|
||||
|
||||
let loss = CosineEmbeddingLossConfig::new().init();
|
||||
let loss_no_reduction =
|
||||
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
|
||||
let loss_mean = loss.forward(input1, input2, target);
|
||||
|
||||
let expected_no_reduction = TensorData::from([0.0, 1.0]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
|
||||
|
||||
let expected_mean = TensorData::from([0.5]);
|
||||
loss_mean
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = CosineEmbeddingLossConfig::new().with_margin(0.5);
|
||||
let loss = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{loss}"),
|
||||
"CosineEmbeddingLoss {margin: 0.5, reduction: Mean}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,466 @@
|
||||
use burn_core as burn;
|
||||
use burn_core::tensor::IndexingUpdateOp;
|
||||
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::activation::log_softmax;
|
||||
use burn::tensor::{Bool, Int, Tensor, backend::Backend};
|
||||
use burn::{config::Config, module::Module};
|
||||
|
||||
/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct CrossEntropyLossConfig {
|
||||
/// Create padded cross entropy.
|
||||
///
|
||||
/// Prevents pad tokens from impacting loss calculation.
|
||||
pub pad_tokens: Option<Vec<usize>>,
|
||||
|
||||
/// Create weighted cross-entropy.
|
||||
///
|
||||
/// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
|
||||
///
|
||||
/// # Pre-conditions
|
||||
/// - The order of the weight vector should correspond to the label integer assignment.
|
||||
/// - Targets assigned negative Int's will not be allowed.
|
||||
pub weights: Option<Vec<f32>>,
|
||||
|
||||
/// Create cross-entropy with label smoothing.
|
||||
///
|
||||
/// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
|
||||
/// Alpha = 0 would be the same as default.
|
||||
pub smoothing: Option<f32>,
|
||||
|
||||
/// Create cross-entropy with probabilities as input instead of logits.
|
||||
///
|
||||
#[config(default = true)]
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl CrossEntropyLossConfig {
|
||||
/// Initialize [Cross-entropy loss](CrossEntropyLoss).
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
|
||||
self.assertions();
|
||||
CrossEntropyLoss {
|
||||
pad_tokens: self.pad_tokens.clone(),
|
||||
weights: self
|
||||
.weights
|
||||
.as_ref()
|
||||
.map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
|
||||
smoothing: self.smoothing,
|
||||
logits: self.logits,
|
||||
}
|
||||
}
|
||||
|
||||
fn assertions(&self) {
|
||||
if let Some(alpha) = self.smoothing {
|
||||
assert!(
|
||||
(0.0..=1.).contains(&alpha),
|
||||
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
|
||||
);
|
||||
};
|
||||
if let Some(weights) = self.weights.as_ref() {
|
||||
assert!(
|
||||
weights.iter().all(|e| e > &0.),
|
||||
"Weights of cross-entropy have to be positive."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the cross entropy loss from the input logits and the targets.
|
||||
///
|
||||
/// Should be created using [CrossEntropyLossConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct CrossEntropyLoss<B: Backend> {
|
||||
/// Pad tokens to ignore in the loss calculation.
|
||||
pub pad_tokens: Option<Vec<usize>>,
|
||||
/// Weights for cross-entropy.
|
||||
pub weights: Option<Tensor<B, 1>>,
|
||||
/// Label smoothing factor.
|
||||
pub smoothing: Option<f32>,
|
||||
/// Use logits as input.
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
|
||||
alloc::format!("Vec<0..{}>", pad_tokens.len())
|
||||
} else {
|
||||
"None".to_string()
|
||||
};
|
||||
|
||||
content
|
||||
.add("pad_tokens", &pad_tokens)
|
||||
.add("weights", &self.weights)
|
||||
.add("smoothing", &self.smoothing)
|
||||
.add("logits", &self.logits)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> CrossEntropyLoss<B> {
|
||||
/// For backward compatibility.
|
||||
pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
|
||||
CrossEntropyLossConfig::new()
|
||||
.with_pad_tokens(pad_index.map(|e| vec![e]))
|
||||
.init(device)
|
||||
}
|
||||
|
||||
/// Compute the criterion on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - logits: `[batch_size, num_targets]`
|
||||
/// - targets: `[batch_size]`
|
||||
pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
|
||||
Self::assertions(logits.clone(), targets.clone());
|
||||
match self.smoothing {
|
||||
Some(alpha) => self.forward_smoothed(logits, targets, alpha),
|
||||
_ => self.forward_default(logits, targets),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_smoothed(
|
||||
&self,
|
||||
logits: Tensor<B, 2>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
alpha: f32,
|
||||
) -> Tensor<B, 1> {
|
||||
let mask = self.padding_mask(&targets);
|
||||
let tensor = if self.logits {
|
||||
log_softmax(logits, 1)
|
||||
} else {
|
||||
logits.log()
|
||||
};
|
||||
let [batch_size, nr_classes] = tensor.dims();
|
||||
let tensor = tensor
|
||||
* Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
|
||||
|
||||
match &self.weights {
|
||||
Some(weights) => {
|
||||
let tensor = tensor
|
||||
* weights
|
||||
.clone()
|
||||
.reshape([1, nr_classes])
|
||||
.repeat_dim(0, batch_size);
|
||||
let weights = weights.clone().gather(0, targets);
|
||||
let tensor = Self::apply_mask_2d(tensor, mask);
|
||||
tensor.sum().neg() / weights.sum()
|
||||
}
|
||||
None => {
|
||||
let tensor = Self::apply_mask_2d(tensor, mask);
|
||||
tensor.sum_dim(1).mean().neg()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
|
||||
let [batch_size] = targets.dims();
|
||||
|
||||
let mask = self.padding_mask(&targets);
|
||||
let tensor = log_softmax(logits, 1);
|
||||
let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
|
||||
|
||||
match &self.weights {
|
||||
Some(weights) => {
|
||||
let weights = weights.clone().gather(0, targets);
|
||||
let tensor = tensor.reshape([batch_size]) * weights.clone();
|
||||
let tensor = Self::apply_mask_1d(tensor, mask);
|
||||
tensor.sum().neg() / weights.sum()
|
||||
}
|
||||
None => {
|
||||
let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
|
||||
tensor.mean().neg()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_smoothed_targets(
|
||||
shape: [usize; 2],
|
||||
targets: Tensor<B, 1, Int>,
|
||||
alpha: f32,
|
||||
) -> Tensor<B, 2> {
|
||||
let [batch_size, nr_classes] = shape;
|
||||
let device = &targets.device();
|
||||
let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
|
||||
1,
|
||||
targets.reshape([batch_size, 1]),
|
||||
Tensor::ones([batch_size, 1], device),
|
||||
IndexingUpdateOp::Add,
|
||||
);
|
||||
targets_matrix * (1. - alpha) + alpha / nr_classes as f32
|
||||
}
|
||||
|
||||
fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
|
||||
let mut mask = None;
|
||||
if let Some(pad_tokens) = &self.pad_tokens {
|
||||
let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
|
||||
for x in pad_tokens {
|
||||
res = res + targets.clone().equal_elem(*x as i64).int();
|
||||
}
|
||||
mask = Some(res.greater_elem(0));
|
||||
}
|
||||
|
||||
mask
|
||||
}
|
||||
|
||||
fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
|
||||
if let Some(mask) = mask {
|
||||
tensor = tensor.mask_fill(mask, 0);
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
|
||||
if let Some(mask) = mask {
|
||||
let [batch_size, nr_classes] = tensor.dims();
|
||||
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
|
||||
let [logits_height, _] = logits.dims();
|
||||
let [targets_height] = targets.dims();
|
||||
assert!(
|
||||
logits_height == targets_height,
|
||||
"Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
macro_rules! setup {
|
||||
() => {{
|
||||
let [batch_size, num_targets] = [4, 5];
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::random(
|
||||
[batch_size, num_targets],
|
||||
Distribution::Normal(0., 1.0),
|
||||
&device,
|
||||
);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
|
||||
let targets_logits = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
(logits, targets, targets_logits)
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! setup_padded {
|
||||
() => {{
|
||||
let [batch_size, num_targets, pad_index] = [4, 5, 1];
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::random(
|
||||
[batch_size, num_targets],
|
||||
Distribution::Normal(0., 1.0),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 1, Int>::from_data(
|
||||
TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
|
||||
&device,
|
||||
);
|
||||
let targets_logits = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
(logits, targets, targets_logits)
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_loss_with_weights() {
|
||||
let (logits, targets, targets_logits) = setup!();
|
||||
let weights = vec![1.0, 2., 3., 4., 5.];
|
||||
let device = Default::default();
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.with_weights(Some(weights.clone()))
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets);
|
||||
let tensor = log_softmax(logits, 1);
|
||||
let loss_2 = tensor
|
||||
* targets_logits
|
||||
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
|
||||
.unsqueeze()
|
||||
.repeat_dim(0, 4);
|
||||
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_label_smoothing_with_weights_and_alpha_zero() {
|
||||
let (logits, targets, _) = setup!();
|
||||
let device = Default::default();
|
||||
let weights = vec![1.0, 2., 3., 4., 5.];
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.with_weights(Some(weights.clone()))
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets.clone());
|
||||
let loss_2 = CrossEntropyLossConfig::new()
|
||||
.with_weights(Some(weights.clone()))
|
||||
.with_smoothing(Some(0.))
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets);
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_loss() {
|
||||
let (logits, targets, targets_logits) = setup!();
|
||||
let device = Default::default();
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets);
|
||||
let loss_2 = cross_entropy_with_logits(logits, targets_logits);
|
||||
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_label_smoothing_alpha_equal_zero() {
|
||||
let (logits, targets, _) = setup!();
|
||||
let device = Default::default();
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets.clone());
|
||||
let loss_2 = CrossEntropyLossConfig::new()
|
||||
.with_smoothing(Some(0.))
|
||||
.init(&device)
|
||||
.forward(logits, targets);
|
||||
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_loss_with_pad_token() {
|
||||
let (logits, targets, targets_logits) = setup_padded!();
|
||||
let pad_index = 1;
|
||||
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.with_pad_tokens(Some(vec![pad_index, 2]))
|
||||
.init(&logits.device())
|
||||
.forward(logits.clone(), targets);
|
||||
let loss_2 = cross_entropy_with_logits(logits, targets_logits);
|
||||
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_label_smoothing_with_zero_alpha_and_pad_token() {
|
||||
let (logits, targets, _) = setup_padded!();
|
||||
let pad_index = 1;
|
||||
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.with_pad_tokens(Some(vec![pad_index, 2]))
|
||||
.init(&logits.device())
|
||||
.forward(logits.clone(), targets.clone());
|
||||
let loss_2 = CrossEntropyLossConfig::new()
|
||||
.with_pad_tokens(Some(vec![pad_index, 2]))
|
||||
.with_smoothing(Some(0.))
|
||||
.init(&logits.device())
|
||||
.forward(logits.clone(), targets);
|
||||
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_label_smoothing_target_conversion() {
|
||||
let (logits, targets, _) = setup!();
|
||||
let smoothed_targets =
|
||||
CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
|
||||
let targets_logits = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([
|
||||
[0.01, 0.01, 0.96, 0.01, 0.01],
|
||||
[0.96, 0.01, 0.01, 0.01, 0.01],
|
||||
[0.01, 0.01, 0.01, 0.01, 0.96],
|
||||
[0.01, 0.96, 0.01, 0.01, 0.01],
|
||||
]),
|
||||
&Default::default(),
|
||||
);
|
||||
smoothed_targets
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_label_smoothing() {
|
||||
let (logits, targets, _) = setup!();
|
||||
let device = Default::default();
|
||||
let loss_1 = CrossEntropyLossConfig::new()
|
||||
.with_smoothing(Some(0.05))
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets);
|
||||
let targets_logits = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([
|
||||
[0.01, 0.01, 0.96, 0.01, 0.01],
|
||||
[0.96, 0.01, 0.01, 0.01, 0.01],
|
||||
[0.01, 0.01, 0.01, 0.01, 0.96],
|
||||
[0.01, 0.96, 0.01, 0.01, 0.01],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let x = log_softmax(logits, 1);
|
||||
let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
|
||||
|
||||
loss_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = CrossEntropyLossConfig::new()
|
||||
.with_weights(Some(alloc::vec![3., 7., 0.9]))
|
||||
.with_smoothing(Some(0.5));
|
||||
let loss = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{loss}"),
|
||||
"CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
|
||||
);
|
||||
}
|
||||
}
|
||||
1730
crates/stable-diffusion-burn/burn-crates/burn-nn/src/loss/ctc.rs
Normal file
1730
crates/stable-diffusion-burn/burn-crates/burn-nn/src/loss/ctc.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,215 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::{config::Config, module::Module};
|
||||
|
||||
use super::Reduction;
|
||||
|
||||
/// Configuration to create a [Huber loss](HuberLoss).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct HuberLossConfig {
|
||||
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
|
||||
pub delta: f32,
|
||||
}
|
||||
|
||||
impl HuberLossConfig {
|
||||
/// Initialize [Huber loss](HuberLoss).
|
||||
pub fn init(&self) -> HuberLoss {
|
||||
self.assertions();
|
||||
HuberLoss {
|
||||
delta: self.delta,
|
||||
lin_bias: self.delta * self.delta * 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
fn assertions(&self) {
|
||||
assert!(
|
||||
self.delta >= 0., // This also tests for normality
|
||||
"Delta for Huber loss must be a non-negative number."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the Huber loss between the inputs and the target.
|
||||
///
|
||||
/// The loss for each element of the residuals `r = targets - predictions` is given by
|
||||
///
|
||||
/// ```text
|
||||
/// L(r) = 0.5 * r^2 if |r| <= d
|
||||
/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d
|
||||
/// ```
|
||||
///
|
||||
/// where `d` is the configured `delta`. In particular, this is equal to the
|
||||
/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
|
||||
/// but behaves linearly instead of quadratically for large residuals.
|
||||
///
|
||||
/// This loss function is less sensitive to outliers than the mean squared error loss.
|
||||
///
|
||||
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
|
||||
#[derive(Module, Debug, Clone)]
|
||||
#[module(custom_display)]
|
||||
pub struct HuberLoss {
|
||||
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
|
||||
pub delta: f32,
|
||||
/// Precomputed value for the linear bias.
|
||||
pub lin_bias: f32, // delta * delta * 0.5 precomputed
|
||||
}
|
||||
|
||||
impl ModuleDisplay for HuberLoss {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("delta", &self.delta)
|
||||
.add("lin_bias", &self.lin_bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl HuberLoss {
|
||||
/// Compute the loss element-wise for the predictions and targets, then reduce
|
||||
/// to a single loss value.
|
||||
///
|
||||
/// `Reduction::Auto` behaves as `Reduction::Mean`.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: \[...dims\]
|
||||
/// - targets: \[...dims\]
|
||||
/// - output: \[1\]
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let loss = self.forward_no_reduction(predictions, targets);
|
||||
match reduction {
|
||||
Reduction::Mean | Reduction::Auto => loss.mean(),
|
||||
Reduction::Sum => loss.sum(),
|
||||
other => panic!("{other:?} reduction is not supported"),
|
||||
}
|
||||
}
|
||||
/// Compute the loss element-wise for the predictions and targets.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: [...dims]
|
||||
/// - targets: [...dims]
|
||||
/// - output: [...dims]
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let residuals = targets - predictions;
|
||||
self.forward_residuals(residuals)
|
||||
}
|
||||
/// Compute the loss element-wise for the given residuals.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - residuals: [...dims]
|
||||
/// - output: [...dims]
|
||||
pub fn forward_residuals<const D: usize, B: Backend>(
|
||||
&self,
|
||||
residuals: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let is_large = residuals.clone().abs().greater_elem(self.delta);
|
||||
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
|
||||
// `sign()` function, in general, suffers from a jump at 0.
|
||||
// Instead the following tensor implements `delta * sign(r)` for values outside
|
||||
// the bound:
|
||||
let softsign = residuals.clone().clamp(-self.delta, self.delta);
|
||||
|
||||
// 0.5 * d^2 + d * (|r| - d) =
|
||||
// d * |r| - 0.5 * d^2
|
||||
// Moreover |r| = sign(r) * r
|
||||
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
|
||||
|
||||
let inside = residuals.square().mul_scalar(0.5);
|
||||
inside.mask_where(is_large, outside)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_huber_loss() {
|
||||
let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
|
||||
let targets = TensorData::from([0., 0., 0., 0., 0.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predict = TestTensor::<1>::from_data(predict, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let huber = HuberLossConfig::new(0.5).init();
|
||||
|
||||
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_no_reduction = huber.forward_no_reduction(predict, targets);
|
||||
|
||||
let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([0.284]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([1.42]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_huber_ad_loss() {
|
||||
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
|
||||
|
||||
let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
|
||||
let targets = TensorData::from([0., 0., 0., 0., 0.]);
|
||||
|
||||
let device = Default::default();
|
||||
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
|
||||
let targets = TestAutodiffTensor::from_data(targets, &device);
|
||||
|
||||
let loss = HuberLossConfig::new(0.5).init();
|
||||
let loss = loss.forward_no_reduction(predict.clone(), targets);
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads_predict = predict.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
|
||||
grads_predict
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = HuberLossConfig::new(0.5);
|
||||
let loss = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{loss}"),
|
||||
"HuberLoss {delta: 0.5, lin_bias: 0.125}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::Reduction;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::{config::Config, module::Module};
|
||||
|
||||
/// Configuration to create a [KLDiv loss](KLDivLoss).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct KLDivLossConfig {
|
||||
/// Specifies whether target is the log space. Default: False.
|
||||
#[config(default = false)]
|
||||
pub log_target: bool,
|
||||
}
|
||||
|
||||
impl KLDivLossConfig {
|
||||
/// Initialize [KLDiv Loss](KLDivLoss).
|
||||
pub fn init(&self) -> KLDivLoss {
|
||||
KLDivLoss {
|
||||
log_target: self.log_target,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Kullback-Leibler Divergence Loss
|
||||
///
|
||||
/// KL Divergence shows the difference between two probability distributions by measuring information loss
|
||||
///
|
||||
/// KLDivLoss =
|
||||
/// ```tex
|
||||
/// y_{true} \cdot (\log{y_{true}} - \log{y_{pred}})
|
||||
/// ```
|
||||
/// By default, the loss expects the input in the log-space.
|
||||
/// The targets may also be provided in the log-space if `log_target` is true.
|
||||
///
|
||||
/// See
|
||||
/// - [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence)
|
||||
#[derive(Module, Debug, Clone)]
|
||||
#[module(custom_display)]
|
||||
pub struct KLDivLoss {
|
||||
/// Specifies whether target is the log space. Default: False.
|
||||
pub log_target: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for KLDivLoss {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("log_target", &self.log_target).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl KLDivLoss {
|
||||
/// Compute the criterion on the input tensor.
|
||||
///
|
||||
/// `Reduction::Auto` behaves as `Reduction::BatchMean`,`Reduction::Mean` dose not align with the math definition.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: \[batch_size,num_targets\]
|
||||
/// - targets: \[batch_size,num_targets\]
|
||||
/// - output: \[1\]
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let loss = self.forward_no_reduction(predictions, targets);
|
||||
match reduction {
|
||||
Reduction::BatchMean | Reduction::Auto => {
|
||||
let batch_size = loss.dims()[0] as f32;
|
||||
loss.sum().div_scalar(batch_size)
|
||||
}
|
||||
Reduction::Mean => loss.mean(),
|
||||
Reduction::Sum => loss.sum(),
|
||||
}
|
||||
}
|
||||
/// Compute the criterion on the input tensor without reducing.
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
match self.log_target {
|
||||
true => targets.clone().exp().mul(targets.sub(predictions)),
|
||||
false => {
|
||||
let epsilon = 1e-8;
|
||||
let log_target = targets.clone().clamp(epsilon, 1.0).log();
|
||||
targets.mul(log_target.sub(predictions))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_kl_div_loss() {
|
||||
let predict = TensorData::from([[-1.0, -0.5], [-2.0, -0.2]]);
|
||||
let targets = TensorData::from([[0.4, 0.6], [0.1, 0.9]]);
|
||||
|
||||
let device = Default::default();
|
||||
let predict = TestTensor::<2>::from_data(predict, &device);
|
||||
let targets = TestTensor::<2>::from_data(targets, &device);
|
||||
|
||||
let kl_loss = KLDivLossConfig { log_target: false }.init();
|
||||
|
||||
let loss_sum = kl_loss.forward(predict.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss_batch_mean =
|
||||
kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
|
||||
let loss_no_reduction = kl_loss.forward_no_reduction(predict, targets);
|
||||
|
||||
let expected_no_reduction =
|
||||
TensorData::from([[0.0334837139, -0.0064953566], [-0.0302585065, 0.0851755068]]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::absolute(1e-5));
|
||||
|
||||
let expected_sum = TensorData::from([0.08191]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
|
||||
|
||||
let expected_batch_mean = TensorData::from([0.04095]);
|
||||
loss_batch_mean
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_batch_mean, Tolerance::absolute(1e-5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kl_div_loss_log_target() {
|
||||
let device = Default::default();
|
||||
let predict = TestTensor::<1>::from_data([-1.0, -2.0], &device);
|
||||
let targets = TestTensor::<1>::from_data([-0.5, -1.5], &device);
|
||||
|
||||
let kl_loss = KLDivLossConfig { log_target: true }.init();
|
||||
|
||||
let loss_no_reduction = kl_loss.forward_no_reduction(predict.clone(), targets.clone());
|
||||
let expected_none = TensorData::from([0.3032653299, 0.1115650801]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_none, Tolerance::absolute(1e-5));
|
||||
|
||||
let loss_batch_mean =
|
||||
kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
|
||||
let expected_bm = TensorData::from([0.207415204965]);
|
||||
loss_batch_mean
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_bm, Tolerance::absolute(1e-5));
|
||||
|
||||
let loss_sum = kl_loss.forward(predict, targets, Reduction::Sum);
|
||||
let expected_sum = TensorData::from([0.414830409931]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_kl_div_ad_loss() {
|
||||
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 2>;
|
||||
|
||||
let device = Default::default();
|
||||
let predict = TestAutodiffTensor::from_data([[-1.0, -0.5]], &device).require_grad();
|
||||
let targets = TestAutodiffTensor::from_data([[0.4, 0.6]], &device);
|
||||
|
||||
let kl_loss = KLDivLossConfig { log_target: false }.init();
|
||||
let loss = kl_loss.forward(predict.clone(), targets, Reduction::Sum);
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads_predict = predict.grad(&grads).unwrap();
|
||||
|
||||
// d/d_pred [target * (log_target - pred)] = -target
|
||||
let expected = TensorData::from([[-0.4, -0.6]]);
|
||||
grads_predict
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = KLDivLossConfig { log_target: true };
|
||||
let loss = config.init();
|
||||
|
||||
assert_eq!(alloc::format!("{loss}"), "KLDivLoss {log_target: true}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,672 @@
|
||||
use super::Reduction;
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
use burn_core as burn;
|
||||
|
||||
/// Configuration for the [Lp Loss](LpLoss) module.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn_nn::loss::{LpLossConfig, Reduction};
|
||||
///
|
||||
/// // Create L1 loss (MAE when using mean reduction)
|
||||
/// let l1_loss = LpLossConfig::l1();
|
||||
///
|
||||
/// // Create L2 loss (MSE when using mean reduction)
|
||||
/// let l2_loss = LpLossConfig::l2();
|
||||
///
|
||||
/// // Create custom Lp loss with p=3
|
||||
/// let l3_loss = LpLossConfig::new(3.0).init();
|
||||
/// ```
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LpLossConfig {
|
||||
/// The exponent `p` determining the type of error measurement.
|
||||
///
|
||||
/// Common values:
|
||||
/// - `p = 1.0`: L1 loss (MAE with mean reduction) - robust to outliers
|
||||
/// - `p = 2.0`: L2 loss (MSE with mean reduction) - standard choice, differentiable everywhere
|
||||
/// - `p > 2.0`: Increasingly sensitive to large errors (outliers)
|
||||
/// - `0 < p < 1`: More robust to outliers than L1 (quasi-norm)
|
||||
pub p: f64,
|
||||
}
|
||||
|
||||
impl LpLossConfig {
|
||||
/// Initializes a [Lp Loss](LpLoss) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `p <= 0`.
|
||||
pub fn init(&self) -> LpLoss {
|
||||
self.assertions();
|
||||
LpLoss { p: self.p }
|
||||
}
|
||||
|
||||
/// Creates L1 loss (p=1).
|
||||
///
|
||||
/// When used with `Reduction::Mean`, this computes Mean Absolute Error (MAE).
|
||||
/// When used with `Reduction::Sum`, this computes Sum of Absolute Errors (SAE).
|
||||
pub fn l1() -> LpLoss {
|
||||
LpLoss { p: 1.0 }
|
||||
}
|
||||
|
||||
/// Creates L2 loss (p=2).
|
||||
///
|
||||
/// When used with `Reduction::Mean`, this computes Mean Squared Error (MSE).
|
||||
/// When used with `Reduction::Sum`, this computes Sum of Squared Errors (SSE).
|
||||
pub fn l2() -> LpLoss {
|
||||
LpLoss { p: 2.0 }
|
||||
}
|
||||
|
||||
fn assertions(&self) {
|
||||
assert!(self.p > 0.0, "The order of the norm p must be positive.")
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the Lp Loss between predictions and targets.
|
||||
///
|
||||
/// This loss function computes the element-wise p-th power of absolute errors,
|
||||
/// then reduces them via mean or sum.
|
||||
///
|
||||
/// # Mathematical Definition
|
||||
///
|
||||
/// For predictions `ŷ` and targets `y`, the element-wise loss is:
|
||||
///
|
||||
/// ```text
|
||||
/// Lᵢ = |ŷᵢ - yᵢ|ᵖ
|
||||
/// ```
|
||||
///
|
||||
/// With mean reduction (default), the final loss is:
|
||||
///
|
||||
/// ```text
|
||||
/// L = (1/n) × Σᵢ |ŷᵢ - yᵢ|ᵖ
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// - This implementation computes `|error|^p`, **not** the Lp norm `(Σ|error|^p)^(1/p)`.
|
||||
/// - The `p = 1` case uses an optimized `abs()` operation.
|
||||
/// - The `p = 2` case uses an optimized computation `error * error` instead of `powf`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn_nn::loss::{LpLossConfig, Reduction};
|
||||
/// use burn::tensor::Tensor;
|
||||
///
|
||||
/// // Create L2 loss
|
||||
/// let l2_loss = LpLossConfig::l2();
|
||||
///
|
||||
/// let predictions: Tensor<Backend, 2> = /* model output */;
|
||||
/// let targets: Tensor<Backend, 2> = /* ground truth */;
|
||||
///
|
||||
/// // Compute loss with mean reduction (MSE)
|
||||
/// let mse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Mean);
|
||||
///
|
||||
/// // Compute loss with sum reduction (SSE)
|
||||
/// let sse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Sum);
|
||||
///
|
||||
/// // Compute loss with no reduction
|
||||
/// let unreduced_l2_loss = l2_loss.forward_no_reduction(predictions, targets);
|
||||
/// ```
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct LpLoss {
|
||||
/// The order of the norm (e.g., 1 for L1, 2 for L2).
|
||||
/// Equivalently, the exponent `p` for computing `|error|^p`.
|
||||
pub p: f64,
|
||||
}
|
||||
|
||||
impl LpLoss {
|
||||
/// Computes the element-wise loss `|error|^p` with reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `predictions` - The model's predicted values.
|
||||
/// * `targets` - The ground truth target values.
|
||||
/// * `reduction` - Specifies how to reduce the element-wise losses:
|
||||
/// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.
|
||||
/// - `Reduction::Sum`: Returns the sum of all element-wise losses.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A scalar tensor containing the reduced loss value.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: `[...dims]` - Any shape
|
||||
/// - targets: `[...dims]` - Must match predictions shape
|
||||
/// - output: `[1]` - Scalar loss value
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let unreduced_loss = self.forward_no_reduction(predictions, targets);
|
||||
|
||||
match reduction {
|
||||
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
|
||||
Reduction::Sum => unreduced_loss.sum(),
|
||||
other => panic!("{other:?} reduction is not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the element-wise loss `|error|^p` without reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `predictions` - The model's predicted values.
|
||||
/// * `targets` - The ground truth target values.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of the same shape as the inputs, containing `|prediction - target|^p`
|
||||
/// for each element.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: `[...dims]` - Any shape
|
||||
/// - targets: `[...dims]` - Must match predictions shape
|
||||
/// - output: `[...dims]` - Same shape as inputs
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let error = predictions.sub(targets);
|
||||
|
||||
// Use simplified/optimized expressions for common cases (p = 1, p = 2)
|
||||
if self.p == 1.0 {
|
||||
// L1 loss
|
||||
error.abs()
|
||||
} else if self.p == 2.0 {
|
||||
// L2 loss
|
||||
error.clone().mul(error)
|
||||
} else {
|
||||
error.abs().powf_scalar(self.p)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the element-wise loss `|error|^p` with reduction over specified dimensions.
|
||||
///
|
||||
/// Calculates element-wise `|predictions - targets|^p`, then takes the mean
|
||||
/// over the specified dimensions. Useful for per-sample or per-channel losses (e.g., when
|
||||
/// working with images).
|
||||
///
|
||||
/// Dimensions can be provided in any order. They are sorted internally and
|
||||
/// reduced from highest to lowest to ensure indices remain valid.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `predictions` - The model's predicted values.
|
||||
/// * `targets` - The ground truth target values.
|
||||
/// * `dims` - Dimensions to reduce over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the specified dimensions reduced to size 1.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// // Image tensor: [batch, C, H, W]
|
||||
/// let l2_loss = LpLossConfig::l2();
|
||||
///
|
||||
/// // Per-image MSE for PSNR: reduce over C, H, W → [batch, 1, 1, 1]
|
||||
/// let mse_per_image = l2_loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
|
||||
/// ```
|
||||
pub fn forward_reduce_dims<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
dims: &[usize],
|
||||
) -> Tensor<B, D> {
|
||||
let error = self.forward_no_reduction(predictions, targets);
|
||||
|
||||
// Sort the dimensions to ascending order
|
||||
let mut sorted_dims = dims.to_vec();
|
||||
sorted_dims.sort();
|
||||
|
||||
// Reduce over specified dimensions
|
||||
error.mean_dims(sorted_dims.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_l1_constructor() {
|
||||
let loss_func_l1 = LpLossConfig::l1();
|
||||
let loss_func_p1 = LpLossConfig::new(1.0).init();
|
||||
assert_eq!(loss_func_l1.p, 1.0);
|
||||
assert_eq!(loss_func_l1.p, loss_func_p1.p);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_l2_constructor() {
|
||||
let loss_func_l2 = LpLossConfig::l2();
|
||||
let loss_func_p2 = LpLossConfig::new(2.0).init();
|
||||
assert_eq!(loss_func_l2.p, 2.0);
|
||||
assert_eq!(loss_func_l2.p, loss_func_p2.p);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_l1() {
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let loss_func = LpLossConfig::l1();
|
||||
let loss_no_reduction =
|
||||
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
|
||||
|
||||
let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
|
||||
loss_no_reduction.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([1.0]);
|
||||
loss_auto.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([4.0]);
|
||||
loss_sum.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_l2() {
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let loss_func = LpLossConfig::l2();
|
||||
let loss_no_reduction =
|
||||
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
|
||||
|
||||
let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
|
||||
loss_no_reduction.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([1.5]);
|
||||
loss_auto.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([6.0]);
|
||||
loss_sum.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_p_half() {
|
||||
// L0.5 quasi-norm: more robust to outliers than L1
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[2.0, 1.0], [3.0, 0.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let loss_func = LpLossConfig::new(0.5).init();
|
||||
let loss_no_reduction =
|
||||
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
|
||||
|
||||
// |1-2|^0.5 = 1, |2-1|^0.5 = 1, |3-3|^0.5 = 0, |4-0|^0.5 = 2
|
||||
let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
|
||||
loss_no_reduction.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([1.0]);
|
||||
loss_auto.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([4.0]);
|
||||
loss_sum.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_p3() {
|
||||
// L3 norm: more sensitive to outliers than L2
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let loss_func = LpLossConfig::new(3.0).init();
|
||||
let loss_no_reduction =
|
||||
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
|
||||
|
||||
// |1-2|^3 = 1, |2-1|^3 = 1, |3-3|^3 = 0, |4-2|^3 = 8
|
||||
let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]);
|
||||
loss_no_reduction.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([2.5]);
|
||||
loss_auto.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([10.0]);
|
||||
loss_sum.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_zero_error() {
|
||||
// Test when predictions exactly match targets
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let targets = predictions.clone();
|
||||
|
||||
let loss_func_l1 = LpLossConfig::l1();
|
||||
let loss_func_l2 = LpLossConfig::l2();
|
||||
|
||||
let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto);
|
||||
|
||||
let expected = TensorData::from([0.0]);
|
||||
l1_loss.into_data().assert_eq(&expected, false);
|
||||
l2_loss.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_negative_errors() {
|
||||
// Test that negative errors are handled correctly (absolute value)
|
||||
let device = Default::default();
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 1>::from_data(TensorData::from([3.0, 4.0, 5.0]), &device);
|
||||
let loss_func_l1 = LpLossConfig::l1();
|
||||
let loss_func_p1 = LpLossConfig::new(1.0).init();
|
||||
|
||||
let loss_no_reduction_l1 =
|
||||
loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets);
|
||||
|
||||
// All errors are negative: 1-3=-2, 2-4=-2, 3-5=-2, but |error| = 2
|
||||
let expected = TensorData::from([2.0, 2.0, 2.0]);
|
||||
loss_no_reduction_l1.into_data().assert_eq(&expected, false);
|
||||
loss_no_reduction_p1.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_3d_tensor() {
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func_l2 = LpLossConfig::l2();
|
||||
let loss_func_p2 = LpLossConfig::new(2.0).init();
|
||||
|
||||
let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto);
|
||||
|
||||
// Errors: 1, 0, 0, -1, 1, 0, 0, -2
|
||||
// Squared: 1, 0, 0, 1, 1, 0, 0, 4
|
||||
// Mean: 7/8 = 0.875
|
||||
let expected = TensorData::from([0.875]);
|
||||
loss_l2.into_data().assert_eq(&expected, false);
|
||||
loss_p2.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The order of the norm p must be positive.")]
|
||||
fn test_lp_loss_negative_p_panics() {
|
||||
let _ = LpLossConfig::new(-1.0).init();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The order of the norm p must be positive.")]
|
||||
fn test_lp_loss_zero_p_panics() {
|
||||
let _ = LpLossConfig::new(0.0).init();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_loss_fractional_p() {
|
||||
// Test p = 1.5
|
||||
let device = Default::default();
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.0, 4.0]), &device);
|
||||
|
||||
let targets = Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 0.0]), &device);
|
||||
|
||||
let loss_func = LpLossConfig::new(1.5).init();
|
||||
let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
|
||||
|
||||
// |0-1|^1.5 = 1, |4-0|^1.5 = 8
|
||||
let expected = TensorData::from([1.0, 8.0]);
|
||||
loss_no_reduction.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_single_dim() {
|
||||
let device = Default::default();
|
||||
// Shape: [2, 3]
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func_l2 = LpLossConfig::l2();
|
||||
let loss_func_p2 = LpLossConfig::new(2.0).init();
|
||||
|
||||
// Reduce over dim 1 -> should give [2, 1] shape
|
||||
let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]);
|
||||
let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]);
|
||||
|
||||
// Errors row 0: [1, 0, -3] -> squared: [1, 0, 9] -> mean: 10/3
|
||||
// Errors row 1: [3, 0, 0] -> squared: [9, 0, 0] -> mean: 3
|
||||
let expected = TensorData::from([[10.0 / 3.0], [3.0]]);
|
||||
loss_l2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
loss_p2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_first_dim() {
|
||||
let device = Default::default();
|
||||
// Shape: [2, 3]
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func = LpLossConfig::l2();
|
||||
|
||||
// Reduce over dim 0 -> should give [1, 3] shape
|
||||
let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]);
|
||||
|
||||
// Squared errors: [[1, 0, 9], [9, 0, 0]]
|
||||
// Mean over dim 0: [5, 0, 4.5]
|
||||
let expected = TensorData::from([[5.0, 0.0, 4.5]]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_multiple_dims() {
|
||||
let device = Default::default();
|
||||
// Shape: [2, 2, 2]
|
||||
let predictions = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func = LpLossConfig::l2();
|
||||
|
||||
// Reduce over dims 1 and 2 -> should give [2, 1, 1] shape
|
||||
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
|
||||
|
||||
// Batch 0 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25
|
||||
// Batch 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25
|
||||
let expected = TensorData::from([[[1.25]], [[1.25]]]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_all_dims() {
|
||||
let device = Default::default();
|
||||
// Shape: [2, 2]
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func = LpLossConfig::l2();
|
||||
|
||||
// Reduce over all dims -> should give [1, 1] shape
|
||||
let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]);
|
||||
|
||||
// Errors: [[-1, 1], [0, 2]] -> squared: [[1, 1], [0, 4]] -> mean: 1.5
|
||||
let expected = TensorData::from([[1.5]]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_image_batch() {
|
||||
// Simulate per-image loss for [batch, C, H, W] tensor (common use case for PSNR)
|
||||
let device = Default::default();
|
||||
// Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)
|
||||
let predictions = Tensor::<TestBackend, 4>::from_data(
|
||||
TensorData::from([
|
||||
[[[1.0, 2.0], [3.0, 4.0]]], // Image 1
|
||||
[[[5.0, 6.0], [7.0, 8.0]]], // Image 2
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 4>::from_data(
|
||||
TensorData::from([
|
||||
[[[0.0, 2.0], [3.0, 6.0]]], // Target 1
|
||||
[[[5.0, 5.0], [7.0, 7.0]]], // Target 2
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
let loss_func = LpLossConfig::l2();
|
||||
|
||||
// Reduce over C, H, W (dims 1, 2, 3) to get per-image MSE
|
||||
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
|
||||
|
||||
// Image 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 1.25
|
||||
// Image 2 errors: [[0, 1], [0, 1]] -> squared: [[0, 1], [0, 1]] -> mean: 0.5
|
||||
let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_with_p1() {
|
||||
let device = Default::default();
|
||||
// Shape: [2, 3]
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func = LpLossConfig::l1();
|
||||
|
||||
// Reduce over dim 1 -> should give [2, 1] shape
|
||||
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]);
|
||||
|
||||
// Abs errors row 0: [1, 3, 0] -> mean: 4/3
|
||||
// Abs errors row 1: [3, 0, 3] -> mean: 2
|
||||
let expected = TensorData::from([[4.0 / 3.0], [2.0]]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_empty_dims() {
|
||||
// Reducing over no dimensions should return the unreduced loss
|
||||
let device = Default::default();
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.0, 2.0], [3.0, 6.0]]),
|
||||
&device,
|
||||
);
|
||||
let loss_func = LpLossConfig::l2();
|
||||
let loss_reduce_dims =
|
||||
loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
|
||||
let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
|
||||
|
||||
// Should be equivalent
|
||||
loss_reduce_dims
|
||||
.into_data()
|
||||
.assert_eq(&loss_no_reduction.into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reduce_dims_zero_error() {
|
||||
let device = Default::default();
|
||||
// Shape: [2, 2, 2]
|
||||
let predictions = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
|
||||
&device,
|
||||
);
|
||||
let targets = predictions.clone();
|
||||
let loss_func = LpLossConfig::l2();
|
||||
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
|
||||
|
||||
// All zeros, reduced to shape: [2, 1, 1]
|
||||
let expected = TensorData::from([[[0.0]], [[0.0]]]);
|
||||
loss.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
mod binary_cross_entropy;
|
||||
mod cosine_embedding;
|
||||
mod cross_entropy;
|
||||
mod ctc;
|
||||
mod huber;
|
||||
mod kldiv;
|
||||
mod lp_loss;
|
||||
mod mse;
|
||||
mod poisson;
|
||||
mod reduction;
|
||||
mod smooth_l1;
|
||||
|
||||
pub use binary_cross_entropy::*;
|
||||
pub use cosine_embedding::*;
|
||||
pub use cross_entropy::*;
|
||||
pub use ctc::*;
|
||||
pub use huber::*;
|
||||
pub use kldiv::*;
|
||||
pub use lp_loss::*;
|
||||
pub use mse::*;
|
||||
pub use poisson::*;
|
||||
pub use reduction::*;
|
||||
pub use smooth_l1::*;
|
||||
@@ -0,0 +1,93 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::loss::reduction::Reduction;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
/// Calculate the mean squared error loss from the input logits and the targets.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct MseLoss;
|
||||
|
||||
impl Default for MseLoss {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MseLoss {
|
||||
/// Create the criterion.
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Compute the criterion on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - logits: [batch_size, num_targets]
|
||||
/// - targets: [batch_size, num_targets]
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let tensor = self.forward_no_reduction(logits, targets);
|
||||
match reduction {
|
||||
Reduction::Mean | Reduction::Auto => tensor.mean(),
|
||||
Reduction::Sum => tensor.sum(),
|
||||
other => panic!("{other:?} reduction is not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the criterion on the input tensor without reducing.
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
logits.sub(targets).square()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_mse_loss() {
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let mse = MseLoss::new();
|
||||
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
|
||||
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_sum = mse.forward(logits, targets, Reduction::Sum);
|
||||
|
||||
let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
|
||||
loss_no_reduction.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([1.5]);
|
||||
loss.into_data().assert_eq(&expected, false);
|
||||
|
||||
let expected = TensorData::from([6.0]);
|
||||
loss_sum.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let loss = MseLoss::new();
|
||||
assert_eq!(alloc::format!("{loss}"), "MseLoss");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,417 @@
|
||||
use burn_core as burn;
|
||||
use core::f32::consts::PI;
|
||||
|
||||
use burn::tensor::cast::ToElement;
|
||||
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::{config::Config, module::Module};
|
||||
|
||||
use super::Reduction;
|
||||
|
||||
/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance.
|
||||
///
|
||||
/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss
|
||||
/// behavior, such as whether the input is in log-space, whether to include the Stirling
|
||||
/// approximation term, and a small epsilon value to avoid numerical instability.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PoissonNllLossConfig {
|
||||
/// If `true`, the predictions are expected to be in log-space.
|
||||
///
|
||||
/// When `log_input` is `true`, the loss is computed as:
|
||||
/// ```text
|
||||
/// L(predictions, target) = exp(predictions) - target * predictions
|
||||
/// ```
|
||||
/// When `log_input` is `false`, the loss is computed as:
|
||||
/// ```text
|
||||
/// L(predictions, target) = predictions - target * log(predictions + eps)
|
||||
/// ```
|
||||
#[config(default = true)]
|
||||
pub log_input: bool,
|
||||
/// Whether to compute the full loss, including the Stirling approximation term.
|
||||
///
|
||||
/// When `full` is `true`, the Stirling approximation term is added to the loss:
|
||||
/// ```text
|
||||
/// target * log(target) - target + 0.5 * log(2 * PI * target)
|
||||
/// ```
|
||||
#[config(default = false)]
|
||||
pub full: bool,
|
||||
/// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
|
||||
///
|
||||
/// This epsilon value is added to the predictions to ensure numerical stability
|
||||
/// when computing the logarithm.
|
||||
#[config(default = 1e-8)]
|
||||
pub eps: f64,
|
||||
}
|
||||
|
||||
impl PoissonNllLossConfig {
|
||||
/// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration.
|
||||
///
|
||||
/// # Panics
|
||||
/// - Panics if `eps` is not a positive number.
|
||||
pub fn init(&self) -> PoissonNllLoss {
|
||||
self.assertions();
|
||||
PoissonNllLoss {
|
||||
log_input: self.log_input,
|
||||
full: self.full,
|
||||
eps: self.eps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validates the configuration parameters.
|
||||
///
|
||||
/// # Panics
|
||||
/// - Panics if `eps` is not a positive number.
|
||||
fn assertions(&self) {
|
||||
assert!(
|
||||
self.eps > 0.,
|
||||
"eps for PoissonNllLoss must be a positive number."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target.
|
||||
///
|
||||
/// This loss function is used when the target values are assumed to follow a Poisson distribution.
|
||||
/// The loss is defined as:
|
||||
/// ```text
|
||||
/// target ~ Poisson(input)
|
||||
/// L(predictions, target) = predictions - target * log(predictions) + log(target!)
|
||||
/// ```
|
||||
/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula.
|
||||
/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss.
|
||||
///
|
||||
/// For more details, see:
|
||||
/// <https://en.wikipedia.org/wiki/Poisson_regression#Maximum_likelihood-based_parameter_estimation>
|
||||
#[derive(Module, Debug, Clone)]
|
||||
#[module(custom_display)]
|
||||
pub struct PoissonNllLoss {
|
||||
/// If `true`, the predictions are expected to be in log-space.
|
||||
pub log_input: bool,
|
||||
/// Whether to compute the full loss, including the Stirling approximation term.
|
||||
pub full: bool,
|
||||
/// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
|
||||
pub eps: f64,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for PoissonNllLoss {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("log_input", &self.log_input)
|
||||
.add("full", &self.full)
|
||||
.add("eps", &self.eps)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl PoissonNllLoss {
|
||||
/// Computes the loss element-wise for the given predictions and targets, then reduces
|
||||
/// the result to a single loss value.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `predictions`: The predicted values.
|
||||
/// - `targets`: The target values.
|
||||
/// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `predictions`: `[...dims]`
|
||||
/// - `targets`: `[...dims]`
|
||||
/// - `output`: `[1]`
|
||||
///
|
||||
/// # Panics
|
||||
/// - Panics if the shapes of `predictions` and `targets` do not match.
|
||||
/// - Panics if any target value is negative.
|
||||
/// - Panics if `log_input` is `false` and any prediction value is negative.
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let loss = self.forward_no_reduction(predictions, targets);
|
||||
match reduction {
|
||||
Reduction::Mean | Reduction::Auto => loss.mean(),
|
||||
Reduction::Sum => loss.sum(),
|
||||
other => panic!("{other:?} reduction is not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the loss element-wise for the given predictions and targets without reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `predictions`: The predicted values.
|
||||
/// - `targets`: The target values.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `predictions`: `[...dims]`
|
||||
/// - `targets`: `[...dims]`
|
||||
/// - `output`: `[...dims]`
|
||||
///
|
||||
/// # Panics
|
||||
/// - Panics if the shapes of `predictions` and `targets` do not match.
|
||||
/// - Panics if any target value is negative.
|
||||
/// - Panics if `log_input` is `false` and any prediction value is negative.
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
self.assertions(&predictions, &targets);
|
||||
let mut loss;
|
||||
if self.log_input {
|
||||
loss = predictions.clone().exp() - targets.clone() * predictions;
|
||||
} else {
|
||||
loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
|
||||
}
|
||||
if self.full {
|
||||
let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
|
||||
+ (targets.clone() * 2. * PI).log() * 0.5;
|
||||
loss = loss
|
||||
+ log_stirling_term
|
||||
.mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
|
||||
}
|
||||
loss
|
||||
}
|
||||
|
||||
/// Validates the input tensors for the loss computation.
|
||||
///
|
||||
/// # Panics
|
||||
/// - Panics if the shapes of `predictions` and `targets` do not match.
|
||||
/// - Panics if any target value is negative.
|
||||
/// - Panics if `log_input` is `false` and any prediction value is negative.
|
||||
fn assertions<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: &Tensor<B, D>,
|
||||
targets: &Tensor<B, D>,
|
||||
) {
|
||||
let predictions_dims = predictions.dims();
|
||||
let targets_dims = targets.dims();
|
||||
assert!(
|
||||
predictions_dims == targets_dims,
|
||||
"Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?})."
|
||||
);
|
||||
assert!(
|
||||
targets
|
||||
.clone()
|
||||
.greater_equal_elem(0.)
|
||||
.all()
|
||||
.into_scalar()
|
||||
.to_bool(),
|
||||
"All the values of `targets` must be non-negative."
|
||||
);
|
||||
if !self.log_input {
|
||||
assert!(
|
||||
predictions
|
||||
.clone()
|
||||
.greater_equal_elem(0.)
|
||||
.all()
|
||||
.into_scalar()
|
||||
.to_bool(),
|
||||
"When `log_input` is `false`, all the values of `predictions` must be non-negative."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::approx_constant)]
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_poisson_nll_loss() {
|
||||
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
|
||||
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions = TestTensor::<1>::from_data(predictions, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().init();
|
||||
|
||||
let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
|
||||
|
||||
let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([21.0321]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([126.1929]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poisson_nll_loss_no_log_input() {
|
||||
let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
|
||||
let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions = TestTensor::<1>::from_data(predictions, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
|
||||
|
||||
let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
|
||||
let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poisson_nll_loss_full() {
|
||||
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
|
||||
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions = TestTensor::<1>::from_data(predictions, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().with_full(true).init();
|
||||
|
||||
let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
|
||||
|
||||
let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([21.9920]);
|
||||
loss.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([131.9518]);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_poisson_nll_loss_gradients() {
|
||||
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
|
||||
|
||||
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
|
||||
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
|
||||
let predictions2 = predictions1.clone();
|
||||
let targets = TestAutodiffTensor::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().with_full(false).init();
|
||||
let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
|
||||
|
||||
let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss_full_sum =
|
||||
poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
|
||||
|
||||
let grads = loss_sum.backward();
|
||||
let grads_full = loss_full_sum.backward();
|
||||
|
||||
let grads_predictions1 = predictions1.grad(&grads).unwrap();
|
||||
let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
|
||||
|
||||
let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
|
||||
|
||||
grads_predictions1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
grads_predictions2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "eps for PoissonNllLoss must be a positive number."]
|
||||
fn test_negative_eps() {
|
||||
let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "All the values of `targets` must be non-negative."]
|
||||
fn test_targets_with_negative_values() {
|
||||
let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
|
||||
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions = TestTensor::<1>::from_data(predictions, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().init();
|
||||
|
||||
let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Shape of targets"]
|
||||
fn test_shape_tensors() {
|
||||
let predictions = TensorData::from([0., 1., 2.]);
|
||||
let targets = TensorData::from([0., 1.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions = TestTensor::<1>::from_data(predictions, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().init();
|
||||
|
||||
let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
|
||||
fn test_exp_predictions_non_negative() {
|
||||
let predictions = TensorData::from([0.3, -0.1, 0.4]);
|
||||
let targets = TensorData::from([0., 1., 0.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predictions = TestTensor::<1>::from_data(predictions, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
|
||||
|
||||
let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = PoissonNllLossConfig::new();
|
||||
let loss = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{loss}"),
|
||||
"PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
|
||||
/// The reduction type for the loss.
|
||||
#[derive(Config, Debug)]
|
||||
pub enum Reduction {
|
||||
/// The mean of the losses will be returned.
|
||||
Mean,
|
||||
|
||||
/// The sum of the losses will be returned.
|
||||
Sum,
|
||||
|
||||
/// The sum of the losses divided by the batch_size will be returned.
|
||||
BatchMean,
|
||||
|
||||
/// The mean of the losses will be returned.
|
||||
Auto,
|
||||
}
|
||||
@@ -0,0 +1,520 @@
|
||||
use super::Reduction;
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
use burn_core as burn;
|
||||
|
||||
/// Configuration for the [SmoothL1Loss](SmoothL1Loss) module.
|
||||
///
|
||||
/// Smooth L1 loss combines L1 and L2 loss, using L2 loss for small errors (below beta)
|
||||
/// and L1 loss for large errors (above beta). This makes it less sensitive to outliers
|
||||
/// than MSE while maintaining smooth gradients near zero.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn_nn::loss::{SmoothL1LossConfig, Reduction};
|
||||
///
|
||||
/// // Create Smooth L1 loss with default beta=1.0
|
||||
/// let smooth_l1 = SmoothL1LossConfig::new().init();
|
||||
///
|
||||
/// // Create with custom beta
|
||||
/// let smooth_l1_custom = SmoothL1LossConfig::new().with_beta(0.5).init();
|
||||
/// ```
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SmoothL1LossConfig {
|
||||
/// Specifies the threshold at which to change between L1 and L2 loss.
|
||||
/// The value must be positive. Default: 1.0
|
||||
#[config(default = 1.0)]
|
||||
pub beta: f32,
|
||||
}
|
||||
|
||||
impl SmoothL1LossConfig {
|
||||
/// Initializes a [Smooth L1 Loss](SmoothL1Loss) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `beta <= 0`.
|
||||
pub fn init(&self) -> SmoothL1Loss {
|
||||
self.assertions();
|
||||
SmoothL1Loss { beta: self.beta }
|
||||
}
|
||||
|
||||
fn assertions(&self) {
|
||||
assert!(self.beta > 0.0, "The parameter beta must be positive.")
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the Smooth L1 Loss between predictions and targets.
|
||||
///
|
||||
/// This loss function uses L2 loss for small errors (below beta) and L1 loss for
|
||||
/// large errors (above beta), providing robustness to outliers while maintaining
|
||||
/// smooth gradients near |x - y| = 0.
|
||||
///
|
||||
/// # Mathematical Definition
|
||||
///
|
||||
/// For predictions `x` and targets `y`, the element-wise loss is:
|
||||
///
|
||||
/// - L_i = 0.5 * (x_i - y_i)² / beta , if |x_i - y_i| < beta
|
||||
/// - L_i = |x_i - y_i| - 0.5 * beta , otherwise
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Smooth L1 loss is closely related to HuberLoss since it is equivalent to HuberLoss
|
||||
/// scaled by `1/beta`:
|
||||
/// `SmoothL1(x, y, beta) = Huber(x, y, beta) / beta`
|
||||
///
|
||||
/// This leads to the following differences:
|
||||
///
|
||||
/// - As beta approaches 0, Smooth L1 loss converges to L1Loss, while HuberLoss converges to 0.
|
||||
/// When beta = 0, Smooth L1 loss is equivalent to L1 loss. Thus, the `beta`
|
||||
/// parameter in Burn must be positive. L1Loss should be used for beta = 0.
|
||||
/// - As beta approaches positive infinity, Smooth L1 loss converges to a constant 0 loss, while
|
||||
/// HuberLoss converges to L2Loss.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_nn::loss::{SmoothL1LossConfig, Reduction};
|
||||
/// use burn::tensor::Tensor;
|
||||
///
|
||||
/// // Create Smooth L1 loss with the default beta=1.0
|
||||
/// let smooth_l1 = SmoothL1LossConfig::new().init();
|
||||
///
|
||||
/// let predictions: Tensor<Backend, 2> = /* model output */;
|
||||
/// let targets: Tensor<Backend, 2> = /* ground truth */;
|
||||
///
|
||||
/// // Compute element-wise loss without reduction
|
||||
/// let element_wise = smooth_l1.forward(predictions.clone(), targets.clone());
|
||||
///
|
||||
/// // Compute loss with mean reduction
|
||||
/// let loss = smooth_l1.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);
|
||||
///
|
||||
/// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1]
|
||||
/// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
|
||||
/// ```
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct SmoothL1Loss {
|
||||
/// Specifies the threshold at which to change between L1 and L2 loss.
|
||||
/// The value must be positive. Default: 1.0
|
||||
pub beta: f32,
|
||||
}
|
||||
|
||||
impl SmoothL1Loss {
|
||||
/// Computes the element-wise smooth L1 loss without reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `predictions` - The model's predicted values.
|
||||
/// - `targets` - The ground truth target values.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of the same shape as the inputs, containing the smooth L1 loss
|
||||
/// for each element.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: `[...dims]` - Any shape
|
||||
/// - targets: `[...dims]` - Must match predictions shape
|
||||
/// - output: `[...dims]` - Same shape as inputs
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let error = predictions.sub(targets);
|
||||
let abs_error = error.clone().abs();
|
||||
|
||||
// The L1 case: |error| - 0.5 * beta (when |error| >= beta)
|
||||
let l1_loss = abs_error.clone().sub_scalar(0.5 * self.beta);
|
||||
|
||||
// The L2 case: 0.5 * (error)^2 / beta (when |error| < beta)
|
||||
let l2_loss = error.square().mul_scalar(0.5).div_scalar(self.beta);
|
||||
|
||||
let l2_mask = abs_error.lower_elem(self.beta);
|
||||
l1_loss.mask_where(l2_mask, l2_loss)
|
||||
}
|
||||
|
||||
/// Computes the smooth L1 loss with reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `predictions` - The model's predicted values.
|
||||
/// - `targets` - The ground truth target values.
|
||||
/// - `reduction` - Specifies how to reduce the element-wise losses:
|
||||
/// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.
|
||||
/// - `Reduction::Sum`: Returns the sum of all element-wise losses.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A scalar tensor containing the reduced loss value.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: `[...dims]` - Any shape
|
||||
/// - targets: `[...dims]` - Must match predictions shape
|
||||
/// - output: `[1]` - Scalar loss value
|
||||
pub fn forward_with_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let unreduced_loss = self.forward(predictions, targets);
|
||||
|
||||
match reduction {
|
||||
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
|
||||
Reduction::Sum => unreduced_loss.sum(),
|
||||
other => panic!("{other:?} reduction is not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the smooth L1 loss with reduction over specified dimensions.
|
||||
///
|
||||
/// Calculates element-wise smooth L1 loss, then takes the mean
|
||||
/// over the specified dimensions. Useful for per-sample or per-channel losses.
|
||||
///
|
||||
/// Dimensions can be provided in any order. They are sorted internally and
|
||||
/// reduced from highest to lowest to ensure indices remain valid.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `predictions` - The model's predicted values.
|
||||
/// - `targets` - The ground truth target values.
|
||||
/// - `dims` - Dimensions to reduce over.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the specified dimensions reduced to size 1.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// // Consider image tensor with shape [batch, C, H, W]
|
||||
/// let smooth_l1 = SmoothL1LossConfig::new().init();
|
||||
///
|
||||
/// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1]
|
||||
/// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
|
||||
/// ```
|
||||
pub fn forward_reduce_dims<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
dims: &[usize],
|
||||
) -> Tensor<B, D> {
|
||||
let error = self.forward(predictions, targets);
|
||||
|
||||
// Sort the dimensions to ascending order
|
||||
let mut sorted_dims = dims.to_vec();
|
||||
sorted_dims.sort();
|
||||
|
||||
// Reduce over specified dimensions
|
||||
error.mean_dims(sorted_dims.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
// =========================================================================
|
||||
// Configuration Tests
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_config_default_beta() {
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
assert_eq!(loss.beta, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_config_custom_beta() {
|
||||
let loss = SmoothL1LossConfig::new().with_beta(2.5).init();
|
||||
assert_eq!(loss.beta, 2.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The parameter beta must be positive")]
|
||||
fn test_smooth_l1_config_beta_zero_panics() {
|
||||
SmoothL1LossConfig::new().with_beta(0.0).init();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The parameter beta must be positive")]
|
||||
fn test_smooth_l1_config_beta_negative_panics() {
|
||||
SmoothL1LossConfig::new().with_beta(-1.0).init();
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Forward Pass (Element-wise) Tests
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_l2_region() {
|
||||
// Beta = 1.0, errors = 0.0 and 0.5 (both < beta, use L2 formula)
|
||||
// L2 formula: 0.5 * error^2 / beta
|
||||
// error = 0.0 -> loss = 0.5 * 0.0 / 1.0 = 0.0
|
||||
// error = 0.5 -> loss = 0.5 * 0.25 / 1.0 = 0.125
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.5]]), &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
|
||||
|
||||
let output = loss.forward(predictions, targets);
|
||||
let expected = TensorData::from([[0.0_f32, 0.125]]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_l1_region() {
|
||||
// Beta = 1.0, errors = 0.0 and 2.0 (2.0 >= beta, use L1 formula)
|
||||
// L1 formula: |error| - 0.5 * beta
|
||||
// L2 formula: 0.5 * (error)^2 / beta
|
||||
// error = 0.0 -> loss = 0.0
|
||||
// error = 2.0 -> loss = 2.0 - 0.5 = 1.5
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 2.0]]), &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
|
||||
|
||||
let output = loss.forward(predictions, targets);
|
||||
let expected = TensorData::from([[0.0_f32, 1.5]]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_zero_error() {
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[1.0_f32, 2.0, 3.0]]), &device);
|
||||
let targets = predictions.clone();
|
||||
|
||||
let output = loss.forward(predictions, targets);
|
||||
let expected = TensorData::from([[0.0_f32, 0.0, 0.0]]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_negative_errors() {
|
||||
// Ensure absolute value is used correctly
|
||||
// L1 formula: |error| - 0.5 * beta
|
||||
// L2 formula: 0.5 * (error)^2 / beta
|
||||
// Beta = 1.0, error = -3.0 (L1: 3.0 - 0.5 = 2.5)
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 1>::from_data(TensorData::from([-3.0_f32]), &device);
|
||||
let targets = Tensor::<TestBackend, 1>::zeros([1], &device);
|
||||
|
||||
let output = loss.forward(predictions, targets);
|
||||
let expected = TensorData::from([2.5_f32]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_mixed_regions() {
|
||||
// Test with errors in both L1 and L2 regions
|
||||
// Beta = 1.0
|
||||
// L1 formula: |error| - 0.5 * beta
|
||||
// L2 formula: 0.5 * (error)^2 / beta
|
||||
// error = 0.5 -> L2: 0.5 * 0.25 / 1 = 0.125
|
||||
// error = 1.5 -> L1: 1.5 - 0.5 = 1.0
|
||||
// error = 3.0 -> L1: 3.0 - 0.5 = 2.5
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.5_f32, 1.5, 3.0]), &device);
|
||||
let targets = Tensor::<TestBackend, 1>::zeros([3], &device);
|
||||
|
||||
let output = loss.forward(predictions, targets);
|
||||
let expected = TensorData::from([0.125_f32, 1.0, 2.5]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_custom_beta_values() {
|
||||
// Test with beta = 0.5
|
||||
// error = 0.25 (< beta): L2 = 0.5 * 0.0625 / 0.5 = 0.0625
|
||||
// error = 1.0 (>= beta): L1 = 1.0 - 0.25 = 0.75
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().with_beta(0.5).init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.25_f32, 1.0]), &device);
|
||||
let targets = Tensor::<TestBackend, 1>::zeros([2], &device);
|
||||
|
||||
let output = loss.forward(predictions, targets);
|
||||
let expected = TensorData::from([0.0625_f32, 0.75]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// forward_with_reduction Tests
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_reduction_mean() {
|
||||
// Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5)
|
||||
// Mean: (0.125 + 1.5) / 2 = 0.8125
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
|
||||
|
||||
let output = loss.forward_with_reduction(predictions, targets, Reduction::Mean);
|
||||
let expected = TensorData::from([0.8125_f32]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_reduction_sum() {
|
||||
// Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5)
|
||||
// Sum: 1.625
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
|
||||
|
||||
let output = loss.forward_with_reduction(predictions, targets, Reduction::Sum);
|
||||
let expected = TensorData::from([1.625_f32]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_reduction_auto_equals_mean() {
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions = Tensor::<TestBackend, 1>::from_data(TensorData::from([2.0_f32]), &device);
|
||||
let targets = Tensor::<TestBackend, 1>::zeros([1], &device);
|
||||
|
||||
let mean_out =
|
||||
loss.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);
|
||||
let auto_out = loss.forward_with_reduction(predictions, targets, Reduction::Auto);
|
||||
|
||||
mean_out.into_data().assert_eq(&auto_out.into_data(), false);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Dimension Reduction Tests
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_reduce_dims_single_dim() {
|
||||
// Beta = 2.0
|
||||
// L1 formula: |error| - 0.5 * beta
|
||||
// L2 formula: 0.5 * (error)^2 / beta
|
||||
// Row 0: errors [0.0, 1.0, 4.0]
|
||||
// error = 0.0 -> L2: 0.0
|
||||
// error = 1.0 -> L2: 0.5 * 1.0 / 2.0 = 0.25
|
||||
// error = 4.0 -> L1: 4.0 - 1.0 = 3.0
|
||||
// Mean = 3.25 / 3 = 1.083333...
|
||||
// Row 1: errors [0.0, 0.0, 0.0] -> Mean = 0.0
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().with_beta(2.0).init();
|
||||
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.0_f32, 1.0, 4.0], [5.0_f32, 5.0, 5.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.0_f32, 0.0, 0.0], [5.0_f32, 5.0, 5.0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = loss.forward_reduce_dims(predictions, targets, &[1]);
|
||||
let expected = TensorData::from([[3.25_f32 / 3.0], [0.0]]); // 3.25/3 = 1.0833...
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_reduce_dims_image_batch() {
|
||||
// Simulate per-image Smooth L1 loss for [batch, C, H, W] tensor
|
||||
// (common in object detection like Fast R-CNN)
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init(); // beta = 1.0
|
||||
|
||||
// Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)
|
||||
let predictions = Tensor::<TestBackend, 4>::from_data(
|
||||
TensorData::from([
|
||||
[[[0.5_f32, 2.0], [0.0, 3.0]]], // Image 1
|
||||
[[[1.0_f32, 0.0], [0.5, 1.5]]], // Image 2
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 4>::zeros([2, 1, 2, 2], &device);
|
||||
|
||||
// Reduce over C, H, W (dims 1, 2, 3) to get per-image loss
|
||||
let output = loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
|
||||
|
||||
// Image 1: losses [[0.125, 1.5], [0.0, 2.5]] -> mean: 4.125 / 4 = 1.03125
|
||||
// Image 2: losses [[0.5, 0.0], [0.125, 1.0]] -> mean: 1.625 / 4 = 0.40625
|
||||
let expected = TensorData::from([[[[1.03125_f32]]], [[[0.40625_f32]]]]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_reduce_dims_unsorted() {
|
||||
// Test that unsorted dimensions are handled correctly (sorted internally)
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[1.0_f32, 2.0], [3.0, 4.0]], [[5.0_f32, 6.0], [7.0, 8.0]]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 3>::zeros([2, 2, 2], &device);
|
||||
|
||||
// Pass dims in reverse order
|
||||
let output = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[2, 1]);
|
||||
let expected_output = loss.forward_reduce_dims(predictions, targets, &[1, 2]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&expected_output.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smooth_l1_forward_reduce_dims_empty_dims() {
|
||||
// Reducing over no dimensions should return the unreduced loss
|
||||
let device = Default::default();
|
||||
let loss = SmoothL1LossConfig::new().init();
|
||||
|
||||
let predictions = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[0.5_f32, 2.0], [0.0, 3.0]]),
|
||||
&device,
|
||||
);
|
||||
let targets = Tensor::<TestBackend, 2>::zeros([2, 2], &device);
|
||||
|
||||
let loss_reduce_dims = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
|
||||
let loss_no_reduction = loss.forward(predictions, targets);
|
||||
|
||||
loss_reduce_dims
|
||||
.into_data()
|
||||
.assert_eq(&loss_no_reduction.into_data(), false);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,621 @@
|
||||
//! Cross-Attention Module for Burn
|
||||
//!
|
||||
//! Features:
|
||||
//! - Asymmetric Input Shapes (Query vs Context)
|
||||
//! - Grouped Query Attention (GQA) & Multi-Query Attention (MQA) support
|
||||
//! - Quantization-Safe Masking (min_float)
|
||||
//! - Sparse-Ready (quiet_softmax)
|
||||
//! - KV Caching for Streaming Inference
|
||||
|
||||
use crate::cache::TensorCache;
|
||||
use crate::modules::{Linear, LinearConfig};
|
||||
use crate::{Dropout, DropoutConfig};
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Initializer, Module},
|
||||
tensor::{
|
||||
Bool, Tensor,
|
||||
activation::{quiet_softmax, softmax},
|
||||
backend::Backend,
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
/// Configuration to create a [CrossAttention](CrossAttention) layer using the [init function](CrossAttentionConfig::init).
|
||||
pub struct CrossAttentionConfig {
|
||||
/// Dimension of the Query (e.g., Decoder state).
|
||||
pub d_model: usize,
|
||||
/// Dimension of the Context (e.g., Encoder audio embeddings).
|
||||
pub d_context: usize,
|
||||
/// Number of heads for the Query.
|
||||
pub n_heads: usize,
|
||||
/// Number of heads for Key/Value (Set to 1 for MQA, set to n_heads for MHA).
|
||||
pub n_heads_kv: usize,
|
||||
/// Dimension of a single head.
|
||||
pub d_head: usize,
|
||||
/// Dropout rate.
|
||||
#[config(default = 0.1)]
|
||||
pub dropout: f64,
|
||||
/// Masking value. Use -1.0e4 for f16/bf16 safety.
|
||||
#[config(default = -1.0e4)]
|
||||
pub min_float: f64,
|
||||
/// Use quiet_softmax to allow zero-attention (good for sparse/quantized models).
|
||||
#[config(default = false)]
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
/// The Cross attention module
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - `query`: [`Linear`] layer with `d_model` input and output features.
|
||||
/// - `key`: [`Linear`] layer with `d_model` input and output features.
|
||||
/// - `value`: [`Linear`] layer with `d_model` input and output features.
|
||||
/// - `output`: [`Linear`] layer with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created with [CrossAttentionConfig].
|
||||
pub struct CrossAttention<B: Backend> {
|
||||
query: Linear<B>,
|
||||
key: Linear<B>,
|
||||
value: Linear<B>,
|
||||
output: Linear<B>,
|
||||
dropout: Dropout,
|
||||
|
||||
n_heads: usize,
|
||||
n_heads_kv: usize,
|
||||
d_head: usize,
|
||||
scale: f64,
|
||||
min_float: f64,
|
||||
quiet_softmax: bool,
|
||||
}
|
||||
|
||||
/// Cache for the [Cross Attention](CrossAttention) layer.
|
||||
///
|
||||
/// To be used during inference when context is constant.
|
||||
pub struct CrossAttentionCache<B: Backend> {
|
||||
/// Cached key tensor.
|
||||
pub k: TensorCache<B, 4>,
|
||||
/// Cached value tensor.
|
||||
pub v: TensorCache<B, 4>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CrossAttentionCache<B> {
|
||||
/// Create a new empty cache.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
k: TensorCache::empty(),
|
||||
v: TensorCache::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for CrossAttentionCache<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CrossAttentionConfig {
|
||||
/// Initializes a new cross-attention module.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The device on which to initialize the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new [CrossAttention] module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {
|
||||
// Safety Rail for GQA
|
||||
assert_eq!(
|
||||
self.n_heads % self.n_heads_kv,
|
||||
0,
|
||||
"Query heads must be divisible by KV heads"
|
||||
);
|
||||
|
||||
let init_linear = |in_dim, out_dim| {
|
||||
LinearConfig::new(in_dim, out_dim)
|
||||
.with_initializer(Initializer::KaimingUniform {
|
||||
gain: 1.0 / (self.d_head as f64).sqrt(),
|
||||
fan_out_only: false,
|
||||
})
|
||||
.init(device)
|
||||
};
|
||||
|
||||
CrossAttention {
|
||||
// ADVICE: Asymmetric Projections
|
||||
query: init_linear(self.d_model, self.n_heads * self.d_head),
|
||||
key: init_linear(self.d_context, self.n_heads_kv * self.d_head),
|
||||
value: init_linear(self.d_context, self.n_heads_kv * self.d_head),
|
||||
output: init_linear(self.n_heads * self.d_head, self.d_model),
|
||||
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
n_heads: self.n_heads,
|
||||
n_heads_kv: self.n_heads_kv,
|
||||
d_head: self.d_head,
|
||||
scale: (self.d_head as f64).sqrt().recip(),
|
||||
min_float: self.min_float,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> CrossAttention<B> {
|
||||
/// Applies cross-attention to query using context as key and value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
|
||||
/// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
|
||||
/// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Output tensor of shape `[batch, seq_len_query, d_model]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
query: Tensor<B, 3>,
|
||||
context: Tensor<B, 3>,
|
||||
mask: Option<Tensor<B, 2, Bool>>,
|
||||
) -> Tensor<B, 3> {
|
||||
let [batch, l_q, _] = query.dims();
|
||||
let [_, l_k, _] = context.dims();
|
||||
|
||||
// 1. Projections
|
||||
let q = self.query.forward(query);
|
||||
let k = self.key.forward(context.clone());
|
||||
let v = self.value.forward(context);
|
||||
|
||||
// 2. Reshape Heads
|
||||
// Q: [Batch, Heads, L_q, D_head]
|
||||
let q = q
|
||||
.reshape([batch, l_q, self.n_heads, self.d_head])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
// K, V: [Batch, Heads_KV, L_k, D_head]
|
||||
let k = k
|
||||
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
|
||||
.swap_dims(1, 2);
|
||||
let v = v
|
||||
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
// 3. GQA Expansion
|
||||
// ADVICE: Handle GQA by repeating KV heads to match Query heads
|
||||
let (k, v) = if self.n_heads != self.n_heads_kv {
|
||||
let n_rep = self.n_heads / self.n_heads_kv;
|
||||
(self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
|
||||
} else {
|
||||
(k, v)
|
||||
};
|
||||
|
||||
// 4. Score Calculation
|
||||
let scores = q.matmul(k.transpose()) * self.scale;
|
||||
|
||||
// 5. Masking
|
||||
// ADVICE: Use min_float for F16/FP8 safety
|
||||
let scores = if let Some(mask) = mask {
|
||||
let mask = mask.reshape([batch, 1, 1, l_k]);
|
||||
scores.mask_fill(mask, self.min_float)
|
||||
} else {
|
||||
scores
|
||||
};
|
||||
|
||||
// 6. Softmax
|
||||
// ADVICE: Optional Quiet Softmax for sparse networks
|
||||
let weights = if self.quiet_softmax {
|
||||
quiet_softmax(scores, 3)
|
||||
} else {
|
||||
softmax(scores, 3)
|
||||
};
|
||||
|
||||
let weights = self.dropout.forward(weights);
|
||||
|
||||
// 7. Aggregate & Output
|
||||
let output = weights.matmul(v);
|
||||
let output = output
|
||||
.swap_dims(1, 2)
|
||||
.reshape([batch, l_q, self.n_heads * self.d_head]);
|
||||
|
||||
self.output.forward(output)
|
||||
}
|
||||
|
||||
/// Applies cross-attention to query using context as key and value.
|
||||
///
|
||||
/// This method uses a cache to avoid recomputing key and value tensors when the context is the same.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
|
||||
/// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
|
||||
/// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
|
||||
/// * `cache` - The cache to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Output tensor of shape `[batch, seq_len_query, d_model]`.
|
||||
pub fn forward_cache(
|
||||
&self,
|
||||
query: Tensor<B, 3>,
|
||||
context: Tensor<B, 3>,
|
||||
mask: Option<Tensor<B, 2, Bool>>,
|
||||
cache: &mut CrossAttentionCache<B>,
|
||||
) -> Tensor<B, 3> {
|
||||
let [batch, l_q, _] = query.dims();
|
||||
|
||||
// 1. Projections
|
||||
let q = self.query.forward(query);
|
||||
|
||||
let k_compute = |context: Tensor<B, 3>| {
|
||||
let [batch, l_k, _] = context.dims();
|
||||
self.key
|
||||
.forward(context)
|
||||
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
|
||||
.swap_dims(1, 2)
|
||||
};
|
||||
let v_compute = |context: Tensor<B, 3>| {
|
||||
let [batch, l_k, _] = context.dims();
|
||||
self.value
|
||||
.forward(context)
|
||||
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
|
||||
.swap_dims(1, 2)
|
||||
};
|
||||
|
||||
let k = cache.k.forward_full(context.clone(), k_compute);
|
||||
let v = cache.v.forward_full(context, v_compute);
|
||||
|
||||
let [_, _, l_k, _] = k.dims();
|
||||
|
||||
// 2. Reshape Heads
|
||||
// Q: [Batch, Heads, L_q, D_head]
|
||||
let q = q
|
||||
.reshape([batch, l_q, self.n_heads, self.d_head])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
// K, V are already in their correct shape from k_compute and v_compute
|
||||
|
||||
// 3. GQA Expansion
|
||||
// ADVICE: Handle GQA by repeating KV heads to match Query heads
|
||||
let (k, v) = if self.n_heads != self.n_heads_kv {
|
||||
let n_rep = self.n_heads / self.n_heads_kv;
|
||||
(self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
|
||||
} else {
|
||||
(k, v)
|
||||
};
|
||||
|
||||
// 4. Score Calculation
|
||||
let scores = q.matmul(k.transpose()) * self.scale;
|
||||
|
||||
// 5. Masking
|
||||
// ADVICE: Use min_float for F16/FP8 safety
|
||||
let scores = if let Some(mask) = mask {
|
||||
let mask = mask.reshape([batch, 1, 1, l_k]);
|
||||
scores.mask_fill(mask, self.min_float)
|
||||
} else {
|
||||
scores
|
||||
};
|
||||
|
||||
// 6. Softmax
|
||||
// ADVICE: Optional Quiet Softmax for sparse networks
|
||||
let weights = if self.quiet_softmax {
|
||||
quiet_softmax(scores, 3)
|
||||
} else {
|
||||
softmax(scores, 3)
|
||||
};
|
||||
|
||||
let weights = self.dropout.forward(weights);
|
||||
|
||||
// 7. Aggregate & Output
|
||||
let output = weights.matmul(v);
|
||||
let output = output
|
||||
.swap_dims(1, 2)
|
||||
.reshape([batch, l_q, self.n_heads * self.d_head]);
|
||||
|
||||
self.output.forward(output)
|
||||
}
|
||||
|
||||
/// Helper for Grouped Query Attention
|
||||
fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
|
||||
let [b, h, l, d] = x.dims();
|
||||
x.reshape([b, h, 1, l, d])
|
||||
.expand([b, h, n_rep, l, d])
|
||||
.reshape([b, h * n_rep, l, d])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_cross_attention_mha_shapes() {
|
||||
let [
|
||||
batch_size,
|
||||
seq_len_query,
|
||||
seq_len_context,
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
d_head,
|
||||
] = [7, 13, 15, 32, 40, 4, 8];
|
||||
let device = Default::default();
|
||||
let config = CrossAttentionConfig {
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
n_heads_kv: n_heads, // MHA case
|
||||
d_head,
|
||||
dropout: 0.1,
|
||||
min_float: -1.0e4,
|
||||
quiet_softmax: false,
|
||||
};
|
||||
let cross_attn = config.init::<TestBackend>(&device);
|
||||
|
||||
let query = Tensor::random(
|
||||
[batch_size, seq_len_query, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let context = Tensor::random(
|
||||
[batch_size, seq_len_context, d_context],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = cross_attn.forward(query, context, None);
|
||||
|
||||
assert_eq!(
|
||||
output.shape(),
|
||||
Shape::new([batch_size, seq_len_query, d_model]),
|
||||
"Output should have the correct shape",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_attention_gqa_shapes() {
|
||||
let [
|
||||
batch_size,
|
||||
seq_len_query,
|
||||
seq_len_context,
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
n_heads_kv,
|
||||
d_head,
|
||||
] = [7, 13, 15, 32, 40, 4, 2, 8];
|
||||
let device = Default::default();
|
||||
let config = CrossAttentionConfig {
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
n_heads_kv, // GQA case
|
||||
d_head,
|
||||
dropout: 0.1,
|
||||
min_float: -1.0e4,
|
||||
quiet_softmax: false,
|
||||
};
|
||||
let cross_attn = config.init::<TestBackend>(&device);
|
||||
|
||||
let query = Tensor::random(
|
||||
[batch_size, seq_len_query, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let context = Tensor::random(
|
||||
[batch_size, seq_len_context, d_context],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = cross_attn.forward(query, context, None);
|
||||
|
||||
assert_eq!(
|
||||
output.shape(),
|
||||
Shape::new([batch_size, seq_len_query, d_model]),
|
||||
"Output should have the correct shape",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_attention_mqa_shapes() {
|
||||
let [
|
||||
batch_size,
|
||||
seq_len_query,
|
||||
seq_len_context,
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
d_head,
|
||||
] = [7, 13, 15, 32, 40, 4, 8];
|
||||
let device = Default::default();
|
||||
let config = CrossAttentionConfig {
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
n_heads_kv: 1, // MQA case
|
||||
d_head,
|
||||
dropout: 0.1,
|
||||
min_float: -1.0e4,
|
||||
quiet_softmax: false,
|
||||
};
|
||||
let cross_attn = config.init::<TestBackend>(&device);
|
||||
|
||||
let query = Tensor::random(
|
||||
[batch_size, seq_len_query, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let context = Tensor::random(
|
||||
[batch_size, seq_len_context, d_context],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = cross_attn.forward(query, context, None);
|
||||
|
||||
assert_eq!(
|
||||
output.shape(),
|
||||
Shape::new([batch_size, seq_len_query, d_model]),
|
||||
"Output should have the correct shape",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_attention_mask() {
|
||||
let [
|
||||
batch_size,
|
||||
seq_len_query,
|
||||
seq_len_context,
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
d_head,
|
||||
] = [3, 6, 8, 12, 16, 4, 3];
|
||||
let num_padded = 2;
|
||||
let device = Default::default();
|
||||
let config = CrossAttentionConfig {
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
n_heads_kv: n_heads,
|
||||
d_head,
|
||||
dropout: 0.0, // No dropout for deterministic test
|
||||
min_float: -1.0e4,
|
||||
quiet_softmax: false,
|
||||
};
|
||||
let cross_attn = config.init::<TestBackend>(&device);
|
||||
|
||||
// Create a padding mask for the context
|
||||
let mut mask: Tensor<TestBackend, 2, Int> =
|
||||
Tensor::zeros([batch_size, seq_len_context], &device);
|
||||
mask = mask.slice_assign(
|
||||
[0..batch_size, seq_len_context - num_padded..seq_len_context],
|
||||
Tensor::ones([batch_size, num_padded], &device),
|
||||
);
|
||||
let mask_bool = mask.equal_elem(1);
|
||||
|
||||
let query = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_len_query, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
|
||||
let context_1 = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_len_context, d_context],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
|
||||
// Change the padded part of the context tensor
|
||||
let context_2 = context_1.clone().slice_assign(
|
||||
[
|
||||
0..batch_size,
|
||||
seq_len_context - num_padded..seq_len_context,
|
||||
0..d_context,
|
||||
],
|
||||
Tensor::random(
|
||||
[batch_size, num_padded, d_context],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
),
|
||||
);
|
||||
|
||||
// The outputs should be the same since the changed part is masked.
|
||||
let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));
|
||||
let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));
|
||||
|
||||
output_1
|
||||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {
|
||||
let device = Default::default();
|
||||
let config = CrossAttentionConfig {
|
||||
d_model: 32,
|
||||
d_context: 32,
|
||||
n_heads: 5,
|
||||
n_heads_kv: 2,
|
||||
d_head: 8,
|
||||
dropout: 0.1,
|
||||
min_float: -1.0e4,
|
||||
quiet_softmax: false,
|
||||
};
|
||||
config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_attention_cache() {
|
||||
let [
|
||||
batch_size,
|
||||
seq_len_query,
|
||||
seq_len_context,
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
d_head,
|
||||
] = [3, 6, 8, 12, 16, 4, 3];
|
||||
let device = Default::default();
|
||||
let config = CrossAttentionConfig {
|
||||
d_model,
|
||||
d_context,
|
||||
n_heads,
|
||||
n_heads_kv: n_heads,
|
||||
d_head,
|
||||
dropout: 0.0, // No dropout for deterministic test
|
||||
min_float: -1.0e4,
|
||||
quiet_softmax: false,
|
||||
};
|
||||
let cross_attn = config.init::<TestBackend>(&device);
|
||||
|
||||
let query1 = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_len_query, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let context = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_len_context, d_context],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
|
||||
// First forward pass, no cache
|
||||
let output1 = cross_attn.forward(query1.clone(), context.clone(), None);
|
||||
|
||||
// Second forward pass with cache
|
||||
let mut cache = CrossAttentionCache::new();
|
||||
let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);
|
||||
|
||||
// The two outputs should be identical
|
||||
output1
|
||||
.into_data()
|
||||
.assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());
|
||||
|
||||
// Third forward pass with different query, but same context and cache
|
||||
let query2 = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_len_query, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);
|
||||
|
||||
// For control, do a forward pass without cache with query2
|
||||
let output4 = cross_attn.forward(query2.clone(), context.clone(), None);
|
||||
|
||||
// output3 and output4 should be identical
|
||||
output3
|
||||
.into_data()
|
||||
.assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
use burn_core as burn;
|
||||
use burn_core::config::Config;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use burn::tensor::ops::IntElem;
|
||||
|
||||
use burn::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
|
||||
|
||||
/// Generate an autoregressive attention mask.
|
||||
///
|
||||
/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
|
||||
pub fn generate_autoregressive_mask<B: Backend>(
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, 3, Bool> {
|
||||
let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);
|
||||
mask.expand([batch_size, seq_length, seq_length])
|
||||
}
|
||||
|
||||
/// Generate a padding attention mask.
|
||||
pub struct GeneratePaddingMask<B: Backend> {
|
||||
/// The generated tensor.
|
||||
pub tensor: Tensor<B, 2, Int>,
|
||||
|
||||
/// The generated mask.
|
||||
pub mask: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
/// Defines an enumeration to specify sequence length options for padding
|
||||
#[derive(Config, Debug, Copy)]
|
||||
pub enum SeqLengthOption {
|
||||
/// No maximum length; use the longest sequence
|
||||
NoMax,
|
||||
/// Maximum length specified, truncate if necessary
|
||||
Max(usize),
|
||||
/// Fixed length, pad or truncate to this exact length
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
impl From<Option<usize>> for SeqLengthOption {
|
||||
fn from(val: Option<usize>) -> Self {
|
||||
match val {
|
||||
Some(max) => SeqLengthOption::Max(max),
|
||||
None => SeqLengthOption::NoMax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a padding attention mask for a batch of token sequences.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pad_token` - The token ID used for padding
|
||||
/// * `tokens_list` - Vector of token sequences (each sequence is a vector of token IDs)
|
||||
/// * `seq_length` - Sequence length option (NoMax, Max, or Fixed)
|
||||
/// * `device` - The device for tensor operations
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `GeneratePaddingMask` containing the padded tensor and corresponding mask
|
||||
pub fn generate_padding_mask<B: Backend>(
|
||||
pad_token: usize,
|
||||
tokens_list: Vec<Vec<usize>>,
|
||||
seq_length: impl Into<SeqLengthOption>,
|
||||
device: &B::Device,
|
||||
) -> GeneratePaddingMask<B> {
|
||||
let tokens_max = || {
|
||||
tokens_list
|
||||
.iter()
|
||||
.map(|tokens| tokens.len())
|
||||
.max()
|
||||
.unwrap_or(1)
|
||||
};
|
||||
|
||||
let size = match seq_length.into() {
|
||||
SeqLengthOption::NoMax => tokens_max(),
|
||||
SeqLengthOption::Max(max) => usize::min(tokens_max(), max),
|
||||
SeqLengthOption::Fixed(limit) => limit,
|
||||
};
|
||||
let batch_size = tokens_list.len();
|
||||
|
||||
let mut tensor = Tensor::zeros([batch_size, size], device);
|
||||
tensor = tensor.add_scalar(pad_token as i64);
|
||||
|
||||
for (index, tokens) in tokens_list.into_iter().enumerate() {
|
||||
let seq_length = tokens.len().min(size);
|
||||
tensor = tensor.slice_assign(
|
||||
[index..index + 1, 0..seq_length],
|
||||
Tensor::from_data(
|
||||
TensorData::new(
|
||||
tokens
|
||||
.into_iter()
|
||||
.take(size)
|
||||
.map(|e| (e as i64).elem::<IntElem<B>>())
|
||||
.collect(),
|
||||
Shape::new([1, seq_length]),
|
||||
),
|
||||
device,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
let mask = tensor.clone().equal_elem(pad_token as i64);
|
||||
|
||||
GeneratePaddingMask { tensor, mask }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use alloc::vec;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_generate_autoregressive_mask() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
|
||||
|
||||
mask.into_data().assert_eq(
|
||||
&TensorData::from([
|
||||
[
|
||||
[false, true, true],
|
||||
[false, false, true],
|
||||
[false, false, false],
|
||||
],
|
||||
[
|
||||
[false, true, true],
|
||||
[false, false, true],
|
||||
[false, false, false],
|
||||
],
|
||||
]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_padding_mask() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let tokens = vec![
|
||||
vec![3, 3, 3],
|
||||
vec![3, 3, 3],
|
||||
vec![3, 3, 3, 4],
|
||||
vec![3, 3, 3, 4, 10, 15],
|
||||
];
|
||||
|
||||
let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
|
||||
|
||||
mask.mask.into_data().assert_eq(
|
||||
&TensorData::from([
|
||||
[false, false, false, true, true, true],
|
||||
[false, false, false, true, true, true],
|
||||
[false, false, false, false, true, true],
|
||||
[false, false, false, false, false, false],
|
||||
]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,531 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::activation::Gelu;
|
||||
use crate::cache::TensorCache;
|
||||
use crate::{Dropout, DropoutConfig, Linear, LinearConfig};
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::{Bool, Tensor, backend::Backend};
|
||||
|
||||
use burn::tensor::activation::{quiet_softmax, softmax};
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MultiHeadAttentionConfig {
|
||||
/// The size of each linear layer.
|
||||
pub d_model: usize,
|
||||
/// The number of heads.
|
||||
pub n_heads: usize,
|
||||
/// The dropout rate. Default: 0.1
|
||||
#[config(default = 0.1)]
|
||||
pub dropout: f64,
|
||||
/// The minimum value a float can take. Default: -1.0e4
|
||||
/// This is used to mask attention scores before calculating attention weights.
|
||||
/// A value too low might result in NaN.
|
||||
#[config(default = -1.0e4)]
|
||||
pub min_float: f64,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
|
||||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
pub quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - `query`: [`Linear`] layer with `d_model` input and output features.
|
||||
/// - `key`: [`Linear`] layer with `d_model` input and output features.
|
||||
/// - `value`: [`Linear`] layer with `d_model` input and output features.
|
||||
/// - `output`: [`Linear`] layer with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created with [MultiHeadAttentionConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct MultiHeadAttention<B: Backend> {
|
||||
/// Linear layer to transform the input features into the query space.
|
||||
pub query: Linear<B>,
|
||||
/// Linear layer to transform the input features into the key space.
|
||||
pub key: Linear<B>,
|
||||
/// Linear layer to transform the input features into the value space.
|
||||
pub value: Linear<B>,
|
||||
/// Linear layer to transform the output features back to the original space.
|
||||
pub output: Linear<B>,
|
||||
/// Dropout layer.
|
||||
pub dropout: Dropout,
|
||||
/// Activation function.
|
||||
pub activation: Gelu,
|
||||
/// The size of each linear layer.
|
||||
pub d_model: usize,
|
||||
/// The number of heads.
|
||||
pub n_heads: usize,
|
||||
/// Size of the key and query vectors.
|
||||
pub d_k: usize,
|
||||
/// Minimum value a float can take.
|
||||
pub min_float: f64,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("d_model", &self.d_model)
|
||||
.add("n_heads", &self.n_heads)
|
||||
.add("d_k", &self.d_k)
|
||||
.add("dropout", &self.dropout.prob)
|
||||
.add("min_float", &self.min_float)
|
||||
.add("quiet_softmax", &self.quiet_softmax)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MhaInput<B: Backend> {
|
||||
/// Shape `[batch_size, seq_length_1, d_model]`
|
||||
query: Tensor<B, 3>,
|
||||
/// Shape `[batch_size, seq_length_2, d_model]`
|
||||
key: Tensor<B, 3>,
|
||||
/// Shape `[batch_size, seq_length_2, d_model]`
|
||||
value: Tensor<B, 3>,
|
||||
mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttentionConfig {
|
||||
/// Initialize a new [multihead attention](MultiHeadAttention) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
|
||||
let linear = |config: &Self| {
|
||||
LinearConfig::new(config.d_model, config.d_model)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.init(device)
|
||||
};
|
||||
|
||||
MultiHeadAttention {
|
||||
query: linear(self),
|
||||
key: linear(self),
|
||||
value: linear(self),
|
||||
output: linear(self),
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
activation: Gelu::new(),
|
||||
n_heads: self.n_heads,
|
||||
d_k: self.d_model / self.n_heads,
|
||||
min_float: self.min_float,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
d_model: self.d_model,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MhaInput<B> {
|
||||
/// Create a [multihead attention](MultiHeadAttention) input argument
|
||||
/// by setting the query, key and value to the given tensor.
|
||||
///
|
||||
/// # Shape
|
||||
/// - tensor: `[batch_size, seq_length, d_model]`
|
||||
pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
|
||||
Self {
|
||||
query: tensor.clone(),
|
||||
key: tensor.clone(),
|
||||
value: tensor,
|
||||
mask_pad: None,
|
||||
mask_attn: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a [multihead attention](MultiHeadAttention) input argument.
|
||||
pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
|
||||
Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask_pad: None,
|
||||
mask_attn: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the padding mask.
|
||||
pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
|
||||
self.mask_pad = Some(mask_pad);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register the attention mask.
|
||||
pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
|
||||
self.mask_attn = Some(mask_attn);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// [Multihead attention](MultiHeadAttention) outputs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MhaOutput<B: Backend> {
|
||||
/// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`.
|
||||
pub weights: Tensor<B, 4>,
|
||||
/// The context tensor `[batch_size, seq_length_1, d_model]`.
|
||||
pub context: Tensor<B, 3>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadAttention<B> {
|
||||
/// Applies the forward pass on the input tensors.
|
||||
///
|
||||
/// See [MultiHeadAttention](MultiHeadAttention) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - query: `[batch_size, seq_length_1, d_model]`
|
||||
/// - key: `[batch_size, seq_length_2, d_model]`
|
||||
/// - value: `[batch_size, seq_length_2, d_model]`
|
||||
/// - output: `[batch_size, seq_length_1, d_model]`
|
||||
pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
|
||||
let [batch_size, seq_length_1, d_model] = input.query.dims();
|
||||
|
||||
let query = self.attention_linear(input.query, &self.query);
|
||||
let key = self.attention_linear(input.key, &self.key);
|
||||
let value = self.attention_linear(input.value, &self.value);
|
||||
|
||||
let attn_scores = self.attn_scores(query, key);
|
||||
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
|
||||
|
||||
let context = weights.clone().matmul(value);
|
||||
let context = context
|
||||
.swap_dims(1, 2)
|
||||
.reshape([batch_size, seq_length_1, d_model]);
|
||||
let context = self.output.forward(context);
|
||||
|
||||
MhaOutput { weights, context }
|
||||
}
|
||||
|
||||
/// Applies the forward pass using a cache.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - query: `[batch_size, seq_length_1, d_model]`
|
||||
/// - key: `[batch_size, seq_length_2, d_model]`
|
||||
/// - value: `[batch_size, seq_length_2, d_model]`
|
||||
/// - output: `[batch_size, seq_length_1, d_model]`
|
||||
pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
|
||||
let [batch_size, seq_length_1, d_model] = input.query.dims();
|
||||
|
||||
let query = cache
|
||||
.query
|
||||
.forward(input.query, |t| self.attention_linear(t, &self.query));
|
||||
let key = cache
|
||||
.key
|
||||
.forward(input.key, |t| self.attention_linear(t, &self.key));
|
||||
let value = cache
|
||||
.value
|
||||
.forward(input.value, |t| self.attention_linear(t, &self.value));
|
||||
|
||||
let attn_scores = self.attn_scores(query, key);
|
||||
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
|
||||
|
||||
let context = weights.clone().matmul(value);
|
||||
let context = context
|
||||
.swap_dims(1, 2)
|
||||
.reshape([batch_size, seq_length_1, d_model]);
|
||||
|
||||
let context = cache.output.forward(context, |t| self.output.forward(t));
|
||||
|
||||
MhaOutput { weights, context }
|
||||
}
|
||||
|
||||
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let attn_scores = query
|
||||
.matmul(key.transpose())
|
||||
.div_scalar((self.d_k as f32).sqrt());
|
||||
|
||||
self.dropout.forward(attn_scores)
|
||||
}
|
||||
|
||||
fn attn_weights(
|
||||
&self,
|
||||
mut attn_scores: Tensor<B, 4>,
|
||||
mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
) -> Tensor<B, 4> {
|
||||
if let Some(mask_pad) = mask_pad {
|
||||
let [batch_size, seq_length] = mask_pad.dims();
|
||||
|
||||
attn_scores = attn_scores.mask_fill(
|
||||
mask_pad.reshape([batch_size, 1, 1, seq_length]),
|
||||
self.min_float,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(mask_attn) = mask_attn {
|
||||
let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
|
||||
|
||||
attn_scores = attn_scores.mask_fill(
|
||||
mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
|
||||
self.min_float,
|
||||
);
|
||||
}
|
||||
|
||||
if self.quiet_softmax {
|
||||
quiet_softmax(attn_scores, 3)
|
||||
} else {
|
||||
softmax(attn_scores, 3)
|
||||
}
|
||||
}
|
||||
|
||||
fn attention_linear(&self, x: Tensor<B, 3>, linear: &Linear<B>) -> Tensor<B, 4> {
|
||||
let [batch_size, seq_length, _d_model] = x.dims();
|
||||
linear
|
||||
.forward(x)
|
||||
.reshape([batch_size, seq_length, self.n_heads, self.d_k])
|
||||
.swap_dims(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache for the [Multi Head Attention](MultiHeadAttention) layer.
|
||||
///
|
||||
/// To be used during inference when decoding tokens.
|
||||
pub struct MhaCache<B: Backend> {
|
||||
query: MhaLinearCache<B, 4>,
|
||||
key: MhaLinearCache<B, 4>,
|
||||
value: MhaLinearCache<B, 4>,
|
||||
output: MhaLinearCache<B, 3>,
|
||||
}
|
||||
|
||||
enum MhaLinearCache<B: Backend, const D: usize> {
|
||||
Autoregressive(TensorCache<B, D>, usize),
|
||||
Full(TensorCache<B, D>),
|
||||
}
|
||||
|
||||
impl<B: Backend> MhaCache<B> {
|
||||
/// Initialize a cache for autoregressive inference.
|
||||
pub fn autoregressive() -> Self {
|
||||
Self {
|
||||
query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
|
||||
key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
|
||||
value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
|
||||
output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and
|
||||
/// values (cross-attention).
|
||||
pub fn autoregressive_cross_attention() -> Self {
|
||||
Self {
|
||||
query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
|
||||
key: MhaLinearCache::Full(TensorCache::empty()),
|
||||
value: MhaLinearCache::Full(TensorCache::empty()),
|
||||
output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
|
||||
pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
|
||||
&mut self,
|
||||
tensor: Tensor<B, 3>,
|
||||
func: F,
|
||||
) -> Tensor<B, D> {
|
||||
match self {
|
||||
MhaLinearCache::Autoregressive(cache, dim) => {
|
||||
cache.forward_autoregressive(tensor, *dim, func)
|
||||
}
|
||||
MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{TestBackend, attention::generate_autoregressive_mask};
|
||||
use alloc::vec::Vec;
|
||||
use burn::tensor::Int;
|
||||
use burn::tensor::Tolerance;
|
||||
use burn::tensor::ops::FloatElem;
|
||||
use burn::tensor::{Distribution, Shape};
|
||||
|
||||
#[test]
|
||||
fn test_self_attention_shapes() {
|
||||
let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
|
||||
let device = Default::default();
|
||||
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
|
||||
let input = MhaInput::self_attn(Tensor::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
));
|
||||
|
||||
let output = mha.forward(input);
|
||||
|
||||
assert_eq!(
|
||||
output.context.shape(),
|
||||
Shape::new([batch_size, seq_length, d_model]),
|
||||
"Context should have the correct shape",
|
||||
);
|
||||
assert_eq!(
|
||||
output.weights.shape(),
|
||||
Shape::new([batch_size, n_heads, seq_length, seq_length]),
|
||||
"Weights should have the correct shape",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generic_mha_shapes() {
|
||||
let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
|
||||
let mha = MultiHeadAttentionConfig::new(d_model, n_heads)
|
||||
.init::<TestBackend>(&Default::default());
|
||||
let device = Default::default();
|
||||
let input = MhaInput::new(
|
||||
Tensor::random(
|
||||
[batch_size, seq_length_1, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
),
|
||||
Tensor::random(
|
||||
[batch_size, seq_length_2, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
),
|
||||
Tensor::random(
|
||||
[batch_size, seq_length_2, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
),
|
||||
);
|
||||
|
||||
let output = mha.forward(input);
|
||||
|
||||
assert_eq!(
|
||||
output.context.shape(),
|
||||
Shape::new([batch_size, seq_length_1, d_model]),
|
||||
"Context should have the correct shape",
|
||||
);
|
||||
assert_eq!(
|
||||
output.weights.shape(),
|
||||
Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
|
||||
"Weights should have the correct shape",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_self_attention_mask_pad() {
|
||||
let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
|
||||
let device = Default::default();
|
||||
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
|
||||
|
||||
// Create a padding mask
|
||||
let mask_pad: Tensor<TestBackend, 2, Int> =
|
||||
Tensor::zeros([batch_size, seq_length], &device);
|
||||
let mask_pad = mask_pad.slice_assign(
|
||||
[0..batch_size, seq_length - num_padded..seq_length],
|
||||
Tensor::ones([batch_size, num_padded], &device),
|
||||
);
|
||||
let mask_pad = mask_pad.equal_elem(1).to_device(&device);
|
||||
|
||||
let tensor_1 = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
// Change the end of the tensor
|
||||
let tensor_2 = tensor_1.clone().slice_assign(
|
||||
[
|
||||
0..batch_size,
|
||||
seq_length - num_padded..seq_length,
|
||||
0..d_model,
|
||||
],
|
||||
Tensor::random(
|
||||
[batch_size, num_padded, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
),
|
||||
);
|
||||
|
||||
let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
|
||||
let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
|
||||
|
||||
let output_1 = mha.forward(input_1);
|
||||
let output_2 = mha.forward(input_2);
|
||||
|
||||
// Check that the beginning of each tensor is the same
|
||||
output_1
|
||||
.context
|
||||
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
.into_data()
|
||||
.assert_approx_eq(
|
||||
&output_2
|
||||
.context
|
||||
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
.into_data(),
|
||||
Tolerance::<f32>::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
|
||||
let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
|
||||
let device = Default::default();
|
||||
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
|
||||
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
|
||||
|
||||
let output_1 = mha.forward(input);
|
||||
let mut output_2 = Vec::new();
|
||||
let mut cache = MhaCache::autoregressive();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
|
||||
let input = MhaInput::self_attn(tensor);
|
||||
let next_tok = mha.forward_cache(input, &mut cache).context.slice([
|
||||
0..batch_size,
|
||||
i - 1..i,
|
||||
0..d_model,
|
||||
]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
let output_2 = Tensor::cat(output_2, 1);
|
||||
|
||||
output_1
|
||||
.context
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem<TestBackend>>(
|
||||
&output_2.into_data(),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = MultiHeadAttentionConfig::new(2, 4);
|
||||
let mha = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{mha}"),
|
||||
"MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
|
||||
dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod cross_attention;
|
||||
mod mask;
|
||||
mod mha;
|
||||
|
||||
pub use cross_attention::*;
|
||||
pub use mask::*;
|
||||
pub use mha::*;
|
||||
52
crates/stable-diffusion-burn/burn-crates/burn-nn/src/modules/cache/autoregressive.rs
vendored
Normal file
52
crates/stable-diffusion-burn/burn-crates/burn-nn/src/modules/cache/autoregressive.rs
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
use alloc::vec;
|
||||
use burn_core as burn;
|
||||
|
||||
use super::{CacheState, TensorCache};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
impl<B: Backend, const D: usize> TensorCache<B, D> {
|
||||
pub(crate) fn forward_autoregressive<F>(
|
||||
&mut self,
|
||||
tensor: Tensor<B, 3>,
|
||||
dim_cat: usize,
|
||||
func: F,
|
||||
) -> Tensor<B, D>
|
||||
where
|
||||
F: Fn(Tensor<B, 3>) -> Tensor<B, D>,
|
||||
{
|
||||
let mut tensor_old = CacheState::Empty;
|
||||
core::mem::swap(&mut self.state, &mut tensor_old);
|
||||
|
||||
let tensor_new = match tensor_old {
|
||||
CacheState::Value(tensor_old) => {
|
||||
let [batch_size, seq_length, d_model] = tensor.dims();
|
||||
let next_seq_token =
|
||||
tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
|
||||
let next_seq_token = func(next_seq_token);
|
||||
|
||||
Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)
|
||||
}
|
||||
_ => func(tensor),
|
||||
};
|
||||
|
||||
self.state = CacheState::Value(tensor_new.clone());
|
||||
tensor_new
|
||||
}
|
||||
|
||||
pub(crate) fn forward_full<F>(&mut self, tensor: Tensor<B, 3>, func: F) -> Tensor<B, D>
|
||||
where
|
||||
F: Fn(Tensor<B, 3>) -> Tensor<B, D>,
|
||||
{
|
||||
let mut tensor_old = CacheState::Empty;
|
||||
core::mem::swap(&mut self.state, &mut tensor_old);
|
||||
|
||||
let tensor_new = match tensor_old {
|
||||
CacheState::Value(tensor_old) => tensor_old,
|
||||
_ => func(tensor),
|
||||
};
|
||||
|
||||
self.state = CacheState::Value(tensor_new.clone());
|
||||
tensor_new
|
||||
}
|
||||
}
|
||||
27
crates/stable-diffusion-burn/burn-crates/burn-nn/src/modules/cache/base.rs
vendored
Normal file
27
crates/stable-diffusion-burn/burn-crates/burn-nn/src/modules/cache/base.rs
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
pub(crate) enum CacheState<T> {
|
||||
Value(T),
|
||||
Empty,
|
||||
}
|
||||
|
||||
/// A cache for a tensor.
|
||||
pub struct TensorCache<B: Backend, const D: usize> {
|
||||
pub(crate) state: CacheState<Tensor<B, D>>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> TensorCache<B, D> {
|
||||
/// Creates a new empty cache.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The empty cache.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
state: CacheState::Empty,
|
||||
}
|
||||
}
|
||||
}
|
||||
4
crates/stable-diffusion-burn/burn-crates/burn-nn/src/modules/cache/mod.rs
vendored
Normal file
4
crates/stable-diffusion-burn/burn-crates/burn-nn/src/modules/cache/mod.rs
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
mod autoregressive;
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,22 @@
|
||||
pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize, groups: usize) {
|
||||
let channels_in_div_by_group = channels_in.is_multiple_of(groups);
|
||||
let channels_out_div_by_group = channels_out.is_multiple_of(groups);
|
||||
|
||||
if !channels_in_div_by_group || !channels_out_div_by_group {
|
||||
panic!(
|
||||
"Both channels must be divisible by the number of groups. Got \
|
||||
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/tracel-ai/burn/issues/2676
|
||||
/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
|
||||
/// size is not supported as it will not produce the same output size.
|
||||
pub(crate) fn check_same_padding_support(kernel_size: &[usize]) {
|
||||
for k in kernel_size.iter() {
|
||||
if k % 2 == 0 {
|
||||
unimplemented!("Same padding with an even kernel size is not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::{PaddingConfig1d, conv::checks};
|
||||
use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::PaddedConvOptions};
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Content, DisplaySettings, Ignored, Initializer, Module, ModuleDisplay, Param},
|
||||
};
|
||||
|
||||
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Conv1dConfig {
|
||||
/// The number of input channels.
|
||||
pub channels_in: usize,
|
||||
/// The number of output channels.
|
||||
pub channels_out: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "1")]
|
||||
pub stride: usize,
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "1")]
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
|
||||
/// will automatically use asymmetric padding to preserve input dimensions.
|
||||
#[config(default = "PaddingConfig1d::Valid")]
|
||||
pub padding: PaddingConfig1d,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over input tensors.
|
||||
///
|
||||
/// Should be created with [Conv1dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv1d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
|
||||
pub weight: Param<Tensor<B, 3>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: usize,
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// Padding configuration.
|
||||
pub padding: Ignored<PaddingConfig1d>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv1d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Format padding
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format stride/dilation as strings
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
|
||||
// Extract channels in/out from weight dims
|
||||
let [channels_out, group_channels_in, _] = self.weight.dims();
|
||||
let channels_in = group_channels_in * self.groups;
|
||||
let ch_out = format!("{:?}", channels_out);
|
||||
let ch_in = format!("{:?}", channels_in);
|
||||
|
||||
content
|
||||
.add("ch_in", &ch_in)
|
||||
.add("ch_out", &ch_out)
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
impl Conv1dConfig {
|
||||
/// Initialize a new [conv1d](Conv1d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
|
||||
checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
|
||||
|
||||
let shape = [
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size,
|
||||
];
|
||||
|
||||
let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), None, device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias =
|
||||
Some(
|
||||
self.initializer
|
||||
.init_with([self.channels_out], Some(fan_in), None, device),
|
||||
);
|
||||
}
|
||||
|
||||
Conv1d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Conv1d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [conv1d](burn::tensor::module::conv1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, length_in]`
|
||||
/// - output: `[batch_size, channels_out, length_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let length = input.dims()[2];
|
||||
|
||||
// Calculate padding as pair - handles Same, Valid, and Explicit uniformly
|
||||
let (left, right) =
|
||||
self.padding
|
||||
.calculate_padding_1d_pair(length, self.kernel_size, self.stride);
|
||||
|
||||
let options = PaddedConvOptions::asymmetric(
|
||||
[self.stride],
|
||||
[left],
|
||||
[right],
|
||||
[self.dilation],
|
||||
self.groups,
|
||||
);
|
||||
|
||||
conv1d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::{ElementConversion, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = Conv1dConfig::new(5, 5, 5);
|
||||
let k = (config.channels_in * config.kernel_size) as f64;
|
||||
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_with_even_kernel_uses_asymmetric_padding() {
|
||||
let device = Default::default();
|
||||
let config = Conv1dConfig::new(4, 4, 2)
|
||||
.with_padding(PaddingConfig1d::Same)
|
||||
.with_initializer(Initializer::Constant { value: 1.0 })
|
||||
.with_bias(false);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
// Input: [batch=1, channels=4, length=5]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 4, 5], &device);
|
||||
let output = conv.forward(input);
|
||||
|
||||
// Same padding should preserve spatial dimensions
|
||||
assert_eq!(output.dims(), [1, 4, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Conv1dConfig::new(5, 5, 5);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{conv}"),
|
||||
"Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = Conv1dConfig::new(5, 3, 3);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
|
||||
let _ = conv.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asymmetric_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create conv with asymmetric padding: left=1, right=2
|
||||
let config = Conv1dConfig::new(2, 3, 3)
|
||||
.with_padding(PaddingConfig1d::Explicit(1, 2))
|
||||
.with_initializer(Initializer::Constant { value: 1.0 })
|
||||
.with_bias(false);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
// Input: [batch=1, channels=2, length=4]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
|
||||
let output = conv.forward(input);
|
||||
|
||||
// With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
|
||||
// Output length = (7 - 3) / 1 + 1 = 5
|
||||
assert_eq!(output.dims(), [1, 3, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_explicit_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create conv with symmetric explicit padding: left=2, right=2
|
||||
let config = Conv1dConfig::new(2, 3, 3)
|
||||
.with_padding(PaddingConfig1d::Explicit(2, 2))
|
||||
.with_initializer(Initializer::Constant { value: 1.0 })
|
||||
.with_bias(false);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
// Input: [batch=1, channels=2, length=4]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
|
||||
let output = conv.forward(input);
|
||||
|
||||
// With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
|
||||
// Output length = (8 - 3) / 1 + 1 = 6
|
||||
assert_eq!(output.dims(), [1, 3, 6]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,349 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig2d;
|
||||
use burn::config::Config;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::conv2d;
|
||||
use burn::tensor::ops::PaddedConvOptions;
|
||||
|
||||
use crate::conv::checks;
|
||||
|
||||
/// Configuration to create a [2D convolution](Conv2d) layer, using the [init function](Conv2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Conv2dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
|
||||
/// will automatically use asymmetric padding to preserve input dimensions.
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a 2D convolution over input tensors.
|
||||
///
|
||||
/// Should be created with [Conv2dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
impl Conv2dConfig {
|
||||
/// Initialize a new [conv2d](Conv2d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
|
||||
|
||||
let shape = [
|
||||
self.channels[1],
|
||||
self.channels[0] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
];
|
||||
|
||||
let k = self.kernel_size.iter().product::<usize>();
|
||||
let fan_in = self.channels[0] / self.groups * k;
|
||||
let fan_out = self.channels[1] / self.groups * k;
|
||||
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), Some(fan_out), device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(self.initializer.init_with(
|
||||
[self.channels[1]],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
device,
|
||||
));
|
||||
}
|
||||
|
||||
Conv2d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
let [channels_out, group_channels_in, _, _] = self.weight.dims();
|
||||
let channels_in = group_channels_in * self.groups;
|
||||
let ch_out = format!("{:?}", channels_out);
|
||||
let ch_in = format!("{:?}", channels_in);
|
||||
content
|
||||
.add("ch_in", &ch_in)
|
||||
.add("ch_out", &ch_out)
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Conv2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [conv2d](burn::tensor::module::conv2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `input`: `[batch_size, channels_in, height_in, width_in]`
|
||||
/// - `output`: `[batch_size, channels_out, height_out, width_out]`
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// use burn::nn::conv::Conv2dConfig;
|
||||
/// use burn::tensor::Tensor;
|
||||
///
|
||||
/// // Assuming backend type alias `B`
|
||||
/// let device = Default::default();
|
||||
/// let conv = Conv2dConfig::new([3, 8], [3, 3]).init::<B>(&device);
|
||||
///
|
||||
/// let x = Tensor::<B, 4>::zeros([1, 3, 28, 28], &device);
|
||||
/// let y = conv.forward(x);
|
||||
///
|
||||
/// println!("{:?}", y.dims()); // [1, 8, 26, 26]
|
||||
/// ```
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
|
||||
// Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
|
||||
let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
|
||||
height_in,
|
||||
width_in,
|
||||
&self.kernel_size,
|
||||
&self.stride,
|
||||
);
|
||||
|
||||
let options = PaddedConvOptions::asymmetric(
|
||||
self.stride,
|
||||
[top, left],
|
||||
[bottom, right],
|
||||
self.dilation,
|
||||
self.groups,
|
||||
);
|
||||
|
||||
conv2d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::ops::FloatElem;
|
||||
use burn::tensor::{ElementConversion, Tolerance};
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
type FT = FloatElem<TestBackend>; // Float test
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = Conv2dConfig::new([5, 1], [5, 5]);
|
||||
let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
|
||||
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight.to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::zeros::<FT, _>(conv.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_out() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true, // test that fan_out is passed to `init_with()`
|
||||
};
|
||||
|
||||
let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_with_groups_is_valid() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true,
|
||||
};
|
||||
|
||||
let config = Conv2dConfig::new([4, 4], [1, 1])
|
||||
.with_initializer(init.clone())
|
||||
.with_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Both channels must be divisible by the number of groups."]
|
||||
fn channels_with_groups_is_invalid() {
|
||||
let device = Default::default();
|
||||
let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_with_even_kernel_uses_asymmetric_padding() {
|
||||
let device = Default::default();
|
||||
let config = Conv2dConfig::new([4, 4], [2, 2])
|
||||
.with_padding(PaddingConfig2d::Same)
|
||||
.with_initializer(Initializer::Constant { value: 1.0 })
|
||||
.with_bias(false);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
// Input: [batch=1, channels=4, height=5, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 4, 5, 5], &device);
|
||||
let output = conv.forward(input);
|
||||
|
||||
// Same padding should preserve spatial dimensions
|
||||
assert_eq!(output.dims(), [1, 4, 5, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Conv2dConfig::new([5, 1], [5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{conv}"),
|
||||
"Conv2d {ch_in: 5, ch_out: 1, stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = Conv2dConfig::new([5, 3], [3, 3]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
|
||||
let _ = conv.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asymmetric_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create conv with asymmetric padding: top=1, left=2, bottom=3, right=4
|
||||
let config = Conv2dConfig::new([2, 3], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4))
|
||||
.with_initializer(Initializer::Constant { value: 1.0 })
|
||||
.with_bias(false);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
// Input: [batch=1, channels=2, height=4, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
|
||||
let output = conv.forward(input);
|
||||
|
||||
// Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
|
||||
// Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
|
||||
assert_eq!(output.dims(), [1, 3, 6, 9]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_explicit_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create conv with symmetric explicit padding: top=2, left=2, bottom=2, right=2
|
||||
let config = Conv2dConfig::new([2, 3], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))
|
||||
.with_initializer(Initializer::Constant { value: 1.0 })
|
||||
.with_bias(false);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
// Input: [batch=1, channels=2, height=4, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
|
||||
let output = conv.forward(input);
|
||||
|
||||
// Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
|
||||
// Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
|
||||
assert_eq!(output.dims(), [1, 3, 6, 7]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig3d;
|
||||
use burn::config::Config;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::conv3d;
|
||||
use burn::tensor::ops::ConvOptions;
|
||||
|
||||
use crate::conv::checks;
|
||||
|
||||
/// Configuration to create a [3D convolution](Conv3d) layer, using the [init function](Conv3dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Conv3dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 3],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1, 1]")]
|
||||
pub stride: [usize; 3],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1, 1]")]
|
||||
pub dilation: [usize; 3],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "PaddingConfig3d::Valid")]
|
||||
pub padding: PaddingConfig3d,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a 3D convolution over input tensors.
|
||||
///
|
||||
/// Should be created with [Conv3dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv3d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2, kernel_size_3]`
|
||||
pub weight: Param<Tensor<B, 5>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 3],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 3],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 3],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig3d>,
|
||||
}
|
||||
|
||||
impl Conv3dConfig {
|
||||
/// Initialize a new [conv3d](Conv3d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv3d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
|
||||
if self.padding == PaddingConfig3d::Same {
|
||||
checks::check_same_padding_support(&self.kernel_size);
|
||||
}
|
||||
|
||||
let shape = [
|
||||
self.channels[1],
|
||||
self.channels[0] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
self.kernel_size[2],
|
||||
];
|
||||
|
||||
let k = self.kernel_size.iter().product::<usize>();
|
||||
let fan_in = self.channels[0] / self.groups * k;
|
||||
let fan_out = self.channels[1] / self.groups * k;
|
||||
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), Some(fan_out), device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(self.initializer.init_with(
|
||||
[self.channels[1]],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
device,
|
||||
));
|
||||
}
|
||||
|
||||
Conv3d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv3d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Padding doesn't implement ModuleDisplay, so format manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format arrays as strings (consistent with Conv2d/Conv1d).
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
|
||||
// Weight dims: [channels_out, channels_in/groups, k1, k2, k3]
|
||||
let [channels_out, group_channels_in, _, _, _] = self.weight.dims();
|
||||
let channels_in = group_channels_in * self.groups;
|
||||
let ch_out = format!("{:?}", channels_out);
|
||||
let ch_in = format!("{:?}", channels_in);
|
||||
|
||||
content
|
||||
.add("ch_in", &ch_in)
|
||||
.add("ch_out", &ch_out)
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Conv3d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [conv3d](burn::tensor::module::conv3d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, depth_in, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels_out, depth_out, height_out, width_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 5>) -> Tensor<B, 5> {
|
||||
let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims();
|
||||
let padding = self.padding.calculate_padding_3d(
|
||||
depth_in,
|
||||
height_in,
|
||||
width_in,
|
||||
&self.kernel_size,
|
||||
&self.stride,
|
||||
);
|
||||
conv3d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
ConvOptions::new(self.stride, padding, self.dilation, self.groups),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = Conv3dConfig::new([5, 1], [5, 5, 5]);
|
||||
let k = (config.channels[0]
|
||||
* config.kernel_size[0]
|
||||
* config.kernel_size[1]
|
||||
* config.kernel_size[2]) as f64;
|
||||
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = Conv3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros);
|
||||
let device = Default::default();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight.to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::zeros::<f32, _>(conv.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_out() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true, // test that fan_out is passed to `init_with()`
|
||||
};
|
||||
let config = Conv3dConfig::new([5, 1], [5, 5, 5]).with_initializer(init.clone());
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_with_groups_is_valid() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true,
|
||||
};
|
||||
|
||||
let config = Conv3dConfig::new([4, 4], [1, 1, 1])
|
||||
.with_initializer(init.clone())
|
||||
.with_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Same padding with an even kernel size is not supported"]
|
||||
fn same_with_even_kernel_is_invalid() {
|
||||
let device = Default::default();
|
||||
let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Conv3dConfig::new([5, 1], [5, 5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{conv}"),
|
||||
"Conv3d {ch_in: 5, ch_out: 1, stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: Valid, params: 626}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = Conv3dConfig::new([5, 3], [3, 3, 3]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 5>::zeros([1, 4, 10, 10, 10], &Default::default());
|
||||
let _ = conv.forward(input);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::conv::checks;
|
||||
use burn::config::Config;
|
||||
use burn::module::Content;
|
||||
use burn::module::DisplaySettings;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::Module;
|
||||
use burn::module::ModuleDisplay;
|
||||
use burn::module::Param;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::conv_transpose1d;
|
||||
use burn::tensor::ops::ConvTransposeOptions;
|
||||
|
||||
/// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer
|
||||
/// using the [init function](ConvTranspose1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ConvTranspose1dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "1")]
|
||||
pub stride: usize,
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "1")]
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "0")]
|
||||
pub padding: usize,
|
||||
/// The padding output configuration.
|
||||
#[config(default = "0")]
|
||||
pub padding_out: usize,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over input tensors.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct ConvTranspose1d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
|
||||
pub weight: Param<Tensor<B, 3>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: usize,
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: usize,
|
||||
/// The padding output configuration.
|
||||
pub padding_out: usize,
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for ConvTranspose1d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("channels", &format!("{:?}", &self.channels))
|
||||
.add("stride", &self.stride)
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("dilation", &self.dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &self.padding)
|
||||
.add("padding_out", &self.padding_out)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConvTranspose1dConfig {
|
||||
/// Initialize a new [conv transpose 1d](ConvTranspose1d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose1d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
|
||||
|
||||
let shape = [
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size,
|
||||
];
|
||||
|
||||
let fan_in = self.channels[1] / self.groups * self.kernel_size;
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), None, device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(
|
||||
self.initializer
|
||||
.init_with([self.channels[1]], Some(fan_in), None, device),
|
||||
);
|
||||
}
|
||||
|
||||
ConvTranspose1d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
padding: self.padding,
|
||||
padding_out: self.padding_out,
|
||||
channels: self.channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvTranspose1d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [conv_transpose1d](burn::tensor::module::conv_transpose1d).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, length_in]`
|
||||
/// - output: `[batch_size, channels_out, length_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
conv_transpose1d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride],
|
||||
[self.padding],
|
||||
[self.padding_out],
|
||||
[self.dilation],
|
||||
self.groups,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::ops::FloatElem;
|
||||
use burn::tensor::{ElementConversion, Tolerance};
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = ConvTranspose1dConfig::new([5, 1], 5);
|
||||
let k = (config.channels[1] * config.kernel_size) as f64;
|
||||
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight.to_data().assert_approx_eq::<f32>(
|
||||
&TensorData::zeros::<f32, _>(conv.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ConvTranspose1dConfig::new([5, 2], 5);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{conv}"),
|
||||
"ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = ConvTranspose1dConfig::new([5, 3], 3);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
|
||||
let _ = conv.forward(input);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::conv::checks;
|
||||
use burn::config::Config;
|
||||
use burn::module::Content;
|
||||
use burn::module::DisplaySettings;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::Module;
|
||||
use burn::module::ModuleDisplay;
|
||||
use burn::module::Param;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::conv_transpose2d;
|
||||
use burn::tensor::ops::ConvTransposeOptions;
|
||||
|
||||
/// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer
|
||||
/// using the [init function](ConvTranspose2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ConvTranspose2dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "[0, 0]")]
|
||||
pub padding: [usize; 2],
|
||||
/// The padding output configuration.
|
||||
#[config(default = "[0, 0]")]
|
||||
pub padding_out: [usize; 2],
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a 2D transposed convolution over input tensors.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct ConvTranspose2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// Padding configuration.
|
||||
pub padding: [usize; 2],
|
||||
/// Padding output configuration.
|
||||
pub padding_out: [usize; 2],
|
||||
/// Number of channels.
|
||||
pub channels: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for ConvTranspose2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("channels", &format!("{:?}", &self.channels))
|
||||
.add("stride", &format!("{:?}", &self.stride))
|
||||
.add("kernel_size", &format!("{:?}", &self.kernel_size))
|
||||
.add("dilation", &format!("{:?}", &self.dilation))
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &format!("{:?}", &self.padding))
|
||||
.add("padding_out", &format!("{:?}", &self.padding_out))
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConvTranspose2dConfig {
|
||||
/// Initialize a new [conv transpose 2d](ConvTranspose2d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose2d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
|
||||
|
||||
let shape = [
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
];
|
||||
|
||||
let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::<usize>();
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), None, device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(
|
||||
self.initializer
|
||||
.init_with([self.channels[1]], Some(fan_in), None, device),
|
||||
);
|
||||
}
|
||||
|
||||
ConvTranspose2d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
padding: self.padding,
|
||||
padding_out: self.padding_out,
|
||||
channels: self.channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvTranspose2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [conv_transpose2d](burn::tensor::module::conv_transpose2d).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels_out, height_out, width_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
conv_transpose2d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
ConvTransposeOptions::new(
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = ConvTranspose2dConfig::new([5, 1], [5, 5]);
|
||||
let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64;
|
||||
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config =
|
||||
ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight.to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::zeros::<f32, _>(conv.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ConvTranspose2dConfig::new([5, 2], [5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{conv}"),
|
||||
"ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = ConvTranspose2dConfig::new([5, 3], [3, 3]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
|
||||
let _ = conv.forward(input);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::conv::checks;
|
||||
use burn::config::Config;
|
||||
use burn::module::Content;
|
||||
use burn::module::DisplaySettings;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::Module;
|
||||
use burn::module::ModuleDisplay;
|
||||
use burn::module::Param;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::conv_transpose3d;
|
||||
use burn::tensor::ops::ConvTransposeOptions;
|
||||
|
||||
/// Configuration to create an [3D transposed convolution](ConvTranspose3d) layer
|
||||
/// using the [init function](ConvTranspose3dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ConvTranspose3dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 3],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1, 1]")]
|
||||
pub stride: [usize; 3],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1, 1]")]
|
||||
pub dilation: [usize; 3],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "[0, 0, 0]")]
|
||||
pub padding: [usize; 3],
|
||||
/// The padding output configuration.
|
||||
#[config(default = "[0, 0, 0]")]
|
||||
pub padding_out: [usize; 3],
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a 3D transposed convolution over input tensors.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct ConvTranspose3d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2, kernel_size_3]`
|
||||
pub weight: Param<Tensor<B, 5>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 3],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 3],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 3],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// Padding configuration.
|
||||
pub padding: [usize; 3],
|
||||
/// Padding output configuration.
|
||||
pub padding_out: [usize; 3],
|
||||
/// Number of channels.
|
||||
pub channels: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for ConvTranspose3d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("channels", &format!("{:?}", &self.channels))
|
||||
.add("stride", &format!("{:?}", &self.stride))
|
||||
.add("kernel_size", &format!("{:?}", &self.kernel_size))
|
||||
.add("dilation", &format!("{:?}", &self.dilation))
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &format!("{:?}", &self.padding))
|
||||
.add("padding_out", &format!("{:?}", &self.padding_out))
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConvTranspose3dConfig {
|
||||
/// Initialize a new [conv transpose 2d](ConvTranspose3d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose3d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
|
||||
|
||||
let shape = [
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
self.kernel_size[2],
|
||||
];
|
||||
|
||||
let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::<usize>();
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), None, device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(
|
||||
self.initializer
|
||||
.init_with([self.channels[1]], Some(fan_in), None, device),
|
||||
);
|
||||
}
|
||||
|
||||
ConvTranspose3d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
padding: self.padding,
|
||||
padding_out: self.padding_out,
|
||||
channels: self.channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvTranspose3d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [conv_transpose3d](burn::tensor::module::conv_transpose3d).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, depth_in, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels_out, depth_out, height_out, width_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 5>) -> Tensor<B, 5> {
|
||||
conv_transpose3d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
ConvTransposeOptions::new(
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = ConvTranspose3dConfig::new([5, 1], [5, 5, 5]);
|
||||
let k = (config.channels[1]
|
||||
* config.kernel_size[0]
|
||||
* config.kernel_size[1]
|
||||
* config.kernel_size[2]) as f64;
|
||||
let k = (config.groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config =
|
||||
ConvTranspose3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight.to_data().assert_approx_eq::<f32>(
|
||||
&TensorData::zeros::<f32, _>(conv.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ConvTranspose3dConfig::new([5, 2], [5, 5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{conv}"),
|
||||
"ConvTranspose3d {channels: [5, 2], stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: [0, 0, 0], padding_out: [0, 0, 0], params: 1252}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = ConvTranspose3dConfig::new([5, 3], [3, 3, 3]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 5>::zeros([1, 4, 10, 10, 10], &Default::default());
|
||||
let _ = conv.forward(input);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,295 @@
|
||||
use alloc::format;
|
||||
use burn::tensor::ops::DeformConvOptions;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig2d;
|
||||
use burn::config::Config;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::deform_conv2d;
|
||||
|
||||
use crate::conv::checks;
|
||||
|
||||
/// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DeformConv2dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub weight_groups: usize,
|
||||
/// Offset groups.
|
||||
#[config(default = "1")]
|
||||
pub offset_groups: usize,
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// ### Warning
|
||||
/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
|
||||
/// size is not supported as it will not produce the same output size.
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a deformable 2D convolution over input tensors.
|
||||
///
|
||||
/// Should be created with [DeformConv2dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct DeformConv2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub weight_groups: usize,
|
||||
/// Offset groups.
|
||||
pub offset_groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
impl DeformConv2dConfig {
|
||||
/// Initialize a new [DeformConv2d](DeformConv2d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> DeformConv2d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);
|
||||
if self.padding == PaddingConfig2d::Same {
|
||||
checks::check_same_padding_support(&self.kernel_size);
|
||||
}
|
||||
|
||||
let shape = [
|
||||
self.channels[1],
|
||||
self.channels[0] / self.weight_groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
];
|
||||
|
||||
let k = self.kernel_size.iter().product::<usize>();
|
||||
let fan_in = self.channels[0] / self.weight_groups * k;
|
||||
let fan_out = self.channels[1] / self.weight_groups * k;
|
||||
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), Some(fan_out), device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(self.initializer.init_with(
|
||||
[self.channels[1]],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
device,
|
||||
));
|
||||
}
|
||||
|
||||
DeformConv2d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
weight_groups: self.weight_groups,
|
||||
offset_groups: self.weight_groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for DeformConv2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
|
||||
content
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("weight_groups", &self.weight_groups)
|
||||
.add("offset_groups", &self.offset_groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> DeformConv2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [deform_conv2d](burn::tensor::module::deform_conv2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, height_in, width_in]`
|
||||
/// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]`
|
||||
/// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]`
|
||||
/// - output: `[batch_size, channels_out, height_out, width_out]`
|
||||
pub fn forward(
|
||||
&self,
|
||||
input: Tensor<B, 4>,
|
||||
offset: Tensor<B, 4>,
|
||||
mask: Option<Tensor<B, 4>>,
|
||||
) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
let padding =
|
||||
self.padding
|
||||
.calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
|
||||
deform_conv2d(
|
||||
input,
|
||||
offset,
|
||||
self.weight.val(),
|
||||
mask,
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
DeformConvOptions::new(
|
||||
self.stride,
|
||||
padding,
|
||||
self.dilation,
|
||||
self.weight_groups,
|
||||
self.offset_groups,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
|
||||
let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
|
||||
let k = (config.offset_groups as f64 / k).sqrt().elem::<FT>();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight.to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::zeros::<f32, _>(conv.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_out() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true, // test that fan_out is passed to `init_with()`
|
||||
};
|
||||
|
||||
let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_with_groups_is_valid() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true,
|
||||
};
|
||||
|
||||
let config = DeformConv2dConfig::new([4, 4], [1, 1])
|
||||
.with_initializer(init.clone())
|
||||
.with_weight_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Both channels must be divisible by the number of groups."]
|
||||
fn channels_with_groups_is_invalid() {
|
||||
let device = Default::default();
|
||||
let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Same padding with an even kernel size is not supported"]
|
||||
fn same_with_even_kernel_is_invalid() {
|
||||
let device = Default::default();
|
||||
let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{conv}"),
|
||||
"DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
|
||||
fn input_channels_mismatch() {
|
||||
let config = DeformConv2dConfig::new([5, 3], [3, 3]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
|
||||
let offset = Tensor::<TestBackend, 4>::zeros([1, 2 * 3 * 3, 10, 10], &Default::default());
|
||||
let _ = conv.forward(input, offset, None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv3d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
|
||||
pub(crate) mod checks;
|
||||
|
||||
pub use conv_transpose1d::*;
|
||||
pub use conv_transpose2d::*;
|
||||
pub use conv_transpose3d::*;
|
||||
pub use conv1d::*;
|
||||
pub use conv2d::*;
|
||||
pub use conv3d::*;
|
||||
pub use deform_conv2d::*;
|
||||
@@ -0,0 +1,124 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Distribution, Tensor};
|
||||
|
||||
/// Configuration to create a [Dropout](Dropout) layer using the [init function](DropoutConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DropoutConfig {
|
||||
/// The probability of randomly zeroes some elements of the input tensor during training.
|
||||
pub prob: f64,
|
||||
}
|
||||
|
||||
/// Set at random some elements of the input tensor to zero during training.
|
||||
///
|
||||
/// This is an effective regularization technique as describe in the paper
|
||||
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
|
||||
///
|
||||
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
|
||||
///
|
||||
/// Should be created with [DropoutConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Dropout {
|
||||
/// The probability of randomly zeroes some elements of the input tensor during training.
|
||||
pub prob: f64,
|
||||
}
|
||||
|
||||
impl DropoutConfig {
|
||||
/// Initialize a new [dropout](Dropout) module.
|
||||
pub fn init(&self) -> Dropout {
|
||||
if self.prob < 0.0 || self.prob > 1.0 {
|
||||
panic!(
|
||||
"Dropout probability should be between 0 and 1, but got {}",
|
||||
self.prob
|
||||
);
|
||||
}
|
||||
Dropout { prob: self.prob }
|
||||
}
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [Dropout](Dropout) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
if !B::ad_enabled(&input.device()) || self.prob == 0.0 {
|
||||
return input;
|
||||
}
|
||||
|
||||
let prob_keep = 1.0 - self.prob;
|
||||
let random = input.random_like(Distribution::Bernoulli(prob_keep));
|
||||
let x = input * random;
|
||||
|
||||
x * (1.0 / prob_keep)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Dropout {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("prob", &self.prob).optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::tensor::Shape;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::{TestAutodiffBackend, TestBackend};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use crate::TestBackend;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn with_ad_backend_should_mark_input() {
|
||||
let tensor =
|
||||
Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
|
||||
let dropout = DropoutConfig::new(0.5).init();
|
||||
|
||||
let output = dropout.forward(tensor.clone());
|
||||
|
||||
assert_ne!(tensor.to_data(), output.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn without_ad_backend_should_not_change_input() {
|
||||
let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
|
||||
let dropout = DropoutConfig::new(0.5).init();
|
||||
|
||||
let output = dropout.forward(tensor.clone());
|
||||
|
||||
assert_eq!(tensor.to_data(), output.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = DropoutConfig::new(0.5);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "Dropout {prob: 0.5}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Dropout probability should be between 0 and 1,"]
|
||||
fn dropout_prob_invalid() {
|
||||
let config = DropoutConfig::new(-10.);
|
||||
let _layer = config.init();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::Module;
|
||||
use burn::module::Param;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Int;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use burn::tensor::module::embedding;
|
||||
|
||||
/// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// The number of embedding vectors.
|
||||
pub n_embedding: usize,
|
||||
/// The size of each vector.
|
||||
pub d_model: usize,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(default = "Initializer::Normal{mean:0.0, std:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Lookup table to store a fix number of vectors.
|
||||
///
|
||||
/// Should be created with [EmbeddingConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Embedding<B: Backend> {
|
||||
/// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
|
||||
/// from a normal distribution `N(0, 1)`.
|
||||
pub weight: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Embedding<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [n_embedding, d_model] = self.weight.shape().dims();
|
||||
content
|
||||
.add("n_embedding", &n_embedding)
|
||||
.add("d_model", &d_model)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
/// Initialize a new [embedding](Embedding) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
|
||||
let weight = self
|
||||
.initializer
|
||||
.init([self.n_embedding, self.d_model], device);
|
||||
|
||||
Embedding { weight }
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Embedding<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [embedding](burn::tensor::module::embedding).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, seq_length]`
|
||||
/// - output: `[batch_size, seq_length, d_model]`
|
||||
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
embedding(self.weight.val(), input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros);
|
||||
let embed = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
embed.weight.to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::zeros::<f32, _>(embed.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = EmbeddingConfig::new(100, 10);
|
||||
let embed = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{embed}"),
|
||||
"Embedding {n_embedding: 100, d_model: 10, params: 1000}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn::tensor::module::interpolate;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::ops::InterpolateOptions;
|
||||
|
||||
use super::InterpolateMode;
|
||||
|
||||
/// Configuration for the 1D interpolation module.
|
||||
///
|
||||
/// This struct defines the configuration options for the 1D interpolation operation.
|
||||
/// It allows specifying the output size, scale factor, and interpolation mode.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Interpolate1dConfig {
|
||||
/// Output size of the interpolated tensor.
|
||||
/// If specified, this takes precedence over `scale_factor`.
|
||||
#[config(default = "None")]
|
||||
pub output_size: Option<usize>,
|
||||
|
||||
/// Scale factor for resizing the input tensor.
|
||||
/// This is used when `output_size` is not specified.
|
||||
#[config(default = "None")]
|
||||
pub scale_factor: Option<f32>,
|
||||
|
||||
/// Interpolation mode to use for resizing.
|
||||
/// Determines how the output values are calculated.
|
||||
#[config(default = "InterpolateMode::Nearest")]
|
||||
pub mode: InterpolateMode,
|
||||
|
||||
/// If `true`, the input and output tensors are aligned by their corner pixels.
|
||||
/// If `false`, half-pixel coordinate mapping is used instead.
|
||||
#[config(default = true)]
|
||||
pub align_corners: bool,
|
||||
}
|
||||
|
||||
/// Interpolate module for resizing 1D tensors with shape [N, C, L].
|
||||
///
|
||||
/// This struct represents a 1D interpolation module that can resize tensors
|
||||
/// using various interpolation methods. It provides flexibility in specifying
|
||||
/// either an output size or a scale factor for resizing, along with options
|
||||
/// for the interpolation mode.
|
||||
///
|
||||
/// The module can be used to upsample or downsample 1D tensors, preserving the
|
||||
/// number of channels and batch size while adjusting the length dimension.
|
||||
///
|
||||
/// The module can be created using the [Interpolate1dConfig] struct and the
|
||||
/// `init` method, which returns an instance of the [Interpolate1d] struct.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Interpolate1d {
|
||||
/// Output size of the interpolated tensor
|
||||
pub output_size: Option<usize>,
|
||||
|
||||
/// Scale factor for resizing the input tensor
|
||||
pub scale_factor: Option<f32>,
|
||||
|
||||
/// Interpolation mode used for resizing
|
||||
pub mode: Ignored<InterpolateMode>,
|
||||
|
||||
/// Whether to align corner pixels
|
||||
pub align_corners: bool,
|
||||
}
|
||||
|
||||
impl Interpolate1dConfig {
|
||||
/// Initialize the interpolation module
|
||||
pub fn init(self) -> Interpolate1d {
|
||||
Interpolate1d {
|
||||
output_size: self.output_size,
|
||||
scale_factor: self.scale_factor,
|
||||
mode: Ignored(self.mode),
|
||||
align_corners: self.align_corners,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Interpolate1d {
|
||||
/// Performs the forward pass of the 1D interpolation module
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - Input tensor with shape [N, C, L]
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Resized tensor with shape [N, C, L'], where L' is determined by
|
||||
/// the output_size or scale_factor specified in the module configuration
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);
|
||||
/// let interpolate = Interpolate1dConfig::new()
|
||||
/// .with_output_size(Some(128))
|
||||
/// .init();
|
||||
/// let output = interpolate.forward(input);
|
||||
/// assert_eq!(output.dims(), [1, 3, 128]);
|
||||
/// ```
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
|
||||
|
||||
// Use the interpolate operation to resize the temporal input tensor
|
||||
// by adding a new dimension for the interpolation axis
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let result = interpolate(
|
||||
input,
|
||||
[1, output_size],
|
||||
InterpolateOptions::new(self.mode.0.clone().into())
|
||||
.with_align_corners(self.align_corners),
|
||||
);
|
||||
|
||||
result.squeeze_dims(&[2])
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate output size based on input dimensions, output size, and scale factor
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_dims` - Input dimensions of the tensor
|
||||
/// * `output_size` - Output size for the interpolated tensor
|
||||
/// * `scale_factor` - Scale factor for resizing the tensor
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Output size for the interpolated tensor
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if neither output_size nor scale_factor is provided
|
||||
/// or if the scale factor is too large
|
||||
fn calculate_output_size(
|
||||
input_dims: [usize; 3],
|
||||
output_size: Option<usize>,
|
||||
scale_factor: Option<f32>,
|
||||
) -> usize {
|
||||
match (output_size, scale_factor) {
|
||||
(Some(output_size), None) => {
|
||||
// Use provided
|
||||
output_size
|
||||
}
|
||||
(None, Some(scale_factor)) => {
|
||||
// Calculate output size based on scale factor
|
||||
let [_, _, l] = input_dims;
|
||||
|
||||
let new_dim = (l as f64) * (scale_factor as f64);
|
||||
|
||||
if new_dim > usize::MAX as f64 {
|
||||
panic!("Scale factor is too large");
|
||||
}
|
||||
|
||||
new_dim as usize
|
||||
}
|
||||
_ => panic!("Either output_size or scale_factor must be provided"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Interpolate1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("mode", &self.mode)
|
||||
.add("output_size", &format!("{:?}", self.output_size))
|
||||
.add("scale_factor", &self.scale_factor)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use burn::tensor::Distribution;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
#[test]
|
||||
fn test_calculate_output_size() {
|
||||
let input_dims = [1, 1, 4];
|
||||
|
||||
let output_size = calculate_output_size(input_dims, Some(2), None);
|
||||
assert_eq!(output_size, 2);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some(2.0));
|
||||
assert_eq!(output_size, 8);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some(0.5));
|
||||
assert_eq!(output_size, 2);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some(1.5));
|
||||
assert_eq!(output_size, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
|
||||
fn test_panic() {
|
||||
let input_dims = [1, 1, 4];
|
||||
calculate_output_size(input_dims, None, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Scale factor is too large")]
|
||||
fn test_large_scale_factor() {
|
||||
let input_dims = [1, 1, usize::MAX - 1];
|
||||
calculate_output_size(input_dims, None, Some(2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module() {
|
||||
let input = Tensor::<TestBackend, 3>::random(
|
||||
[2, 3, 4],
|
||||
Distribution::Uniform(0.0, 1.0),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// Test with output_size
|
||||
let config = Interpolate1dConfig::new().with_output_size(Some(8));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 8]);
|
||||
|
||||
// Test with scale_factor
|
||||
let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 2]);
|
||||
|
||||
// Test with different interpolation mode
|
||||
let config = Interpolate1dConfig::new()
|
||||
.with_output_size(Some(6))
|
||||
.with_mode(InterpolateMode::Linear);
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input);
|
||||
assert_eq!(output.dims(), [2, 3, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Interpolate1dConfig::new().with_output_size(Some(20));
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"Interpolate1d {mode: Nearest, output_size: Some(20), \
|
||||
scale_factor: None}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
use alloc::format;
|
||||
|
||||
use burn::tensor::module::interpolate;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::ops::InterpolateOptions;
|
||||
|
||||
use super::InterpolateMode;
|
||||
|
||||
/// Configuration for the 2D interpolation module.
|
||||
///
|
||||
/// This struct defines the configuration options for the 2D interpolation operation.
|
||||
/// It allows specifying the output size, scale factor, and interpolation mode.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Interpolate2dConfig {
|
||||
/// Output size of the interpolated tensor.
|
||||
/// If specified, this takes precedence over `scale_factor`.
|
||||
#[config(default = "None")]
|
||||
pub output_size: Option<[usize; 2]>,
|
||||
|
||||
/// Scale factor for resizing the input tensor.
|
||||
/// This is used when `output_size` is not specified.
|
||||
#[config(default = "None")]
|
||||
pub scale_factor: Option<[f32; 2]>,
|
||||
|
||||
/// Interpolation mode to use for resizing.
|
||||
/// Determines how the output values are calculated.
|
||||
#[config(default = "InterpolateMode::Nearest")]
|
||||
pub mode: InterpolateMode,
|
||||
|
||||
/// If `true`, the input and output tensors are aligned by their corner pixels.
|
||||
/// If `false`, half-pixel coordinate mapping is used instead.
|
||||
#[config(default = true)]
|
||||
pub align_corners: bool,
|
||||
}
|
||||
|
||||
/// Interpolate module for resizing tensors with shape [N, C, H, W].
|
||||
///
|
||||
/// This struct represents an interpolation module that can resize tensors
|
||||
/// using various interpolation methods. It provides flexibility in specifying
|
||||
/// either an output size or a scale factor for resizing, along with options
|
||||
/// for the interpolation mode.
|
||||
///
|
||||
/// The module can be used to upsample or downsample tensors, preserving the
|
||||
/// number of channels and batch size while adjusting the height and width
|
||||
/// dimensions.
|
||||
///
|
||||
/// The module can be created using the [Interpolate2dConfig] struct and the
|
||||
/// `init` method, which returns an instance of the [Interpolate2d] struct.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Interpolate2d {
|
||||
/// Output size of the interpolated tensor
|
||||
pub output_size: Option<[usize; 2]>,
|
||||
|
||||
/// Scale factor for resizing the input tensor
|
||||
pub scale_factor: Option<[f32; 2]>,
|
||||
|
||||
/// Interpolation mode used for resizing
|
||||
pub mode: Ignored<InterpolateMode>,
|
||||
|
||||
/// Whether to align corner pixels
|
||||
pub align_corners: bool,
|
||||
}
|
||||
|
||||
impl Interpolate2dConfig {
|
||||
/// Initialize the interpolation module
|
||||
pub fn init(self) -> Interpolate2d {
|
||||
Interpolate2d {
|
||||
output_size: self.output_size,
|
||||
scale_factor: self.scale_factor,
|
||||
mode: Ignored(self.mode),
|
||||
align_corners: self.align_corners,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Interpolate2d {
|
||||
/// Performs the forward pass of the interpolation module
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - Input tensor with shape [N, C, H, W]
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by
|
||||
/// the output_size or scale_factor specified in the module configuration
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);
|
||||
/// let interpolate = Interpolate2dConfig::new()
|
||||
/// .with_output_size(Some([128, 128]))
|
||||
/// .init();
|
||||
/// let output = interpolate.forward(input);
|
||||
/// assert_eq!(output.dims(), [1, 3, 128, 128]);
|
||||
/// ```
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
|
||||
interpolate(
|
||||
input,
|
||||
output_size,
|
||||
InterpolateOptions::new(self.mode.0.clone().into())
|
||||
.with_align_corners(self.align_corners),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the output size for tensor interpolation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].
|
||||
/// * `output_size` - Optional desired output size [H', W'].
|
||||
/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple [H', W'] representing the calculated output size.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if neither `output_size` nor `scale_factor` is provided,
|
||||
/// or if the scale factor results in dimensions exceeding usize::MAX.
|
||||
fn calculate_output_size(
|
||||
input_dims: [usize; 4],
|
||||
output_size: Option<[usize; 2]>,
|
||||
scale_factor: Option<[f32; 2]>,
|
||||
) -> [usize; 2] {
|
||||
match (output_size, scale_factor) {
|
||||
(Some(output_size), None) => {
|
||||
// Use provided
|
||||
output_size
|
||||
}
|
||||
(None, Some(scale_factor)) => {
|
||||
// Calculate output size based on scale factor
|
||||
let [_, _, h, w] = input_dims;
|
||||
|
||||
let new_dim_h = (h as f64) * (scale_factor[0] as f64);
|
||||
|
||||
if new_dim_h > usize::MAX as f64 {
|
||||
panic!("Scale factor for height is too large");
|
||||
}
|
||||
|
||||
let new_dim_w = (w as f64) * (scale_factor[1] as f64);
|
||||
|
||||
if new_dim_w > usize::MAX as f64 {
|
||||
panic!("Scale factor for width is too large");
|
||||
}
|
||||
|
||||
[new_dim_h as usize, new_dim_w as usize]
|
||||
}
|
||||
_ => panic!("Either output_size or scale_factor must be provided"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Interpolate2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("mode", &self.mode)
|
||||
.add("output_size", &format!("{:?}", self.output_size))
|
||||
.add("scale_factor", &self.scale_factor)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::Distribution;
|
||||
|
||||
use crate::TestBackend;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_calculate_output_size() {
|
||||
let input_dims = [1, 1, 4, 4];
|
||||
|
||||
let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
|
||||
assert_eq!(output_size, [2, 2]);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
|
||||
assert_eq!(output_size, [8, 8]);
|
||||
|
||||
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
|
||||
assert_eq!(output_size, [2, 2]);
|
||||
|
||||
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
|
||||
assert_eq!(output_size, [8, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
|
||||
fn test_missing_params() {
|
||||
calculate_output_size([1, 1, 4, 4], None, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Scale factor for height is too large")]
|
||||
fn test_infinite_height() {
|
||||
calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Scale factor for width is too large")]
|
||||
fn test_infinite_width() {
|
||||
calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module() {
|
||||
let input = Tensor::<TestBackend, 4>::random(
|
||||
[2, 3, 4, 4],
|
||||
Distribution::Uniform(0.0, 1.0),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// Test with output_size
|
||||
let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 8, 8]);
|
||||
|
||||
// Test with scale_factor
|
||||
let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 2, 2]);
|
||||
|
||||
// Test with different interpolation mode
|
||||
let config = Interpolate2dConfig::new()
|
||||
.with_output_size(Some([6, 6]))
|
||||
.with_mode(InterpolateMode::Linear);
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input);
|
||||
assert_eq!(output.dims(), [2, 3, 6, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
|
||||
scale_factor: None}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
mod interpolate1d;
|
||||
mod interpolate2d;
|
||||
|
||||
pub use interpolate1d::*;
|
||||
pub use interpolate2d::*;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::ops::InterpolateMode as OpsInterpolateMode;
|
||||
|
||||
/// Algorithm used for downsampling and upsampling
|
||||
///
|
||||
/// This enum defines different interpolation modes for resampling data.
|
||||
#[derive(Config, Debug)]
|
||||
pub enum InterpolateMode {
|
||||
/// Nearest-neighbor interpolation
|
||||
///
|
||||
/// This mode selects the value of the nearest sample point for each output pixel.
|
||||
/// It is applicable for both temporal and spatial data.
|
||||
Nearest,
|
||||
|
||||
/// Linear interpolation
|
||||
///
|
||||
/// This mode calculates the output value using linear
|
||||
/// interpolation between nearby sample points.
|
||||
///
|
||||
/// It is applicable for both temporal and spatial data.
|
||||
Linear,
|
||||
|
||||
/// Cubic interpolation
|
||||
///
|
||||
/// This mode uses cubic interpolation to calculate the output value
|
||||
/// based on surrounding sample points.
|
||||
///
|
||||
/// It is applicable for both temporal and spatial data and generally
|
||||
/// provides smoother results than linear interpolation.
|
||||
Cubic,
|
||||
}
|
||||
|
||||
impl From<InterpolateMode> for OpsInterpolateMode {
|
||||
fn from(mode: InterpolateMode) -> Self {
|
||||
match mode {
|
||||
InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
|
||||
InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
|
||||
InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Param;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::module::linear;
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
/// Configuration to create a [`Linear`] layer using the [init function](LinearConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LinearConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the output features.
|
||||
pub d_output: usize,
|
||||
/// If a bias should be applied during the linear transformation.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
/// The layout in which the linear parameters are stored.
|
||||
#[config(default = "LinearLayout::Row")]
|
||||
pub layout: LinearLayout,
|
||||
}
|
||||
|
||||
#[derive(Config, Debug, Copy)]
|
||||
/// The layout in which the linear parameters are stored.
|
||||
///
|
||||
/// This can have performance impacts.
|
||||
pub enum LinearLayout {
|
||||
/// Parameters are stored in Row major.
|
||||
Row,
|
||||
/// Parameters are stored in Col major.
|
||||
Col,
|
||||
}
|
||||
|
||||
/// Applies a linear transformation to the input tensor.
|
||||
///
|
||||
/// Should be created with [LinearConfig]
|
||||
///
|
||||
/// `O = IW + b`
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Linear<B: Backend> {
|
||||
/// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
|
||||
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
|
||||
pub weight: Param<Tensor<B, 2>>,
|
||||
/// Vector of size `d_output` initialized from a uniform distribution:
|
||||
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
}
|
||||
|
||||
impl LinearConfig {
|
||||
/// Initialize a new [`Linear`] module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
|
||||
let weight = match self.layout {
|
||||
LinearLayout::Row => {
|
||||
let shape = [self.d_input, self.d_output];
|
||||
self.initializer
|
||||
.init_with(shape, Some(self.d_input), Some(self.d_output), device)
|
||||
}
|
||||
LinearLayout::Col => {
|
||||
let shape = [self.d_output, self.d_input];
|
||||
|
||||
self.initializer
|
||||
.init_with(shape, Some(self.d_output), Some(self.d_input), device)
|
||||
// The param is already transposed when init. We re-transpose to have
|
||||
// [d_output, d_input] while saving.
|
||||
.save_mapper(move |tensor| {
|
||||
B::sync(&tensor.device()).unwrap();
|
||||
let tensor = tensor.transpose();
|
||||
B::sync(&tensor.device()).unwrap();
|
||||
tensor
|
||||
})
|
||||
// When loading from record we have to transpose.
|
||||
.load_mapper(move |tensor| {
|
||||
B::sync(&tensor.device()).unwrap();
|
||||
let tensor = tensor.transpose();
|
||||
B::sync(&tensor.device()).unwrap();
|
||||
|
||||
tensor
|
||||
})
|
||||
// When loading from initialization, we have to transpose.
|
||||
.init_mapper(|tensor| {
|
||||
B::sync(&tensor.device()).unwrap();
|
||||
let tensor = tensor.transpose();
|
||||
B::sync(&tensor.device()).unwrap();
|
||||
tensor
|
||||
})
|
||||
}
|
||||
};
|
||||
let bias = if self.bias {
|
||||
Some(self.initializer.init_with(
|
||||
[self.d_output],
|
||||
Some(self.d_input),
|
||||
Some(self.d_output),
|
||||
device,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Linear { weight, bias }
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Linear<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `input` - The input tensor of shape `[..., d_input]`.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., d_input]`
|
||||
/// - output: `[..., d_output]`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The transformed tensor of shape `[..., d_output]`.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
linear(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|b| b.val()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Linear<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, d_output] = self.weight.shape().dims();
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_output", &d_output)
|
||||
.add("bias", &self.bias.is_some())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::module::ParamId;
|
||||
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
|
||||
use burn::tensor::ElementConversion;
|
||||
use burn::tensor::{Shape, TensorData};
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = LinearConfig::new(5, 5);
|
||||
let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();
|
||||
let linear = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(
|
||||
config.initializer,
|
||||
Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: false
|
||||
}
|
||||
);
|
||||
linear.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
|
||||
let linear = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
linear.weight.to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::zeros::<f32, _>(linear.weight.shape()),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_forward_no_bias() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let value = 2.;
|
||||
let config = LinearConfig::new(2, 3)
|
||||
.with_initializer(Initializer::Constant { value })
|
||||
.with_bias(false);
|
||||
let linear = config.init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
|
||||
let result = linear.forward(input);
|
||||
let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
|
||||
|
||||
assert_eq!(result.into_data(), expected_result.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_forward_with_bias() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let value = 2.;
|
||||
let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
|
||||
let linear = config.init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
|
||||
let result = linear.forward(input);
|
||||
let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
|
||||
|
||||
assert_eq!(result.into_data(), expected_result.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_1d() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let value = 2.;
|
||||
let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
|
||||
let linear = config.init::<TestBackend>(&device);
|
||||
|
||||
let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
|
||||
let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
|
||||
|
||||
let result_1d = linear.forward(input_1d).unsqueeze::<2>();
|
||||
let result_2d = linear.forward(input_2d);
|
||||
|
||||
assert_eq!(result_1d.into_data(), result_2d.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = LinearConfig::new(3, 5);
|
||||
let linear = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{linear}"),
|
||||
"Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layout() {
|
||||
let device = Default::default();
|
||||
let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
|
||||
let linear = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(linear.weight.dims(), [6, 12], "Shape is as configured");
|
||||
|
||||
let recorder = BinBytesRecorder::<FullPrecisionSettings>::new();
|
||||
|
||||
// We go through serialization to trigger the mappers..
|
||||
let record = linear.into_record();
|
||||
let data = recorder.record(record, ()).unwrap();
|
||||
let record = recorder.load(data.clone(), &device).unwrap();
|
||||
|
||||
let config = LinearConfig::new(12, 6).with_layout(LinearLayout::Row);
|
||||
let linear_row = config.init::<TestBackend>(&device).load_record(record);
|
||||
|
||||
assert_eq!(
|
||||
linear_row.weight.dims(),
|
||||
[12, 6],
|
||||
"Shape should be transposed"
|
||||
);
|
||||
|
||||
let record = recorder.load(data.clone(), &device).unwrap();
|
||||
let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
|
||||
let linear_col = config.init::<TestBackend>(&device).load_record(record);
|
||||
|
||||
assert_eq!(
|
||||
linear_col.weight.dims(),
|
||||
[6, 12],
|
||||
"Shape should be as configured"
|
||||
);
|
||||
|
||||
// We go through serialization to trigger the mappers.
|
||||
//
|
||||
// The test will fail if the mapper is not correctly given to the module after loading a
|
||||
// record.
|
||||
let record = linear_col.into_record();
|
||||
let data = recorder.record(record, ()).unwrap();
|
||||
|
||||
let record = recorder.load(data, &device).unwrap();
|
||||
let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
|
||||
let linear_col = config.init::<TestBackend>(&device).load_record(record);
|
||||
|
||||
assert_eq!(
|
||||
linear_col.weight.dims(),
|
||||
[6, 12],
|
||||
"Shape should be as configured"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn col_row_same_result() {
|
||||
let device = Default::default();
|
||||
let config_col = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
|
||||
let linear_col = config_col.init::<TestBackend>(&device);
|
||||
let signal = Tensor::<_, 2>::random([8, 6], burn::tensor::Distribution::Default, &device);
|
||||
let value = linear_col.forward(signal.clone());
|
||||
|
||||
let data_1 = value.into_data();
|
||||
|
||||
let weights = linear_col.weight.val().into_data();
|
||||
let weights = Tensor::from_data(weights, &device);
|
||||
|
||||
let linear = Linear {
|
||||
weight: Param::initialized(ParamId::new(), weights),
|
||||
bias: linear_col
|
||||
.bias
|
||||
.map(|b| Param::initialized(ParamId::new(), b.val())),
|
||||
};
|
||||
|
||||
let value = linear.forward(signal);
|
||||
let data_2 = value.into_data();
|
||||
|
||||
data_1.assert_approx_eq::<f32>(&data_2, Default::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
/// Attention module
|
||||
pub mod attention;
|
||||
|
||||
/// Cache module
|
||||
pub mod cache;
|
||||
|
||||
/// Convolution module
|
||||
pub mod conv;
|
||||
|
||||
/// Pooling module
|
||||
pub mod pool;
|
||||
|
||||
/// Transformer module
|
||||
pub mod transformer;
|
||||
|
||||
/// Interpolate module
|
||||
pub mod interpolate;
|
||||
|
||||
mod dropout;
|
||||
mod embedding;
|
||||
mod linear;
|
||||
mod noise;
|
||||
mod pos_encoding;
|
||||
mod rnn;
|
||||
mod rope_encoding;
|
||||
mod unfold;
|
||||
|
||||
pub mod norm;
|
||||
pub use norm::{batch::*, group::*, instance::*, layer::*, rms::*};
|
||||
|
||||
pub use dropout::*;
|
||||
pub use embedding::*;
|
||||
pub use linear::*;
|
||||
pub use noise::*;
|
||||
pub use pos_encoding::*;
|
||||
pub use rnn::*;
|
||||
pub use rope_encoding::*;
|
||||
pub use unfold::*;
|
||||
@@ -0,0 +1,123 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Distribution, Tensor};
|
||||
|
||||
/// Configuration to create a [GaussianNoise](GaussianNoise) layer using the [init function](GaussianNoiseConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct GaussianNoiseConfig {
|
||||
/// Standard deviation of the normal noise distribution.
|
||||
pub std: f64,
|
||||
}
|
||||
|
||||
/// Add pseudorandom Gaussian noise to an arbitrarily shaped tensor.
|
||||
///
|
||||
/// This is an effective regularization technique that also contributes to data augmentation.
|
||||
/// Please keep in mind that the value of [std](GaussianNoise::std) should be chosen with care in order to avoid
|
||||
/// distortion.
|
||||
///
|
||||
/// Should be created with [GaussianNoiseConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct GaussianNoise {
|
||||
/// Standard deviation of the normal noise distribution.
|
||||
pub std: f64,
|
||||
}
|
||||
|
||||
impl GaussianNoiseConfig {
|
||||
/// Initialize a new [Gaussian noise](GaussianNoise) module.
|
||||
pub fn init(&self) -> GaussianNoise {
|
||||
if self.std.is_sign_negative() {
|
||||
panic!(
|
||||
"Standard deviation is required to be non-negative, but got {}",
|
||||
self.std
|
||||
);
|
||||
}
|
||||
GaussianNoise { std: self.std }
|
||||
}
|
||||
}
|
||||
|
||||
impl GaussianNoise {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [GaussianNoise](GaussianNoise) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
if B::ad_enabled(&input.device()) && self.std != 0.0 {
|
||||
let noise = Tensor::random(
|
||||
input.shape(),
|
||||
Distribution::Normal(0.0, self.std),
|
||||
&input.device(),
|
||||
);
|
||||
input + noise
|
||||
} else {
|
||||
input
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for GaussianNoise {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("std", &self.std).optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::tensor::Shape;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::{TestAutodiffBackend, TestBackend};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use crate::TestBackend;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn with_ad_backend_should_mark_input() {
|
||||
let tensor =
|
||||
Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
|
||||
let noise = GaussianNoiseConfig::new(0.5).init();
|
||||
|
||||
let output = noise.forward(tensor.clone());
|
||||
|
||||
assert_ne!(tensor.to_data(), output.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn without_ad_backend_should_not_change_input() {
|
||||
let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
|
||||
let noise = GaussianNoiseConfig::new(0.5).init();
|
||||
|
||||
let output = noise.forward(tensor.clone());
|
||||
|
||||
assert_eq!(tensor.to_data(), output.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Standard deviation is required to be non-negative")]
|
||||
fn negative_std_should_panic() {
|
||||
GaussianNoiseConfig { std: -0.5 }.init();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = GaussianNoiseConfig::new(0.5);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(alloc::format!("{layer}"), "GaussianNoise {std: 0.5}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,484 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::Initializer;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param, RunningState},
|
||||
};
|
||||
|
||||
/// [`BatchNorm`] Configuration.
|
||||
///
|
||||
/// Used to create a [`BatchNorm`] layer using the [`BatchNormConfig::init`].
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BatchNormConfig {
|
||||
/// The number of features.
|
||||
pub num_features: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub epsilon: f64,
|
||||
/// Momentum used to update the metrics. Default: 0.1
|
||||
#[config(default = 0.1)]
|
||||
pub momentum: f64,
|
||||
}
|
||||
|
||||
/// Applies Batch Normalization over a tensor.
|
||||
///
|
||||
/// Based upon the paper [Batch Normalization](https://arxiv.org/abs/1502.03167).
|
||||
///
|
||||
/// Assumes input tensor is of shape ``[batch_size, channels, ...]``.
|
||||
///
|
||||
/// `Y = norm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `norm` is the normalization function
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
/// Should be created using [`BatchNormConfig`].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BatchNorm<B: Backend> {
|
||||
/// The learnable weight gamma.
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// The learnable weight beta.
|
||||
pub beta: Param<Tensor<B, 1>>,
|
||||
/// The running mean.
|
||||
pub running_mean: RunningState<Tensor<B, 1>>,
|
||||
/// The running variance.
|
||||
pub running_var: RunningState<Tensor<B, 1>>,
|
||||
/// Momentum used to update the metrics.
|
||||
pub momentum: f64,
|
||||
/// A value required for numerical stability.
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl BatchNormConfig {
|
||||
/// Initializes a new [batch norm](BatchNorm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {
|
||||
let gamma = Initializer::Ones.init([self.num_features], device);
|
||||
let beta = Initializer::Zeros.init([self.num_features], device);
|
||||
|
||||
let running_mean = Tensor::zeros([self.num_features], device);
|
||||
let running_var = Tensor::ones([self.num_features], device);
|
||||
|
||||
BatchNorm {
|
||||
gamma,
|
||||
beta,
|
||||
running_mean: RunningState::new(running_mean),
|
||||
running_var: RunningState::new(running_var),
|
||||
momentum: self.momentum,
|
||||
epsilon: self.epsilon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BatchNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [`BatchNorm`] for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - `input`: ``[batch_size, channels, ...]``
|
||||
/// - `output`: ``[batch_size, channels, ...]``
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if the input tensor has rank < 2.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// Should be move to a compilation error when const generic support that kind of
|
||||
// validation. https://github.com/rust-lang/rust/issues/76560
|
||||
if D < 2 {
|
||||
panic!(
|
||||
"BatchNorm can only be applied on tensors of rank >= 2 with the following shape \
|
||||
[batch_size, channels, ...], received {}D tensor",
|
||||
D
|
||||
);
|
||||
}
|
||||
|
||||
match B::ad_enabled(&input.device()) {
|
||||
true => self.forward_train(input),
|
||||
false => self.forward_inference(input),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let device = input.device();
|
||||
let channels = input.dims()[1];
|
||||
let mean = self.running_mean.value().to_device(&device);
|
||||
let var = self.running_var.value().to_device(&device);
|
||||
|
||||
let mut shape = [1; D];
|
||||
shape[1] = channels;
|
||||
|
||||
self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
|
||||
}
|
||||
|
||||
fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let device = input.device();
|
||||
let dims = input.dims();
|
||||
let batch_size = dims[0];
|
||||
let channels = dims[1];
|
||||
|
||||
let mut shape_unsqueeze = [1; D];
|
||||
let mut flatten_size = batch_size;
|
||||
shape_unsqueeze[1] = channels;
|
||||
|
||||
for dim in dims.iter().take(D).skip(2) {
|
||||
flatten_size *= dim;
|
||||
}
|
||||
|
||||
let mean = input
|
||||
.clone()
|
||||
.swap_dims(0, 1)
|
||||
.reshape([channels, flatten_size])
|
||||
.mean_dim(1)
|
||||
.reshape(shape_unsqueeze);
|
||||
|
||||
let var = input
|
||||
.clone()
|
||||
.sub(mean.clone())
|
||||
.square()
|
||||
.swap_dims(0, 1)
|
||||
.reshape([channels, flatten_size])
|
||||
.mean_dim(1)
|
||||
.reshape(shape_unsqueeze);
|
||||
|
||||
let running_mean = self.running_mean.value_sync().to_device(&device);
|
||||
let running_var = self.running_var.value_sync().to_device(&device);
|
||||
|
||||
let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
|
||||
mean.clone()
|
||||
.detach()
|
||||
.mul_scalar(self.momentum)
|
||||
.reshape([channels]),
|
||||
);
|
||||
let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
|
||||
var.clone()
|
||||
.detach()
|
||||
.mul_scalar(self.momentum)
|
||||
.reshape([channels]),
|
||||
);
|
||||
|
||||
self.running_mean.update(running_mean.detach());
|
||||
self.running_var.update(running_var.detach());
|
||||
|
||||
self.forward_shared(input, mean, var)
|
||||
}
|
||||
|
||||
fn forward_shared<const D: usize>(
|
||||
&self,
|
||||
x: Tensor<B, D>,
|
||||
mean: Tensor<B, D>,
|
||||
var: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let channels = x.dims()[1];
|
||||
let mut shape = [1; D];
|
||||
shape[1] = channels;
|
||||
|
||||
let std = var.add_scalar(self.epsilon).sqrt();
|
||||
|
||||
let x = x.sub(mean);
|
||||
let x = x.div(std);
|
||||
|
||||
let x = x.mul(self.gamma.val().reshape(shape));
|
||||
|
||||
x.add(self.beta.val().reshape(shape))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for BatchNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [num_features] = self.beta.shape().dims();
|
||||
|
||||
content
|
||||
.add("num_features", &num_features)
|
||||
.add("momentum", &self.momentum)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg(test)]
|
||||
mod tests_1d {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_train() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
let output = module.forward(input_tensor(&device));
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_train(), Tolerance::rel_abs(0.1, 0.001));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_inference() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
module.forward(input_tensor(&device));
|
||||
let module = module.valid();
|
||||
let output = module.forward(input_tensor(&device));
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
|
||||
}
|
||||
|
||||
fn expected_valid() -> TensorData {
|
||||
TensorData::from([
|
||||
[[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
|
||||
[[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
|
||||
])
|
||||
}
|
||||
|
||||
fn expected_train() -> TensorData {
|
||||
TensorData::from([
|
||||
[
|
||||
[1.1483e+00, 3.7521e-01],
|
||||
[1.6272e-03, 7.5067e-01],
|
||||
[1.6204e+00, -4.5168e-02],
|
||||
],
|
||||
[
|
||||
[6.8856e-02, -1.5923e+00],
|
||||
[-1.6318e+00, 8.7949e-01],
|
||||
[-5.3368e-01, -1.0416e+00],
|
||||
],
|
||||
])
|
||||
}
|
||||
|
||||
fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
|
||||
Tensor::<B, 3>::from_floats(
|
||||
[
|
||||
[[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
|
||||
[[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
|
||||
],
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_train_inference() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
module.forward(input_tensor(&device));
|
||||
let module = module.valid();
|
||||
let output = module.forward(input_tensor(&device));
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
|
||||
|
||||
let module = module.train::<TestAutodiffBackend>();
|
||||
let output = module.forward(input_tensor(&device));
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_train(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg(test)]
|
||||
mod tests_2d {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_train() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
let output = module.forward(input_tensor(&device));
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[[1.5136, 0.7506], [-1.2216, 0.1477]],
|
||||
[[0.3135, 1.2252], [-0.4150, 0.6130]],
|
||||
[[1.4186, 0.3372], [-1.5183, 1.5262]],
|
||||
],
|
||||
[
|
||||
[[0.4483, -1.1914], [-1.2010, 0.7537]],
|
||||
[[-1.6752, 1.3822], [-0.5058, -0.9381]],
|
||||
[[0.0200, -0.3097], [-0.5715, -0.9026]],
|
||||
],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_inference() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
module.forward(input_tensor(&device));
|
||||
let module = module.valid();
|
||||
let output = module.forward(input_tensor(&device));
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[[0.9538, 0.7103], [0.0808, 0.5179]],
|
||||
[[0.6015, 0.8910], [0.3703, 0.6966]],
|
||||
[[0.9171, 0.6912], [0.3037, 0.9395]],
|
||||
],
|
||||
[
|
||||
[[0.6138, 0.0904], [0.0874, 0.7113]],
|
||||
[[-0.0297, 0.9408], [0.3415, 0.2042]],
|
||||
[[0.6250, 0.5561], [0.5013, 0.4323]],
|
||||
],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_running_mean() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
let _output = module.forward(input_tensor(&device));
|
||||
|
||||
let running_mean = module.running_mean.value_sync();
|
||||
|
||||
let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
|
||||
running_mean
|
||||
.reshape([3])
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_running_var() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
let _output = module.forward(input_tensor(&device));
|
||||
|
||||
let running_var = module.running_var.value_sync();
|
||||
|
||||
let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
|
||||
running_var
|
||||
.reshape([3])
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_running_mean_inner_module() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
|
||||
let _output = module.forward(input_tensor(&device));
|
||||
|
||||
let module_valid = module.valid();
|
||||
let running_mean = module_valid.running_mean.value();
|
||||
let running_mean_after = module.running_mean.value();
|
||||
|
||||
running_mean_after
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_norm_grads() {
|
||||
let device = Default::default();
|
||||
let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
|
||||
let input = input_tensor(&device).require_grad();
|
||||
|
||||
let output = module.forward(input.clone());
|
||||
|
||||
let grads = output.backward();
|
||||
|
||||
let tolerance = Tolerance::rel_abs(0.1, 0.001);
|
||||
let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
|
||||
module
|
||||
.gamma
|
||||
.grad(&grads)
|
||||
.unwrap()
|
||||
.reshape([3])
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([8., 8., 8.]);
|
||||
module
|
||||
.beta
|
||||
.grad(&grads)
|
||||
.unwrap()
|
||||
.reshape([3])
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
|
||||
[[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
|
||||
[[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
|
||||
],
|
||||
[
|
||||
[[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
|
||||
[[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
|
||||
[[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
|
||||
],
|
||||
]);
|
||||
input
|
||||
.grad(&grads)
|
||||
.unwrap()
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
}
|
||||
|
||||
fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
|
||||
Tensor::<B, 4>::from_floats(
|
||||
[
|
||||
[
|
||||
[[0.9601, 0.7277], [0.1270, 0.5441]],
|
||||
[[0.6272, 0.9034], [0.4066, 0.7179]],
|
||||
[[0.9378, 0.7230], [0.3544, 0.9591]],
|
||||
],
|
||||
[
|
||||
[[0.6356, 0.1362], [0.1333, 0.7287]],
|
||||
[[0.0249, 0.9509], [0.3791, 0.2481]],
|
||||
[[0.6600, 0.5945], [0.5424, 0.4767]],
|
||||
],
|
||||
],
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{batch_norm}"),
|
||||
"BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
use burn::module::Initializer;
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::Param;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct GroupNormConfig {
|
||||
/// The number of groups to separate the channels into
|
||||
pub num_groups: usize,
|
||||
/// The number of channels expected in the input
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
/// per-channel affine parameters initialized to ones (for weights)
|
||||
/// and zeros (for biases). Default: `true`
|
||||
#[config(default = true)]
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
///
|
||||
/// `Y = groupnorm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
/// Should be created using [GroupNormConfig](GroupNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct GroupNorm<B: Backend> {
|
||||
/// The learnable weight
|
||||
pub gamma: Option<Param<Tensor<B, 1>>>,
|
||||
/// The learnable bias
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
/// The number of groups to separate the channels into
|
||||
pub num_groups: usize,
|
||||
/// The number of channels expected in the input
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for GroupNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("num_groups", &self.num_groups)
|
||||
.add("num_channels", &self.num_channels)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.add("affine", &self.affine)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
/// Initialize a new [group norm](GroupNorm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
|
||||
assert_eq!(
|
||||
self.num_channels % self.num_groups,
|
||||
0,
|
||||
"The number of channels must be divisible by the number of groups"
|
||||
);
|
||||
|
||||
let (gamma, beta) = if self.affine {
|
||||
let gamma = Initializer::Ones.init([self.num_channels], device);
|
||||
let beta = Initializer::Zeros.init([self.num_channels], device);
|
||||
|
||||
(Some(gamma), Some(beta))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
GroupNorm {
|
||||
num_groups: self.num_groups,
|
||||
num_channels: self.num_channels,
|
||||
gamma,
|
||||
beta,
|
||||
epsilon: self.epsilon,
|
||||
affine: self.affine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> GroupNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [GroupNorm](GroupNorm) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, num_channels, *]`
|
||||
/// - output: `[batch_size, num_channels, *]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
if input.shape()[1] != self.num_channels {
|
||||
panic!(
|
||||
"The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
|
||||
self.num_channels,
|
||||
input.shape()[1]
|
||||
);
|
||||
}
|
||||
|
||||
let gamma = self.gamma.as_ref().map(|x| x.val());
|
||||
let beta = self.beta.as_ref().map(|x| x.val());
|
||||
|
||||
group_norm(
|
||||
input,
|
||||
gamma,
|
||||
beta,
|
||||
self.num_groups,
|
||||
self.epsilon,
|
||||
self.affine,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
///
|
||||
/// `Y = groupnorm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
pub(crate) fn group_norm<B: Backend, const D: usize>(
|
||||
input: Tensor<B, D>,
|
||||
gamma: Option<Tensor<B, 1>>,
|
||||
beta: Option<Tensor<B, 1>>,
|
||||
num_groups: usize,
|
||||
epsilon: f64,
|
||||
affine: bool,
|
||||
) -> Tensor<B, D> {
|
||||
if (beta.is_none() || gamma.is_none()) && affine {
|
||||
panic!("Affine is set to true, but gamma or beta is None");
|
||||
}
|
||||
|
||||
let shape = input.shape();
|
||||
if shape.num_elements() <= 2 {
|
||||
panic!(
|
||||
"input rank for GroupNorm should be at least 3, but got {}",
|
||||
shape.num_elements()
|
||||
);
|
||||
}
|
||||
|
||||
let batch_size = shape[0];
|
||||
let num_channels = shape[1];
|
||||
|
||||
let hidden_size = shape[2..].iter().product::<usize>() * num_channels / num_groups;
|
||||
let input = input.reshape([batch_size, num_groups, hidden_size]);
|
||||
|
||||
let mean = input.clone().sum_dim(2) / hidden_size as f64;
|
||||
let input = input.sub(mean);
|
||||
|
||||
let var = input.clone().square().sum_dim(2) / hidden_size as f64;
|
||||
let input_normalized = input.div(var.add_scalar(epsilon).sqrt());
|
||||
|
||||
if affine {
|
||||
let mut affine_shape = [1; D];
|
||||
affine_shape[1] = num_channels;
|
||||
|
||||
input_normalized
|
||||
.reshape(shape)
|
||||
.mul(gamma.clone().unwrap().reshape(affine_shape))
|
||||
.add(beta.clone().unwrap().reshape(affine_shape))
|
||||
} else {
|
||||
input_normalized.reshape(shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use alloc::format;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn group_norm_forward_affine_false() {
|
||||
let device = Default::default();
|
||||
let module = GroupNormConfig::new(2, 6)
|
||||
.with_affine(false)
|
||||
.init::<TestBackend>(&device);
|
||||
|
||||
assert!(module.gamma.is_none());
|
||||
assert!(module.beta.is_none());
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([
|
||||
[
|
||||
[-0.3034, 0.2726, -0.9659],
|
||||
[-1.1845, -1.3236, 0.0172],
|
||||
[1.9507, 1.2554, -0.8625],
|
||||
[1.0682, 0.3604, 0.3985],
|
||||
[-0.4957, -0.4461, -0.9721],
|
||||
[1.5157, -0.1546, -0.5596],
|
||||
],
|
||||
[
|
||||
[-1.6698, -0.4040, -0.7927],
|
||||
[0.3736, -0.0975, -0.1351],
|
||||
[-0.9461, 0.5461, -0.6334],
|
||||
[-1.0919, -0.1158, 0.1213],
|
||||
[-0.9535, 0.1281, 0.4372],
|
||||
[-0.2845, 0.3488, 0.5641],
|
||||
],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[-0.1653, 0.3748, -0.7866],
|
||||
[-0.9916, -1.1220, 0.1353],
|
||||
[1.9485, 1.2965, -0.6896],
|
||||
[1.2769, 0.3628, 0.4120],
|
||||
[-0.7427, -0.6786, -1.3578],
|
||||
[1.8547, -0.3022, -0.8252],
|
||||
],
|
||||
[
|
||||
[-1.9342, 0.0211, -0.5793],
|
||||
[1.2223, 0.4945, 0.4365],
|
||||
[-0.8163, 1.4887, -0.3333],
|
||||
[-1.7960, -0.0392, 0.3875],
|
||||
[-1.5469, 0.3998, 0.9561],
|
||||
[-0.3428, 0.7970, 1.1845],
|
||||
],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_forward_affine_true() {
|
||||
let device = Default::default();
|
||||
let module = GroupNormConfig::new(3, 6)
|
||||
.with_affine(true)
|
||||
.init::<TestBackend>(&device);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
module
|
||||
.gamma
|
||||
.as_ref()
|
||||
.expect("gamma should not be None")
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&TensorData::ones::<f32, _>([6]), tolerance);
|
||||
|
||||
module
|
||||
.beta
|
||||
.as_ref()
|
||||
.expect("beta should not be None")
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&TensorData::zeros::<f32, _>([6]), tolerance);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([
|
||||
[
|
||||
[0.3345, 0.4429, 0.6639],
|
||||
[0.5041, 0.4175, 0.8437],
|
||||
[0.6159, 0.3758, 0.4071],
|
||||
[0.5417, 0.5785, 0.7671],
|
||||
[0.3837, 0.9883, 0.0420],
|
||||
[0.4808, 0.8989, 0.6144],
|
||||
],
|
||||
[
|
||||
[0.3930, 0.2098, 0.0602],
|
||||
[0.2298, 0.9425, 0.0333],
|
||||
[0.7409, 0.8172, 0.8879],
|
||||
[0.4846, 0.0486, 0.2029],
|
||||
[0.6741, 0.9765, 0.6864],
|
||||
[0.2827, 0.5534, 0.2125],
|
||||
],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[-1.1694, -0.5353, 0.7572],
|
||||
[-0.1775, -0.6838, 1.8087],
|
||||
[0.5205, -1.3107, -1.0723],
|
||||
[-0.0459, 0.2351, 1.6734],
|
||||
[-0.5796, 1.3218, -1.6544],
|
||||
[-0.2744, 1.0406, 0.1459],
|
||||
],
|
||||
[
|
||||
[0.2665, -0.3320, -0.8205],
|
||||
[-0.2667, 2.0612, -0.9085],
|
||||
[0.6681, 0.9102, 1.1345],
|
||||
[-0.1453, -1.5287, -1.0389],
|
||||
[0.4253, 1.5962, 0.4731],
|
||||
[-1.0903, -0.0419, -1.3623],
|
||||
],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = GroupNormConfig::new(3, 6);
|
||||
let group_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{group_norm}"),
|
||||
"GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::norm::group_norm;
|
||||
use burn::config::Config;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct InstanceNormConfig {
|
||||
/// The number of channels expected in the input
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
/// per-channel affine parameters initialized to ones (for weights)
|
||||
/// and zeros (for biases). Default: `true`
|
||||
#[config(default = true)]
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
/// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022)
|
||||
///
|
||||
/// Should be created using [InstanceNormConfig](InstanceNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct InstanceNorm<B: Backend> {
|
||||
/// The learnable weight
|
||||
pub gamma: Option<Param<Tensor<B, 1>>>,
|
||||
/// The learnable bias
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
/// The number of channels expected in the input
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("num_channels", &self.num_channels)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.add("affine", &self.affine)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl InstanceNormConfig {
|
||||
/// Initialize a new [instance norm](InstanceNorm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
|
||||
let (gamma, beta) = if self.affine {
|
||||
let gamma = Initializer::Ones.init([self.num_channels], device);
|
||||
let beta = Initializer::Zeros.init([self.num_channels], device);
|
||||
|
||||
(Some(gamma), Some(beta))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
InstanceNorm {
|
||||
gamma,
|
||||
beta,
|
||||
num_channels: self.num_channels,
|
||||
epsilon: self.epsilon,
|
||||
affine: self.affine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> InstanceNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [InstanceNormConfig](InstanceNormConfig) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, num_channels, *]`
|
||||
/// - output: `[batch_size, num_channels, *]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.
|
||||
let num_groups = self.num_channels;
|
||||
|
||||
let gamma = self.gamma.as_ref().map(|x| x.val());
|
||||
let beta = self.beta.as_ref().map(|x| x.val());
|
||||
|
||||
group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use alloc::format;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn instance_norm_forward_affine_false() {
|
||||
let device = Default::default();
|
||||
let module = InstanceNormConfig::new(6)
|
||||
.with_affine(false)
|
||||
.init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([
|
||||
[
|
||||
[-0.3034, 0.2726, -0.9659],
|
||||
[-1.1845, 1.4078, 0.9774],
|
||||
[0.3963, -1.3738, 1.4125],
|
||||
[1.0682, 0.3604, 0.3985],
|
||||
[-0.4957, -0.4461, -0.9721],
|
||||
[1.5157, -0.1546, -0.5596],
|
||||
],
|
||||
[
|
||||
[-1.6698, -0.4040, -0.7927],
|
||||
[0.3736, -0.0975, -0.1351],
|
||||
[-0.9461, 0.5461, -0.6334],
|
||||
[-1.0919, -0.1158, 0.1213],
|
||||
[-0.9535, 0.1281, 0.4372],
|
||||
[-0.2845, 0.3488, 0.5641],
|
||||
],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[0.0569, 1.1952, -1.2522],
|
||||
[-1.3971, 0.8883, 0.5088],
|
||||
[0.2183, -1.3192, 1.1009],
|
||||
[1.4126, -0.7649, -0.6477],
|
||||
[0.5999, 0.8091, -1.409],
|
||||
[1.39, -0.4696, -0.9205],
|
||||
],
|
||||
[
|
||||
[-1.3492, 1.0417, 0.3075],
|
||||
[1.411, -0.6243, -0.7867],
|
||||
[-0.9363, 1.386, -0.4497],
|
||||
[-1.3899, 0.4692, 0.9208],
|
||||
[-1.3822, 0.4319, 0.9503],
|
||||
[-1.3714, 0.3868, 0.9846],
|
||||
],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn instance_norm_forward_affine_true() {
|
||||
let device = Default::default();
|
||||
let module = InstanceNormConfig::new(6)
|
||||
.with_affine(true)
|
||||
.init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([
|
||||
[
|
||||
[0.3345, 0.4429, 0.6639],
|
||||
[0.5041, 0.4175, 0.8437],
|
||||
[0.6159, 0.3758, 0.4071],
|
||||
[0.5417, 0.5785, 0.7671],
|
||||
[0.3837, 0.9883, 0.0420],
|
||||
[0.4808, 0.8989, 0.6144],
|
||||
],
|
||||
[
|
||||
[0.3930, 0.2098, 0.0602],
|
||||
[0.2298, 0.9425, 0.0333],
|
||||
[0.7409, 0.8172, 0.8879],
|
||||
[0.4846, 0.0486, 0.2029],
|
||||
[0.6741, 0.9765, 0.6864],
|
||||
[0.2827, 0.5534, 0.2125],
|
||||
],
|
||||
]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[-1.06458, -0.2738, 1.33838],
|
||||
[-0.45848, -0.92929, 1.38777],
|
||||
[1.40388, -0.84877, -0.55511],
|
||||
[-0.88515, -0.51245, 1.3976],
|
||||
[-0.22397, 1.32124, -1.09727],
|
||||
[-1.05468, 1.34316, -0.28848],
|
||||
],
|
||||
[
|
||||
[1.26372, -0.08229, -1.18144],
|
||||
[-0.44049, 1.38403, -0.94354],
|
||||
[-1.23828, 0.03109, 1.2072],
|
||||
[1.32524, -1.08999, -0.23524],
|
||||
[-0.75061, 1.4132, -0.66259],
|
||||
[-0.45469, 1.38697, -0.93228],
|
||||
],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = InstanceNormConfig::new(6);
|
||||
let instance_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{instance_norm}"),
|
||||
"InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Content;
|
||||
use burn::module::DisplaySettings;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::Module;
|
||||
use burn::module::ModuleDisplay;
|
||||
use burn::module::Param;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct LayerNormConfig {
|
||||
/// The size of the input features.
|
||||
pub d_model: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub epsilon: f64,
|
||||
/// If a bias (beta) should be applied during the normalization. Default: true
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
}
|
||||
|
||||
/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).
|
||||
///
|
||||
/// `Y = norm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `γ` is the learnable weight (scale)
|
||||
/// - `β` is the learnable bias (optional)
|
||||
///
|
||||
/// Should be created using [LayerNormConfig](LayerNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct LayerNorm<B: Backend> {
|
||||
/// The learnable weight (scale).
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// The learnable bias (optional).
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
/// A value required for numerical stability.
|
||||
epsilon: f64,
|
||||
}
|
||||
|
||||
impl LayerNormConfig {
|
||||
/// Initialize a new [layer norm](LayerNorm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
|
||||
let gamma = Initializer::Ones.init([self.d_model], device);
|
||||
let beta = if self.bias {
|
||||
Some(Initializer::Zeros.init([self.d_model], device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
LayerNorm {
|
||||
gamma,
|
||||
beta,
|
||||
epsilon: self.epsilon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LayerNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See the [LayerNorm](LayerNorm) documentation for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any, d_model]`
|
||||
/// - output: `[..., any, d_model]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let (var, mean) = input.clone().var_mean_bias(D - 1);
|
||||
|
||||
let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
|
||||
|
||||
let output = input_normalized.mul(self.gamma.val().unsqueeze());
|
||||
|
||||
match &self.beta {
|
||||
Some(beta) => output.add(beta.val().unsqueeze()),
|
||||
None => output,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for LayerNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_model] = self.gamma.shape().dims();
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.add("bias", &self.beta.is_some())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::format;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::{TestAutodiffBackend, TestBackend};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn layer_norm_forward() {
|
||||
let device = Default::default();
|
||||
let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
|
||||
let input = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[
|
||||
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
|
||||
]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([[
|
||||
-0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
|
||||
]]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_forward_large_epsilon() {
|
||||
let device = Default::default();
|
||||
let module = LayerNormConfig::new(10)
|
||||
.with_epsilon(1e-1)
|
||||
.init::<TestBackend>(&device);
|
||||
let input = Tensor::<TestBackend, 2>::from_data(
|
||||
TensorData::from([[
|
||||
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
|
||||
]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([[
|
||||
-0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
|
||||
]]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn layer_norm_backward() {
|
||||
let device = Default::default();
|
||||
let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
|
||||
TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
|
||||
TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let x = tensor_1.clone().matmul(tensor_2.clone());
|
||||
|
||||
let output = module.forward(x);
|
||||
let grads = output.backward();
|
||||
|
||||
let tensor_1_grad = tensor_1.grad(&grads).unwrap();
|
||||
let tensor_2_grad = tensor_2.grad(&grads).unwrap();
|
||||
let gamma_grad = module.gamma.grad(&grads).unwrap();
|
||||
let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([-2.0, 2.0]);
|
||||
gamma_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([2.0, 2.0]);
|
||||
beta_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
|
||||
tensor_1_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
|
||||
tensor_2_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = LayerNormConfig::new(6);
|
||||
let layer_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{layer_norm}"),
|
||||
"LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_no_bias() {
|
||||
let config = LayerNormConfig::new(6).with_bias(false);
|
||||
let layer_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{layer_norm}"),
|
||||
"LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//! # Normalization Layers
|
||||
//!
|
||||
//! Users who wish to provide an abstraction over swappable normalization
|
||||
//! layers can use the [`Normalization`] wrapper, with support for:
|
||||
//! * [`Normalization::Batch`] - [`BatchNorm`]
|
||||
//! * [`Normalization::Group`] - [`GroupNorm`]
|
||||
//! * [`Normalization::Instance`] - [`InstanceNorm`]
|
||||
//! * [`Normalization::Layer`] - [`LayerNorm`]
|
||||
//! * [`Normalization::Rms`] - [`RmsNorm`]
|
||||
//!
|
||||
//! [`NormalizationConfig`] can be used as a generic normalization policy:
|
||||
//! * Construct a config with arbitrary input features (we suggest `0`).
|
||||
//! * Clone and match that config to the target input layer,
|
||||
//! using the [`NormalizationConfig::with_num_features()`] method.
|
||||
pub(crate) mod batch;
|
||||
pub(crate) mod group;
|
||||
pub(crate) mod instance;
|
||||
pub(crate) mod layer;
|
||||
pub(crate) mod rms;
|
||||
|
||||
mod normalization_wrapper;
|
||||
|
||||
pub use batch::*;
|
||||
pub use group::*;
|
||||
pub use instance::*;
|
||||
pub use layer::*;
|
||||
pub use normalization_wrapper::*;
|
||||
pub use rms::*;
|
||||
@@ -0,0 +1,368 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::{
|
||||
BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig,
|
||||
LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig,
|
||||
};
|
||||
use burn::prelude::{Config, Module};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// ['Normalization'] Configuration.
|
||||
///
|
||||
/// The enum is non-exhaustive to prepare for future additions.
|
||||
///
|
||||
/// Can be used as a generic configuration for normalization layers:
|
||||
/// * Construct a config with arbitrary input features (we suggest `0`).
|
||||
/// * Clone and match that config to the target input layer,
|
||||
/// using the [`NormalizationConfig::with_num_features()`] method.
|
||||
#[derive(Config, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum NormalizationConfig {
|
||||
/// ['BatchNorm'] Configuration.
|
||||
Batch(BatchNormConfig),
|
||||
|
||||
/// ['GroupNorm'] Configuration.
|
||||
Group(GroupNormConfig),
|
||||
|
||||
/// ['InstanceNorm'] Configuration.
|
||||
Instance(InstanceNormConfig),
|
||||
|
||||
/// ['LayerNorm'] Configuration.
|
||||
Layer(LayerNormConfig),
|
||||
|
||||
/// ['RmsNorm'] Configuration.
|
||||
Rms(RmsNormConfig),
|
||||
}
|
||||
|
||||
impl From<BatchNormConfig> for NormalizationConfig {
|
||||
fn from(config: BatchNormConfig) -> Self {
|
||||
Self::Batch(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GroupNormConfig> for NormalizationConfig {
|
||||
fn from(config: GroupNormConfig) -> Self {
|
||||
Self::Group(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InstanceNormConfig> for NormalizationConfig {
|
||||
fn from(config: InstanceNormConfig) -> Self {
|
||||
Self::Instance(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LayerNormConfig> for NormalizationConfig {
|
||||
fn from(config: LayerNormConfig) -> Self {
|
||||
Self::Layer(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RmsNormConfig> for NormalizationConfig {
|
||||
fn from(config: RmsNormConfig) -> Self {
|
||||
Self::Rms(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl NormalizationConfig {
|
||||
/// Initialize a ['Norm'] layer.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Normalization<B> {
|
||||
match self {
|
||||
NormalizationConfig::Batch(config) => config.init(device).into(),
|
||||
NormalizationConfig::Group(config) => config.init(device).into(),
|
||||
NormalizationConfig::Instance(config) => config.init(device).into(),
|
||||
NormalizationConfig::Layer(config) => config.init(device).into(),
|
||||
NormalizationConfig::Rms(config) => config.init(device).into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the number of features.
|
||||
pub fn with_num_features(self, num_features: usize) -> Self {
|
||||
match self {
|
||||
NormalizationConfig::Batch(config) => BatchNormConfig {
|
||||
num_features,
|
||||
..config
|
||||
}
|
||||
.into(),
|
||||
NormalizationConfig::Group(config) => GroupNormConfig {
|
||||
num_channels: num_features,
|
||||
..config
|
||||
}
|
||||
.into(),
|
||||
NormalizationConfig::Instance(config) => InstanceNormConfig {
|
||||
num_channels: num_features,
|
||||
..config
|
||||
}
|
||||
.into(),
|
||||
NormalizationConfig::Layer(config) => LayerNormConfig {
|
||||
d_model: num_features,
|
||||
..config
|
||||
}
|
||||
.into(),
|
||||
NormalizationConfig::Rms(config) => RmsNormConfig {
|
||||
d_model: num_features,
|
||||
..config
|
||||
}
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of features.
|
||||
pub fn num_features(&self) -> usize {
|
||||
match self {
|
||||
NormalizationConfig::Batch(config) => config.num_features,
|
||||
NormalizationConfig::Group(config) => config.num_channels,
|
||||
NormalizationConfig::Instance(config) => config.num_channels,
|
||||
NormalizationConfig::Layer(config) => config.d_model,
|
||||
NormalizationConfig::Rms(config) => config.d_model,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalization Layer Wrapper
|
||||
///
|
||||
/// Provides support for built-in ``burn::nn::norm`` norm layers:
|
||||
/// * [`Normalization::Batch`] - [`BatchNorm`]
|
||||
/// * [`Normalization::Group`] - [`GroupNorm`]
|
||||
/// * [`Normalization::Instance`] - [`InstanceNorm`]
|
||||
/// * [`Normalization::Layer`] - [`LayerNorm`]
|
||||
/// * [`Normalization::Rms`] - [`RmsNorm`]
|
||||
///
|
||||
/// The enum is non-exhaustive, to prepare for future additions.
|
||||
#[derive(Module, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum Normalization<B: Backend> {
|
||||
/// [`BatchNorm`] layer.
|
||||
Batch(BatchNorm<B>),
|
||||
|
||||
/// [`GroupNorm`] layer.
|
||||
Group(GroupNorm<B>),
|
||||
|
||||
/// ['InstanceNorm'] layer.
|
||||
Instance(InstanceNorm<B>),
|
||||
|
||||
/// [`LayerNorm`] layer.
|
||||
Layer(LayerNorm<B>),
|
||||
|
||||
/// ['RmsNorm'] layer.
|
||||
Rms(RmsNorm<B>),
|
||||
}
|
||||
|
||||
impl<B: Backend> From<BatchNorm<B>> for Normalization<B> {
|
||||
fn from(layer: BatchNorm<B>) -> Self {
|
||||
Self::Batch(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<GroupNorm<B>> for Normalization<B> {
|
||||
fn from(layer: GroupNorm<B>) -> Self {
|
||||
Self::Group(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<InstanceNorm<B>> for Normalization<B> {
|
||||
fn from(layer: InstanceNorm<B>) -> Self {
|
||||
Self::Instance(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<LayerNorm<B>> for Normalization<B> {
|
||||
fn from(layer: LayerNorm<B>) -> Self {
|
||||
Self::Layer(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<RmsNorm<B>> for Normalization<B> {
|
||||
fn from(layer: RmsNorm<B>) -> Self {
|
||||
Self::Rms(layer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Normalization<B> {
|
||||
/// Applies normalization to a tensor.
|
||||
///
|
||||
/// The normalization contract depends upon the wrapped norm layer;
|
||||
/// but all norm layers assume an input of at least rank 2;
|
||||
/// and produce an output of the same rank and shape.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
match self {
|
||||
Normalization::Batch(norm) => norm.forward(input),
|
||||
Normalization::Group(norm) => norm.forward(input),
|
||||
Normalization::Instance(norm) => norm.forward(input),
|
||||
Normalization::Layer(norm) => norm.forward(input),
|
||||
Normalization::Rms(norm) => norm.forward(input),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of features.
|
||||
pub fn num_features(&self) -> usize {
|
||||
match self {
|
||||
Normalization::Batch(norm) => norm.gamma.shape()[0],
|
||||
Normalization::Group(norm) => norm.num_channels,
|
||||
Normalization::Instance(norm) => norm.num_channels,
|
||||
Normalization::Layer(norm) => norm.gamma.shape()[0],
|
||||
Normalization::Rms(norm) => norm.gamma.shape()[0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_match_feature_size() {
|
||||
let config: NormalizationConfig = BatchNormConfig::new(0).into();
|
||||
assert_eq!(config.num_features(), 0);
|
||||
let config = config.with_num_features(12);
|
||||
assert_eq!(config.num_features(), 12);
|
||||
|
||||
let config: NormalizationConfig = GroupNormConfig::new(4, 0).into();
|
||||
assert_eq!(config.num_features(), 0);
|
||||
let config = config.with_num_features(12);
|
||||
assert_eq!(config.num_features(), 12);
|
||||
|
||||
let config: NormalizationConfig = InstanceNormConfig::new(0).into();
|
||||
assert_eq!(config.num_features(), 0);
|
||||
let config = config.with_num_features(12);
|
||||
assert_eq!(config.num_features(), 12);
|
||||
|
||||
let config: NormalizationConfig = LayerNormConfig::new(0).into();
|
||||
assert_eq!(config.num_features(), 0);
|
||||
let config = config.with_num_features(12);
|
||||
assert_eq!(config.num_features(), 12);
|
||||
|
||||
let config: NormalizationConfig = RmsNormConfig::new(0).into();
|
||||
assert_eq!(config.num_features(), 0);
|
||||
let config = config.with_num_features(12);
|
||||
assert_eq!(config.num_features(), 12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_norm() {
|
||||
type B = TestAutodiffBackend;
|
||||
let device = Default::default();
|
||||
|
||||
let num_features = 12;
|
||||
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
|
||||
|
||||
let config: NormalizationConfig = BatchNormConfig::new(12).into();
|
||||
|
||||
let layer: Normalization<B> = config.init(&device);
|
||||
assert_eq!(layer.num_features(), 12);
|
||||
|
||||
let expected = match &layer {
|
||||
Normalization::Batch(inner) => inner.forward(input.clone()),
|
||||
_ => panic!("Unexpected layer type"),
|
||||
};
|
||||
|
||||
let output = layer.forward(input);
|
||||
|
||||
output.to_data().assert_eq(&expected.to_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_group_norm() {
|
||||
type B = TestAutodiffBackend;
|
||||
let device = Default::default();
|
||||
|
||||
let num_features = 12;
|
||||
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
|
||||
|
||||
let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into();
|
||||
|
||||
let layer: Normalization<B> = config.init(&device);
|
||||
assert_eq!(layer.num_features(), 12);
|
||||
|
||||
let expected = match &layer {
|
||||
Normalization::Group(inner) => inner.forward(input.clone()),
|
||||
_ => panic!("Unexpected layer type"),
|
||||
};
|
||||
|
||||
let output = layer.forward(input);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_instance_norm() {
|
||||
type B = TestAutodiffBackend;
|
||||
let device = Default::default();
|
||||
|
||||
let num_features = 12;
|
||||
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
|
||||
|
||||
let config: NormalizationConfig = InstanceNormConfig::new(num_features).into();
|
||||
|
||||
let layer: Normalization<B> = config.init(&device);
|
||||
assert_eq!(layer.num_features(), 12);
|
||||
|
||||
let expected = match &layer {
|
||||
Normalization::Instance(inner) => inner.forward(input.clone()),
|
||||
_ => panic!("Unexpected layer type"),
|
||||
};
|
||||
|
||||
let output = layer.forward(input);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_norm() {
|
||||
type B = TestAutodiffBackend;
|
||||
let device = Default::default();
|
||||
|
||||
let num_features = 12;
|
||||
let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
|
||||
|
||||
let config: NormalizationConfig = LayerNormConfig::new(num_features).into();
|
||||
|
||||
let layer: Normalization<B> = config.init(&device);
|
||||
assert_eq!(layer.num_features(), 12);
|
||||
|
||||
let expected = match &layer {
|
||||
Normalization::Layer(inner) => inner.forward(input.clone()),
|
||||
_ => panic!("Unexpected layer type"),
|
||||
};
|
||||
|
||||
let output = layer.forward(input);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rms_norm() {
|
||||
type B = TestAutodiffBackend;
|
||||
let device = Default::default();
|
||||
|
||||
let num_features = 12;
|
||||
let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
|
||||
|
||||
let config: NormalizationConfig = RmsNormConfig::new(num_features).into();
|
||||
|
||||
let layer: Normalization<B> = config.init(&device);
|
||||
assert_eq!(layer.num_features(), 12);
|
||||
|
||||
let expected = match &layer {
|
||||
Normalization::Rms(inner) => inner.forward(input.clone()),
|
||||
_ => panic!("Unexpected layer type"),
|
||||
};
|
||||
|
||||
let output = layer.forward(input);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
use burn::tensor::DType;
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Initializer;
|
||||
use burn::module::Module;
|
||||
use burn::module::Param;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct RmsNormConfig {
|
||||
/// The size of the input features.
|
||||
pub d_model: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl RmsNormConfig {
|
||||
/// Initialize a new [RMS Norm](RmsNorm) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `epsilon` is not positive.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
|
||||
assert!(self.epsilon > 0.0, "epsilon must be positive.");
|
||||
|
||||
let gamma = Initializer::Ones.init([self.d_model], device);
|
||||
|
||||
RmsNorm {
|
||||
gamma,
|
||||
epsilon: self.epsilon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies RMS Normalization over an input tensor along the last dimension.
|
||||
///
|
||||
/// `Y = X / sqrt(mean(X^2) + eps) * gamma`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `gamma` is the learnable weight
|
||||
/// - `mean` is the mean operation
|
||||
/// - `eps` is a small value to avoid division by zero.
|
||||
///
|
||||
/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct RmsNorm<B: Backend> {
|
||||
/// The learnable parameter to scale the normalized tensor
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// A value required for numerical stability
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> RmsNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See the [RmsNorm](RmsNorm) documentation for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any, d_model]`
|
||||
/// - output: `[..., any, d_model]`
|
||||
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// Calculate the root-mean-square norm of the input tensor along the last dimension
|
||||
let dtype = x.dtype();
|
||||
let rms = (x.clone().cast(DType::F32).square().mean_dim(D - 1) + self.epsilon).sqrt();
|
||||
(x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for RmsNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_model] = self.gamma.shape().dims();
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use alloc::format;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn rms_norm_forward() {
|
||||
let device = Default::default();
|
||||
let module = RmsNormConfig::new(3)
|
||||
.with_epsilon(1e-5)
|
||||
.init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);
|
||||
let output = module.forward(input);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[0.0000, 0.7746, 1.5492],
|
||||
[0.7348, 0.9798, 1.2247],
|
||||
[0.8514, 0.9933, 1.1352],
|
||||
]);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = RmsNormConfig::new(6);
|
||||
let layer_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{layer_norm}"),
|
||||
"RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use burn::tensor::module::adaptive_avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer using the [init function](AdaptiveAvgPool1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AdaptiveAvgPool1dConfig {
|
||||
/// The size of the output.
|
||||
pub output_size: usize,
|
||||
}
|
||||
|
||||
/// Applies a 1D adaptive avg pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [AdaptiveAvgPool1dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AdaptiveAvgPool1d {
|
||||
/// The size of the output.
|
||||
pub output_size: usize,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AdaptiveAvgPool1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("output_size", &self.output_size).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool1dConfig {
|
||||
/// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module.
|
||||
pub fn init(&self) -> AdaptiveAvgPool1d {
|
||||
AdaptiveAvgPool1d {
|
||||
output_size: self.output_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool1d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [adaptive_avg_pool1d](burn::tensor::module::adaptive_avg_pool1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, length]`
|
||||
/// - output: `[batch_size, channels, length_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
adaptive_avg_pool1d(input, self.output_size)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AdaptiveAvgPool1dConfig::new(3);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"AdaptiveAvgPool1d {output_size: 3}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::Module;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use burn::tensor::module::adaptive_avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer using the [init function](AdaptiveAvgPool2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AdaptiveAvgPool2dConfig {
|
||||
/// The size of the output.
|
||||
pub output_size: [usize; 2],
|
||||
}
|
||||
|
||||
/// Applies a 2D adaptive avg pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [AdaptiveAvgPool2dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AdaptiveAvgPool2d {
|
||||
/// The size of the output.
|
||||
pub output_size: [usize; 2],
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AdaptiveAvgPool2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let output_size = alloc::format!("{:?}", self.output_size);
|
||||
|
||||
content.add("output_size", &output_size).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dConfig {
|
||||
/// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.
|
||||
pub fn init(&self) -> AdaptiveAvgPool2d {
|
||||
AdaptiveAvgPool2d {
|
||||
output_size: self.output_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [adaptive_avg_pool2d](burn::tensor::module::adaptive_avg_pool2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels, height_out, width_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
adaptive_avg_pool2d(input, self.output_size)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AdaptiveAvgPool2dConfig::new([3, 3]);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"AdaptiveAvgPool2d {output_size: [3, 3]}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig1d;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::module::{Ignored, Module};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::ops::PadMode;
|
||||
|
||||
use burn::tensor::module::avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AvgPool1dConfig {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The stride.
|
||||
#[config(default = "kernel_size")]
|
||||
pub stride: usize,
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
|
||||
/// will automatically use asymmetric padding to preserve input dimensions.
|
||||
#[config(default = "PaddingConfig1d::Valid")]
|
||||
pub padding: PaddingConfig1d,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
#[config(default = "true")]
|
||||
pub count_include_pad: bool,
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
#[config(default = "false")]
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
/// Applies a 1D avg pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [AvgPool1dConfig](AvgPool1dConfig).
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// The zero-padding values will be included in the calculation
|
||||
/// of the average. This means that the zeros are counted as
|
||||
/// legitimate values, and they contribute to the denominator
|
||||
/// when calculating the average. This is equivalent to
|
||||
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AvgPool1d {
|
||||
/// The stride.
|
||||
pub stride: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig1d>,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
pub count_include_pad: bool,
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AvgPool1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("stride", &self.stride)
|
||||
.add("padding", &self.padding)
|
||||
.add("count_include_pad", &self.count_include_pad)
|
||||
.add("ceil_mode", &self.ceil_mode)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AvgPool1dConfig {
|
||||
/// Initialize a new [avg pool 1d](AvgPool1d) module.
|
||||
pub fn init(&self) -> AvgPool1d {
|
||||
AvgPool1d {
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
count_include_pad: self.count_include_pad,
|
||||
ceil_mode: self.ceil_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AvgPool1d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [avg_pool1d](burn::tensor::module::avg_pool1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, length_in]`
|
||||
/// - output: `[batch_size, channels, length_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let [_batch_size, _channels, length] = input.dims();
|
||||
|
||||
// Calculate padding as pair - handles Same, Valid, and Explicit uniformly
|
||||
let (left, right) =
|
||||
self.padding
|
||||
.calculate_padding_1d_pair(length, self.kernel_size, self.stride);
|
||||
|
||||
// TODO: Move asymmetric padding to functional level via PoolOptions
|
||||
// See: https://github.com/tracel-ai/burn/issues/4362
|
||||
// Handle asymmetric padding by applying explicit pad operation first
|
||||
if left != right {
|
||||
// Burn's pad takes (left, right, top, bottom) for the last two dimensions
|
||||
// For 1D (NCL format), we only pad L (last dim), so top/bottom = 0
|
||||
let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0));
|
||||
// Use zero padding for the pool operation since we already padded
|
||||
avg_pool1d(
|
||||
padded,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
0,
|
||||
self.count_include_pad,
|
||||
self.ceil_mode,
|
||||
)
|
||||
} else {
|
||||
// Symmetric padding
|
||||
avg_pool1d(
|
||||
input,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
left,
|
||||
self.count_include_pad,
|
||||
self.ceil_mode,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use rstest::rstest;
|
||||
|
||||
#[test]
|
||||
fn same_with_even_kernel_uses_asymmetric_padding() {
|
||||
let device = Default::default();
|
||||
let config = AvgPool1dConfig::new(2)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Same);
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, length=5]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Same padding should preserve spatial dimensions
|
||||
assert_eq!(output.dims(), [1, 2, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AvgPool1dConfig::new(3);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}"
|
||||
);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(1)]
|
||||
#[case(2)]
|
||||
fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
|
||||
let config = AvgPool1dConfig::new(kernel_size);
|
||||
|
||||
assert_eq!(
|
||||
config.stride, kernel_size,
|
||||
"Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor",
|
||||
config.stride, config.kernel_size
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asymmetric_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create avg pool with asymmetric padding: left=1, right=2
|
||||
let config = AvgPool1dConfig::new(3)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Explicit(1, 2));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, length=4]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
|
||||
// Output length = (7 - 3) / 1 + 1 = 5
|
||||
assert_eq!(output.dims(), [1, 2, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_explicit_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create avg pool with symmetric explicit padding: left=2, right=2
|
||||
let config = AvgPool1dConfig::new(3)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Explicit(2, 2));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, length=4]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
|
||||
// Output length = (8 - 3) / 1 + 1 = 6
|
||||
assert_eq!(output.dims(), [1, 2, 6]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig2d;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::module::{Ignored, Module};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::ops::PadMode;
|
||||
|
||||
use burn::tensor::module::avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D avg pooling](AvgPool2d) layer using the [init function](AvgPool2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AvgPool2dConfig {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The strides.
|
||||
#[config(default = "kernel_size")]
|
||||
pub strides: [usize; 2],
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
|
||||
/// will automatically use asymmetric padding to preserve input dimensions.
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
#[config(default = "true")]
|
||||
pub count_include_pad: bool,
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
#[config(default = "false")]
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
/// Applies a 2D avg pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [AvgPool2dConfig](AvgPool2dConfig).
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// The zero-padding values will be included in the calculation
|
||||
/// of the average. This means that the zeros are counted as
|
||||
/// legitimate values, and they contribute to the denominator
|
||||
/// when calculating the average. This is equivalent to
|
||||
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AvgPool2d {
|
||||
/// Stride of the pooling.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
pub count_include_pad: bool,
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AvgPool2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
|
||||
.add("stride", &alloc::format!("{:?}", &self.stride))
|
||||
.add("padding", &self.padding)
|
||||
.add("count_include_pad", &self.count_include_pad)
|
||||
.add("ceil_mode", &self.ceil_mode)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AvgPool2dConfig {
|
||||
/// Initialize a new [avg pool 2d](AvgPool2d) module.
|
||||
pub fn init(&self) -> AvgPool2d {
|
||||
AvgPool2d {
|
||||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
count_include_pad: self.count_include_pad,
|
||||
ceil_mode: self.ceil_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AvgPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [avg_pool2d](burn::tensor::module::avg_pool2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels, height_out, width_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
|
||||
// Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
|
||||
let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
|
||||
height_in,
|
||||
width_in,
|
||||
&self.kernel_size,
|
||||
&self.stride,
|
||||
);
|
||||
|
||||
// TODO: Move asymmetric padding to functional level via PoolOptions
|
||||
// See: https://github.com/tracel-ai/burn/issues/4362
|
||||
// Handle asymmetric padding by applying explicit pad operation first
|
||||
if top != bottom || left != right {
|
||||
// Burn's pad takes (left, right, top, bottom) for the last two dimensions
|
||||
let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0));
|
||||
// Use zero padding for the pool operation since we already padded
|
||||
avg_pool2d(
|
||||
padded,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
[0, 0],
|
||||
self.count_include_pad,
|
||||
self.ceil_mode,
|
||||
)
|
||||
} else {
|
||||
// Symmetric padding
|
||||
avg_pool2d(
|
||||
input,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
[top, left],
|
||||
self.count_include_pad,
|
||||
self.ceil_mode,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use rstest::rstest;
|
||||
|
||||
#[test]
|
||||
fn same_with_even_kernel_uses_asymmetric_padding() {
|
||||
let device = Default::default();
|
||||
let config = AvgPool2dConfig::new([2, 2])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Same);
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, height=5, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Same padding should preserve spatial dimensions
|
||||
assert_eq!(output.dims(), [1, 2, 5, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AvgPool2dConfig::new([3, 3]);
|
||||
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"AvgPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, count_include_pad: true, ceil_mode: false}"
|
||||
);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case([2, 2])]
|
||||
#[case([1, 2])]
|
||||
fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
|
||||
let config = AvgPool2dConfig::new(kernel_size);
|
||||
|
||||
assert_eq!(
|
||||
config.strides, kernel_size,
|
||||
"Expected strides ({:?}) to match kernel size ({:?}) in default AvgPool2dConfig::new constructor",
|
||||
config.strides, config.kernel_size
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asymmetric_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create avg pool with asymmetric padding: top=1, left=2, bottom=3, right=4
|
||||
let config = AvgPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, height=4, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
|
||||
// Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
|
||||
assert_eq!(output.dims(), [1, 2, 6, 9]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_explicit_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create avg pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2
|
||||
let config = AvgPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, height=4, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
|
||||
// Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
|
||||
assert_eq!(output.dims(), [1, 2, 6, 7]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig1d;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::module::{Ignored, Module};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::ops::PadMode;
|
||||
|
||||
use burn::tensor::module::max_pool1d;
|
||||
|
||||
/// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MaxPool1dConfig {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The stride.
|
||||
#[config(default = "kernel_size")]
|
||||
pub stride: usize,
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
|
||||
/// will automatically use asymmetric padding to preserve input dimensions.
|
||||
#[config(default = "PaddingConfig1d::Valid")]
|
||||
pub padding: PaddingConfig1d,
|
||||
/// The dilation.
|
||||
#[config(default = "1")]
|
||||
pub dilation: usize,
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
#[config(default = "false")]
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
/// Applies a 1D max pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct MaxPool1d {
|
||||
/// The stride.
|
||||
pub stride: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig1d>,
|
||||
/// The dilation.
|
||||
pub dilation: usize,
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for MaxPool1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("stride", &self.stride)
|
||||
.add("padding", &self.padding)
|
||||
.add("dilation", &self.dilation)
|
||||
.add("ceil_mode", &self.ceil_mode)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl MaxPool1dConfig {
|
||||
/// Initialize a new [max pool 1d](MaxPool1d) module.
|
||||
pub fn init(&self) -> MaxPool1d {
|
||||
MaxPool1d {
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
ceil_mode: self.ceil_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MaxPool1d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [max_pool1d](burn::tensor::module::max_pool1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, length_in]`
|
||||
/// - output: `[batch_size, channels, length_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let [_batch_size, _channels, length] = input.dims();
|
||||
|
||||
// Calculate padding as pair - handles Same, Valid, and Explicit uniformly
|
||||
let (left, right) =
|
||||
self.padding
|
||||
.calculate_padding_1d_pair(length, self.kernel_size, self.stride);
|
||||
|
||||
// TODO: Move asymmetric padding to functional level via PoolOptions
|
||||
// See: https://github.com/tracel-ai/burn/issues/4362
|
||||
// Handle asymmetric padding by applying explicit pad operation first
|
||||
if left != right {
|
||||
// For 1D (NCL format), pad the length dimension with (left, right)
|
||||
// and no padding for channel dimension (top=0, bottom=0)
|
||||
// Use -inf for max pooling so padded values don't affect the max
|
||||
let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY));
|
||||
// Use zero padding for the pool operation since we already padded
|
||||
max_pool1d(
|
||||
padded,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
0,
|
||||
self.dilation,
|
||||
self.ceil_mode,
|
||||
)
|
||||
} else {
|
||||
// Symmetric padding
|
||||
max_pool1d(
|
||||
input,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
left,
|
||||
self.dilation,
|
||||
self.ceil_mode,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use rstest::rstest;
|
||||
|
||||
#[test]
|
||||
fn same_with_even_kernel_uses_asymmetric_padding() {
|
||||
let device = Default::default();
|
||||
let config = MaxPool1dConfig::new(2)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Same);
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, length=5]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Same padding should preserve spatial dimensions
|
||||
assert_eq!(output.dims(), [1, 2, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = MaxPool1dConfig::new(3);
|
||||
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}"
|
||||
);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(1)]
|
||||
#[case(2)]
|
||||
fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
|
||||
let config = MaxPool1dConfig::new(kernel_size);
|
||||
|
||||
assert_eq!(
|
||||
config.stride, kernel_size,
|
||||
"Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor",
|
||||
config.stride, config.kernel_size
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asymmetric_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create max pool with asymmetric padding: left=1, right=2
|
||||
let config = MaxPool1dConfig::new(3)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Explicit(1, 2));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, length=4]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
|
||||
// Output length = (7 - 3) / 1 + 1 = 5
|
||||
assert_eq!(output.dims(), [1, 2, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_explicit_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create max pool with symmetric explicit padding: left=2, right=2
|
||||
let config = MaxPool1dConfig::new(3)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Explicit(2, 2));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, length=4]
|
||||
let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
|
||||
// Output length = (8 - 3) / 1 + 1 = 6
|
||||
assert_eq!(output.dims(), [1, 2, 6]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::PaddingConfig2d;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use burn::module::{Ignored, Module};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::ops::PadMode;
|
||||
|
||||
use burn::tensor::module::max_pool2d;
|
||||
|
||||
/// Configuration to create a [2D max pooling](MaxPool2d) layer using the [init function](MaxPool2dConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct MaxPool2dConfig {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The strides.
|
||||
#[config(default = "kernel_size")]
|
||||
pub strides: [usize; 2],
|
||||
/// The padding configuration.
|
||||
///
|
||||
/// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
|
||||
/// will automatically use asymmetric padding to preserve input dimensions.
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// The dilation.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
#[config(default = "false")]
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
/// Applies a 2D max pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct MaxPool2d {
|
||||
/// The strides.
|
||||
pub stride: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
/// The dilation.
|
||||
pub dilation: [usize; 2],
|
||||
/// If true, use ceiling instead of floor for output size calculation.
|
||||
pub ceil_mode: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for MaxPool2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
|
||||
.add("stride", &alloc::format!("{:?}", &self.stride))
|
||||
.add("padding", &self.padding)
|
||||
.add("dilation", &alloc::format!("{:?}", &self.dilation))
|
||||
.add("ceil_mode", &self.ceil_mode)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl MaxPool2dConfig {
|
||||
/// Initialize a new [max pool 2d](MaxPool2d) module.
|
||||
pub fn init(&self) -> MaxPool2d {
|
||||
MaxPool2d {
|
||||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
ceil_mode: self.ceil_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MaxPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [max_pool2d](burn::tensor::module::max_pool2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels, height_out, width_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
|
||||
// Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
|
||||
let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
|
||||
height_in,
|
||||
width_in,
|
||||
&self.kernel_size,
|
||||
&self.stride,
|
||||
);
|
||||
|
||||
// TODO: Move asymmetric padding to functional level via PoolOptions
|
||||
// See: https://github.com/tracel-ai/burn/issues/4362
|
||||
// Handle asymmetric padding by applying explicit pad operation first
|
||||
if top != bottom || left != right {
|
||||
// Burn's pad takes (left, right, top, bottom) for the last two dimensions
|
||||
// Use -inf for max pooling so padded values don't affect the max
|
||||
let padded = input.pad(
|
||||
(left, right, top, bottom),
|
||||
PadMode::Constant(f32::NEG_INFINITY),
|
||||
);
|
||||
// Use zero padding for the pool operation since we already padded
|
||||
max_pool2d(
|
||||
padded,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
[0, 0],
|
||||
self.dilation,
|
||||
self.ceil_mode,
|
||||
)
|
||||
} else {
|
||||
// Symmetric padding
|
||||
max_pool2d(
|
||||
input,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
[top, left],
|
||||
self.dilation,
|
||||
self.ceil_mode,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use rstest::rstest;
|
||||
|
||||
#[test]
|
||||
fn same_with_even_kernel_uses_asymmetric_padding() {
|
||||
let device = Default::default();
|
||||
let config = MaxPool2dConfig::new([2, 2])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Same);
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, height=5, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Same padding should preserve spatial dimensions
|
||||
assert_eq!(output.dims(), [1, 2, 5, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = MaxPool2dConfig::new([3, 3]);
|
||||
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}"
|
||||
);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case([2, 2])]
|
||||
#[case([1, 2])]
|
||||
fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
|
||||
let config = MaxPool2dConfig::new(kernel_size);
|
||||
|
||||
assert_eq!(
|
||||
config.strides, kernel_size,
|
||||
"Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor",
|
||||
config.strides, config.kernel_size
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asymmetric_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create max pool with asymmetric padding: top=1, left=2, bottom=3, right=4
|
||||
let config = MaxPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, height=4, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
|
||||
// Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
|
||||
assert_eq!(output.dims(), [1, 2, 6, 9]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_explicit_padding_forward() {
|
||||
let device = Default::default();
|
||||
// Create max pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2
|
||||
let config = MaxPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
|
||||
let pool = config.init();
|
||||
|
||||
// Input: [batch=1, channels=2, height=4, width=5]
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
|
||||
let output = pool.forward(input);
|
||||
|
||||
// Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
|
||||
// Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
|
||||
assert_eq!(output.dims(), [1, 2, 6, 7]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
mod adaptive_avg_pool1d;
|
||||
mod adaptive_avg_pool2d;
|
||||
mod avg_pool1d;
|
||||
mod avg_pool2d;
|
||||
mod max_pool1d;
|
||||
mod max_pool2d;
|
||||
|
||||
pub use adaptive_avg_pool1d::*;
|
||||
pub use adaptive_avg_pool2d::*;
|
||||
pub use avg_pool1d::*;
|
||||
pub use avg_pool2d::*;
|
||||
pub use max_pool1d::*;
|
||||
pub use max_pool2d::*;
|
||||
@@ -0,0 +1,291 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::TensorData;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PositionalEncodingConfig {
|
||||
/// Maximum sequence size to use.
|
||||
#[config(default = "5_000")]
|
||||
pub max_sequence_size: usize,
|
||||
|
||||
/// The size of each vector.
|
||||
pub d_model: usize,
|
||||
|
||||
/// Max time scale to use.
|
||||
#[config(default = "10_000")]
|
||||
pub max_timescale: usize,
|
||||
}
|
||||
|
||||
/// Positional encoding layer for transformer models.
|
||||
///
|
||||
/// This layer adds positional information to the input embeddings, allowing the transformer model
|
||||
/// to take into account the order of the sequence. The positional encoding is added to the input
|
||||
/// embeddings by computing a set of sinusoidal functions with different frequencies and phases.
|
||||
///
|
||||
/// Sinusoids are used for positional embedding introduced in
|
||||
/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
|
||||
///
|
||||
/// The reference implementation can be found here:
|
||||
/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
|
||||
/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
|
||||
///
|
||||
/// Should be created using [PositionalEncodingConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PositionalEncoding<B: Backend> {
|
||||
/// The sinusoids used to add positional information to the input embeddings.
|
||||
pub sinusoids: Tensor<B, 3>,
|
||||
/// The maximum sequence size to use.
|
||||
pub max_sequence_size: usize,
|
||||
/// Max time scale to use.
|
||||
pub max_timescale: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [_, _, d_model] = self.sinusoids.shape().dims();
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("max_sequence_size", &self.max_sequence_size)
|
||||
.add("max_timescale", &self.max_timescale)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl PositionalEncodingConfig {
|
||||
/// Initialize a new [PositionalEncoding](PositionalEncoding) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
|
||||
let sinusoids = generate_sinusoids::<B>(
|
||||
self.max_sequence_size,
|
||||
self.d_model,
|
||||
self.max_timescale,
|
||||
device,
|
||||
)
|
||||
.unsqueeze::<3>();
|
||||
|
||||
PositionalEncoding {
|
||||
sinusoids,
|
||||
max_sequence_size: self.max_sequence_size,
|
||||
max_timescale: self.max_timescale,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PositionalEncoding<B> {
|
||||
/// Applies the forward pass on the input tensor by adding the sinusoids to the input.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// * input: [batch_size, seq_length, d_model]
|
||||
/// * output: [batch_size, seq_length, d_model]
|
||||
///
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * Panics if the input sequence length is greater than the maximum sequence size.
|
||||
/// * Panics if the input d_model is not equal to the d_model of the sinusoids.
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let [_, seq_length, d_model_input] = input.dims();
|
||||
|
||||
let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
|
||||
|
||||
assert!(
|
||||
max_sequence_size >= seq_length,
|
||||
"max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})"
|
||||
);
|
||||
|
||||
assert!(
|
||||
d_model_input == d_model,
|
||||
"d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})"
|
||||
);
|
||||
|
||||
let slices = [0..batch_size, 0..seq_length, 0..d_model];
|
||||
|
||||
input.add(self.sinusoids.clone().slice(slices))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns sinusoids for positional embedding introduced in
|
||||
/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
|
||||
///
|
||||
/// The reference implementation can be found here:
|
||||
/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
|
||||
/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `length` - The length of the sequence.
|
||||
/// * `d_model` - The size of each vector.
|
||||
/// * `max_timescale` - The maximum time scale to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor of shape [length, d_model] containing the sinusoids.
|
||||
pub fn generate_sinusoids<B: Backend>(
|
||||
length: usize,
|
||||
d_model: usize,
|
||||
max_timescale: usize,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, 2> {
|
||||
assert!(d_model.is_multiple_of(2), "d_model must be even");
|
||||
assert!(
|
||||
max_timescale >= length,
|
||||
"max_timescale must be greater than length"
|
||||
);
|
||||
|
||||
// Calculate the increment for the logarithmic timescale
|
||||
let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
|
||||
|
||||
// Create a vector to hold the sinusoids
|
||||
let mut scaled_time_sin_cos = Vec::with_capacity(length);
|
||||
|
||||
// Loop over each position in the sequence
|
||||
for i in 0..length {
|
||||
// Create a vector to hold the sinusoids for this position
|
||||
let mut row = Vec::with_capacity(d_model / 2);
|
||||
// Loop over each dimension of the sinusoids
|
||||
for k in (0..d_model).step_by(2) {
|
||||
// Calculate the division term for this dimension
|
||||
let div_term = (k as f32 * log_timescale_increment).exp();
|
||||
// Calculate the sine and cosine values for this dimension and position
|
||||
row.push((div_term * i as f32).sin());
|
||||
row.push((div_term * i as f32).cos());
|
||||
}
|
||||
|
||||
// Add the sinusoids for this position to the vector
|
||||
scaled_time_sin_cos.push(row);
|
||||
}
|
||||
|
||||
// Convert the sinusoids to a tensor and return it
|
||||
let data = TensorData::new(
|
||||
scaled_time_sin_cos.into_iter().flatten().collect(),
|
||||
[length, d_model],
|
||||
);
|
||||
|
||||
Tensor::<B, 2>::from_data(data, device)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_module() {
|
||||
let d_model = 6;
|
||||
let length = 3;
|
||||
|
||||
// expected to broadcast
|
||||
let batch_size = 2;
|
||||
|
||||
let device = Default::default();
|
||||
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
|
||||
|
||||
// Use a tensor of zeros as input for easy verification of the output
|
||||
// The output should be the sinusoids broadcasted to the input shape
|
||||
let tensor = Tensor::zeros([batch_size, length, d_model], &device);
|
||||
|
||||
let output = pe.forward(tensor);
|
||||
|
||||
assert_eq!(&*output.shape(), [batch_size, length, d_model]);
|
||||
|
||||
let expected = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
|
||||
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
|
||||
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
|
||||
],
|
||||
[
|
||||
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
|
||||
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
|
||||
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_sinusoids() {
|
||||
let device = Default::default();
|
||||
let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
|
||||
|
||||
// The values are taken from the pytorch reference implementation
|
||||
let expected = Tensor::<TestBackend, 2>::from_floats(
|
||||
[
|
||||
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
|
||||
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
|
||||
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
|
||||
[0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
|
||||
[-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
|
||||
[-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
|
||||
[-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
|
||||
[0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
|
||||
[0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
|
||||
[0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
|
||||
[-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
|
||||
[-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
sinusoids
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn d_model_input_should_match() {
|
||||
let d_model = 8;
|
||||
let device = Default::default();
|
||||
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
|
||||
let input = Tensor::zeros([1, 5, 10], &device);
|
||||
let _output = pe.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn input_length_should_be_less_than_max_len() {
|
||||
let d_model = 8;
|
||||
let device = Default::default();
|
||||
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
|
||||
let input = Tensor::zeros([1, 6_000, d_model], &device);
|
||||
let _output = pe.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = PositionalEncodingConfig::new(4);
|
||||
let pe = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{pe}"),
|
||||
"PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,742 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::GateController;
|
||||
use crate::activation::{Activation, ActivationConfig};
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// A RnnState is used to store hidden state in RNN.
|
||||
pub struct RnnState<B: Backend, const D: usize> {
|
||||
/// The hidden state.
|
||||
pub hidden: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> RnnState<B, D> {
|
||||
/// Initialize a new [RNN State](RnnState).
|
||||
pub fn new(hidden: Tensor<B, D>) -> Self {
|
||||
Self { hidden }
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to create a [Rnn](Rnn) module using the [init function](RnnConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct RnnConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If a bias should be applied during the Rnn transformation.
|
||||
pub bias: bool,
|
||||
/// Rnn initializer
|
||||
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
|
||||
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
|
||||
#[config(default = true)]
|
||||
pub batch_first: bool,
|
||||
/// If true, process the sequence in reverse order.
|
||||
/// This is useful for implementing reverse-direction RNNs (e.g., ONNX reverse direction).
|
||||
#[config(default = false)]
|
||||
pub reverse: bool,
|
||||
/// Optional hidden state clip threshold. If provided, hidden state values are clipped
|
||||
/// to the range `[-clip, +clip]` after each timestep. This can help prevent
|
||||
/// exploding values during inference.
|
||||
pub clip: Option<f64>,
|
||||
/// Activation function applied to the hidden state before computing hidden output.
|
||||
/// Default is Tanh, which is standard for Rnn.
|
||||
#[config(default = "ActivationConfig::Tanh")]
|
||||
pub hidden_activation: ActivationConfig,
|
||||
}
|
||||
|
||||
/// The Rnn module. This implementation is for a unidirectional, stateless, Rnn.
|
||||
/// Should be created with [RnnConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Rnn<B: Backend> {
|
||||
/// gate controller for Rnn (has single gate).
|
||||
pub gate: GateController<B>,
|
||||
/// The hidden state of the Rnn.
|
||||
pub d_hidden: usize,
|
||||
/// If true, input is `[batch_size, seq_length, input_size]`.
|
||||
/// If false, input is `[seq_length, batch_size, input_size]`.
|
||||
pub batch_first: bool,
|
||||
/// If true, process the sequence in reverse order.
|
||||
pub reverse: bool,
|
||||
/// Optional hidden state clip threshold.
|
||||
pub clip: Option<f64>,
|
||||
/// Activation function for hidden output.
|
||||
pub hidden_activation: Activation<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Rnn<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self.gate.input_transform.weight.shape().dims();
|
||||
let bias = self.gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl RnnConfig {
|
||||
/// Initialize a new [Rnn](Rnn) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Rnn<B> {
|
||||
let d_output = self.d_hidden;
|
||||
|
||||
let new_gate = || {
|
||||
GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
device,
|
||||
)
|
||||
};
|
||||
|
||||
Rnn {
|
||||
gate: new_gate(),
|
||||
d_hidden: self.d_hidden,
|
||||
batch_first: self.batch_first,
|
||||
reverse: self.reverse,
|
||||
clip: self.clip,
|
||||
hidden_activation: self.hidden_activation.init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Rnn<B> {
|
||||
/// Applies the forward pass on the input tensor. This RNN implementation
|
||||
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
|
||||
///
|
||||
/// ## Parameters:
|
||||
/// - batched_input: The input tensor of shape:
|
||||
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
|
||||
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
|
||||
/// - state: An optional `RnnState` representing the initial hidden state.
|
||||
/// The state tensor has shape `[batch_size, hidden_size]`.
|
||||
/// If no initial state is provided, these tensors are initialized to zeros.
|
||||
///
|
||||
/// ## Returns:
|
||||
/// - output: A tensor represents the output features of Rnn. Shape:
|
||||
/// - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true
|
||||
/// - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false
|
||||
/// - state: A `RnnState` represents the final hidden state. The hidden state tensor has the shape
|
||||
/// `[batch_size, hidden_size]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
state: Option<RnnState<B, 2>>,
|
||||
) -> (Tensor<B, 3>, RnnState<B, 2>) {
|
||||
// Convert to batch-first layout internally if needed
|
||||
let batched_input = if self.batch_first {
|
||||
batched_input
|
||||
} else {
|
||||
batched_input.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
let device = batched_input.device();
|
||||
let [batch_size, seq_length, _] = batched_input.dims();
|
||||
|
||||
// Process sequence in forward or reverse order based on config
|
||||
let (output, state) = if self.reverse {
|
||||
self.forward_iter(
|
||||
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
|
||||
state,
|
||||
batch_size,
|
||||
seq_length,
|
||||
&device,
|
||||
)
|
||||
} else {
|
||||
self.forward_iter(
|
||||
batched_input.iter_dim(1).zip(0..seq_length),
|
||||
state,
|
||||
batch_size,
|
||||
seq_length,
|
||||
&device,
|
||||
)
|
||||
};
|
||||
|
||||
// Convert output back to seq-first layout if needed
|
||||
let output = if self.batch_first {
|
||||
output
|
||||
} else {
|
||||
output.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
(output, state)
|
||||
}
|
||||
|
||||
fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
|
||||
&self,
|
||||
input_timestep_iter: I,
|
||||
state: Option<RnnState<B, 2>>,
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
device: &B::Device,
|
||||
) -> (Tensor<B, 3>, RnnState<B, 2>) {
|
||||
let mut batched_hidden_state =
|
||||
Tensor::empty([batch_size, seq_length, self.d_hidden], device);
|
||||
|
||||
let mut hidden_state = match state {
|
||||
Some(state) => state.hidden,
|
||||
None => Tensor::zeros([batch_size, self.d_hidden], device),
|
||||
};
|
||||
|
||||
for (input_t, t) in input_timestep_iter {
|
||||
let input_t = input_t.squeeze_dim(1);
|
||||
|
||||
// Compute gate output: h_t = activation(W_i @ x_t + W_h @ h_{t-1} + b)
|
||||
let biased_gate_sum = self
|
||||
.gate
|
||||
.gate_product(input_t.clone(), hidden_state.clone());
|
||||
|
||||
let output_values = self.hidden_activation.forward(biased_gate_sum);
|
||||
|
||||
// Update hidden state
|
||||
hidden_state = output_values;
|
||||
|
||||
// Apply hidden state clipping if configured
|
||||
if let Some(clip) = self.clip {
|
||||
hidden_state = hidden_state.clamp(-clip, clip);
|
||||
}
|
||||
|
||||
let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
|
||||
|
||||
// store the hidden state for this timestep
|
||||
batched_hidden_state = batched_hidden_state.slice_assign(
|
||||
[0..batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
unsqueezed_hidden_state.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
(batched_hidden_state, RnnState::new(hidden_state))
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to create a [BiRnn](BiRnn) module using the [init function](BiRnnConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BiRnnConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If a bias should be applied during the BiRnn transformation.
|
||||
pub bias: bool,
|
||||
/// BiRnn initializer
|
||||
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
|
||||
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
|
||||
#[config(default = true)]
|
||||
pub batch_first: bool,
|
||||
/// Optional hidden state clip threshold.
|
||||
pub clip: Option<f64>,
|
||||
/// Activation function applied to the hidden state before computing hidden output.
|
||||
#[config(default = "ActivationConfig::Tanh")]
|
||||
pub hidden_activation: ActivationConfig,
|
||||
}
|
||||
|
||||
/// The BiRnn module. This implementation is for Bidirectional RNN.
|
||||
/// Should be created with [BiRnnConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BiRnn<B: Backend> {
|
||||
/// RNN for the forward direction.
|
||||
pub forward: Rnn<B>,
|
||||
/// RNN for the reverse direction.
|
||||
pub reverse: Rnn<B>,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If true, input is `[batch_size, seq_length, input_size]`.
|
||||
/// If false, input is `[seq_length, batch_size, input_size]`.
|
||||
pub batch_first: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for BiRnn<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self.forward.gate.input_transform.weight.shape().dims();
|
||||
let bias = self.forward.gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl BiRnnConfig {
|
||||
/// Initialize a new [Bidirectional RNN](BiRnn) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> BiRnn<B> {
|
||||
// Internal RNNs always use batch_first=true; BiRnn handles layout conversion
|
||||
let base_config = RnnConfig::new(self.d_input, self.d_hidden, self.bias)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.with_batch_first(true)
|
||||
.with_clip(self.clip)
|
||||
.with_hidden_activation(self.hidden_activation.clone());
|
||||
|
||||
BiRnn {
|
||||
forward: base_config.clone().init(device),
|
||||
reverse: base_config.init(device),
|
||||
d_hidden: self.d_hidden,
|
||||
batch_first: self.batch_first,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BiRnn<B> {
|
||||
/// Applies the forward pass on the input tensor. This Bidirectional RNN implementation
|
||||
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
|
||||
///
|
||||
/// ## Parameters:
|
||||
/// - batched_input: The input tensor of shape:
|
||||
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
|
||||
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
|
||||
/// - state: An optional `RnnState` representing the hidden state.
|
||||
/// Each state tensor has shape `[2, batch_size, hidden_size]`.
|
||||
/// If no initial state is provided, these tensors are initialized to zeros.
|
||||
///
|
||||
/// ## Returns:
|
||||
/// - output: A tensor represents the output features of RNN. Shape:
|
||||
/// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
|
||||
/// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
|
||||
/// - state: A `RnnState` represents the final forward and reverse states.
|
||||
/// The `state.hidden` have the shape `[2, batch_size, hidden_size]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
state: Option<RnnState<B, 3>>,
|
||||
) -> (Tensor<B, 3>, RnnState<B, 3>) {
|
||||
// Convert to batch-first layout internally if needed
|
||||
let batched_input = if self.batch_first {
|
||||
batched_input
|
||||
} else {
|
||||
batched_input.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
let device = batched_input.clone().device();
|
||||
let [batch_size, seq_length, _] = batched_input.shape().dims();
|
||||
|
||||
let [init_state_forward, init_state_reverse] = match state {
|
||||
Some(state) => {
|
||||
let hidden_state_forward = state
|
||||
.hidden
|
||||
.clone()
|
||||
.slice([0..1, 0..batch_size, 0..self.d_hidden])
|
||||
.squeeze_dim(0);
|
||||
let hidden_state_reverse = state
|
||||
.hidden
|
||||
.slice([1..2, 0..batch_size, 0..self.d_hidden])
|
||||
.squeeze_dim(0);
|
||||
|
||||
[
|
||||
Some(RnnState::new(hidden_state_forward)),
|
||||
Some(RnnState::new(hidden_state_reverse)),
|
||||
]
|
||||
}
|
||||
None => [None, None],
|
||||
};
|
||||
|
||||
// forward direction
|
||||
let (batched_hidden_state_forward, final_state_forward) = self
|
||||
.forward
|
||||
.forward(batched_input.clone(), init_state_forward);
|
||||
|
||||
// reverse direction
|
||||
let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
|
||||
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
|
||||
init_state_reverse,
|
||||
batch_size,
|
||||
seq_length,
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = Tensor::cat(
|
||||
[batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
|
||||
2,
|
||||
);
|
||||
|
||||
// Convert output back to seq-first layout if needed
|
||||
let output = if self.batch_first {
|
||||
output
|
||||
} else {
|
||||
output.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
let state = RnnState::new(Tensor::stack(
|
||||
[final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
|
||||
0,
|
||||
));
|
||||
|
||||
(output, state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{LinearRecord, TestBackend};
|
||||
use burn::module::Param;
|
||||
use burn::tensor::{Device, Distribution, TensorData};
|
||||
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::TestAutodiffBackend;
|
||||
|
||||
fn create_single_feature_gate_controller(
|
||||
weights: f32,
|
||||
biases: f32,
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
device: &Device<TestBackend>,
|
||||
) -> GateController<TestBackend> {
|
||||
let record_1 = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from([[weights]]), device),
|
||||
bias: Some(Param::from_data(TensorData::from([biases]), device)),
|
||||
};
|
||||
let record_2 = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from([[weights]]), device),
|
||||
bias: Some(Param::from_data(TensorData::from([biases]), device)),
|
||||
};
|
||||
GateController::create_with_weights(
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
record_1,
|
||||
record_2,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_uniform_initializer() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = RnnConfig::new(5, 5, false)
|
||||
.with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
|
||||
let rnn = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let gate_to_data =
|
||||
|gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
|
||||
|
||||
gate_to_data(rnn.gate).assert_within_range::<FT>(0.elem()..1.elem());
|
||||
}
|
||||
|
||||
/// Test forward pass with simple input vector.
|
||||
///
|
||||
/// Simple RNN: h_t = tanh(W_input @ x_t + W_hidden @ h_{t-1} + b)
|
||||
/// With input=0.1, weight_input=0.5, bias=0.0, h_0=0.0, weight_hidden=0.5
|
||||
/// h_t = tanh(0.5*0.1 + 0.5*0) = tanh(0.05) = 0.04995
|
||||
#[test]
|
||||
fn test_forward_single_input_single_feature() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = RnnConfig::new(1, 1, false);
|
||||
let device = Default::default();
|
||||
let mut rnn = config.init::<TestBackend>(&device);
|
||||
|
||||
rnn.gate = create_single_feature_gate_controller(
|
||||
0.5,
|
||||
0.0,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
&device,
|
||||
);
|
||||
|
||||
// single timestep with single feature
|
||||
let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
|
||||
|
||||
let (output, state) = rnn.forward(input, None);
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
let expected = TensorData::from([[0.04995]]);
|
||||
state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
|
||||
output
|
||||
.select(0, Tensor::arange(0..1, &device))
|
||||
.squeeze_dim::<2>(0)
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_forward_pass_batch_of_one() {
|
||||
let device = Default::default();
|
||||
let rnn = RnnConfig::new(64, 1024, true).init(&device);
|
||||
let batched_input =
|
||||
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
|
||||
|
||||
let (output, state) = rnn.forward(batched_input, None);
|
||||
assert_eq!(output.dims(), [1, 2, 1024]);
|
||||
assert_eq!(state.hidden.dims(), [1, 1024]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_batched_backward_pass() {
|
||||
use burn::tensor::Shape;
|
||||
let device = Default::default();
|
||||
let rnn = RnnConfig::new(64, 32, true).init(&device);
|
||||
let shape: Shape = [8, 10, 64].into();
|
||||
let batched_input =
|
||||
Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
|
||||
|
||||
let (output, _) = rnn.forward(batched_input.clone(), None);
|
||||
let fake_loss = output;
|
||||
let grads = fake_loss.backward();
|
||||
|
||||
let some_gradient = rnn.gate.hidden_transform.weight.grad(&grads).unwrap();
|
||||
|
||||
// Asserts that the gradients exist and are non-zero
|
||||
assert_ne!(
|
||||
some_gradient
|
||||
.any()
|
||||
.into_data()
|
||||
.iter::<f32>()
|
||||
.next()
|
||||
.unwrap(),
|
||||
0.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bidirectional() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = BiRnnConfig::new(2, 3, true);
|
||||
let mut rnn = config.init(&device);
|
||||
|
||||
fn create_gate_controller<const D1: usize, const D2: usize>(
|
||||
input_weights: [[f32; D1]; D2],
|
||||
input_biases: [f32; D1],
|
||||
hidden_weights: [[f32; D1]; D1],
|
||||
hidden_biases: [f32; D1],
|
||||
device: &Device<TestBackend>,
|
||||
) -> GateController<TestBackend> {
|
||||
let d_input = input_weights[0].len();
|
||||
let d_output = input_weights.len();
|
||||
|
||||
let input_record = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from(input_weights), device),
|
||||
bias: Some(Param::from_data(TensorData::from(input_biases), device)),
|
||||
};
|
||||
let hidden_record = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from(hidden_weights), device),
|
||||
bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
|
||||
};
|
||||
GateController::create_with_weights(
|
||||
d_input,
|
||||
d_output,
|
||||
true,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
input_record,
|
||||
hidden_record,
|
||||
)
|
||||
}
|
||||
|
||||
// [batch_size=1, seq_length=4, input_size=2]
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[
|
||||
[0.949, -0.861],
|
||||
[0.892, 0.927],
|
||||
[-0.173, -0.301],
|
||||
[-0.081, 0.992],
|
||||
]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
// [2, batch_size=1, hidden_size=3]
|
||||
let h0 = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
rnn.forward.gate = create_gate_controller(
|
||||
// input_weights: [input_size=2, hidden_size=3]
|
||||
[[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
|
||||
// input_biases: [hidden_size=3]
|
||||
[-0.196, 0.354, 0.209],
|
||||
// hidden_weights: [hidden_size=3, hidden_size=3]
|
||||
[
|
||||
[-0.320, 0.232, -0.165],
|
||||
[0.093, -0.572, -0.315],
|
||||
[-0.467, 0.325, 0.046],
|
||||
],
|
||||
// hidden_biases: [hidden_size=3]
|
||||
[0.181, -0.190, -0.245],
|
||||
&device,
|
||||
);
|
||||
|
||||
rnn.reverse.gate = create_gate_controller(
|
||||
[[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
|
||||
[0.540, -0.164, 0.033],
|
||||
[
|
||||
[0.159, 0.180, -0.037],
|
||||
[-0.443, 0.485, -0.488],
|
||||
[0.098, -0.085, -0.140],
|
||||
],
|
||||
[-0.510, 0.105, 0.114],
|
||||
&device,
|
||||
);
|
||||
|
||||
// [batch_size=1, sequence_length=4, hidden_size * 2 = 6]
|
||||
// The expected output values were computed from PyTorch
|
||||
let expected_output_with_init_state = TensorData::from([[
|
||||
[0.5226, -0.6370, 0.0210, 0.0685, 0.3867, 0.3602],
|
||||
[0.3580, 0.8431, 0.4129, -0.3175, 0.4374, 0.1766],
|
||||
[-0.3837, -0.2703, -0.3957, -0.1542, -0.1122, 0.0725],
|
||||
[0.5059, 0.5527, 0.1244, -0.6779, 0.3725, -0.3387],
|
||||
]]);
|
||||
let expected_output_without_init_state = TensorData::from([[
|
||||
[0.0560, -0.2056, 0.2334, 0.0892, 0.3912, 0.3607],
|
||||
[0.4340, 0.7378, 0.3714, -0.2394, 0.4235, 0.2002],
|
||||
[-0.3962, -0.2097, -0.3798, 0.0532, -0.2067, 0.1727],
|
||||
[0.5075, 0.5298, 0.1083, -0.3200, 0.0764, -0.1282],
|
||||
]]);
|
||||
|
||||
//`[2, batch_size=1, hidden_size=3]`
|
||||
let expected_hn_with_init_state =
|
||||
TensorData::from([[[0.5059, 0.5527, 0.1244]], [[0.0685, 0.3867, 0.3602]]]);
|
||||
let expected_hn_without_init_state =
|
||||
TensorData::from([[[0.5075, 0.5298, 0.1083]], [[0.0892, 0.3912, 0.3607]]]);
|
||||
|
||||
let (output_with_init_state, state_with_init_state) =
|
||||
rnn.forward(input.clone(), Some(RnnState::new(h0)));
|
||||
let (output_without_init_state, state_without_init_state) = rnn.forward(input, None);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
output_with_init_state
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
|
||||
output_without_init_state
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
|
||||
state_with_init_state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
|
||||
state_without_init_state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_rnn() {
|
||||
let config = RnnConfig::new(2, 3, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"Rnn {d_input: 2, d_hidden: 3, bias: true, params: 21}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_birnn() {
|
||||
let config = BiRnnConfig::new(2, 3, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"BiRnn {d_input: 2, d_hidden: 3, bias: true, params: 42}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rnn_clipping() {
|
||||
let device = Default::default();
|
||||
|
||||
// Create Rnn with clipping enabled
|
||||
let clip_value = 0.3;
|
||||
let config = RnnConfig::new(4, 8, true).with_clip(Some(clip_value));
|
||||
let rnn = config.init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
|
||||
let (_, state) = rnn.forward(input, None);
|
||||
|
||||
// Verify output values are within the clip range
|
||||
let hidden_state: Vec<f32> = state.hidden.to_data().to_vec().unwrap();
|
||||
for val in hidden_state {
|
||||
assert!(
|
||||
val >= -clip_value as f32 && val <= clip_value as f32,
|
||||
"Value {} is outside clip range [-{}, {}]",
|
||||
val,
|
||||
clip_value,
|
||||
clip_value
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reverse_sequence() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
// Create RNN with reverse=true to process sequence in reverse order
|
||||
let config = RnnConfig::new(1, 1, false).with_reverse(true);
|
||||
let mut rnn = config.init::<TestBackend>(&device);
|
||||
|
||||
rnn.gate = create_single_feature_gate_controller(
|
||||
0.5,
|
||||
0.0,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
&device,
|
||||
);
|
||||
|
||||
// Create input with 3 timesteps: [0.1, 0.2, 0.3]
|
||||
// Shape: [batch_size=1, seq_length=3, input_features=1]
|
||||
let input =
|
||||
Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
|
||||
|
||||
let (output, state) = rnn.forward(input, None);
|
||||
|
||||
// With reverse=true and weight=0.5, sequence is processed in reverse:
|
||||
// t=2 (last): h = tanh(0.5*0.3 + 0.5*0) = tanh(0.15) ≈ 0.1488850
|
||||
// t=1 (mid): h = tanh(0.5*0.2 + 0.5*0.1488850) ≈ 0.17269433
|
||||
// t=0 (first): h = tanh(0.5*0.1 + 0.5*0.17269433) ≈ 0.135508
|
||||
let expected_final_hidden = TensorData::from([[0.135508]]);
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_final_hidden, tolerance);
|
||||
|
||||
// Verify output tensor has correct shape and matches state at final timestep
|
||||
assert_eq!(output.dims(), [1, 3, 1]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::{Linear, LinearConfig, LinearLayout};
|
||||
use burn::module::{Initializer, Module};
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
/// A GateController represents a gate in an LSTM cell. An
|
||||
/// LSTM cell generally contains three gates: an input gate,
|
||||
/// forget gate, and output gate. Additionally, cell gate
|
||||
/// is just used to compute the cell state.
|
||||
///
|
||||
/// An Lstm gate is modeled as two linear transformations.
|
||||
/// The results of these transformations are used to calculate
|
||||
/// the gate's output.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GateController<B: Backend> {
|
||||
/// Represents the affine transformation applied to input vector
|
||||
pub input_transform: Linear<B>,
|
||||
/// Represents the affine transformation applied to the hidden state
|
||||
pub hidden_transform: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> GateController<B> {
|
||||
/// Initialize a new [gate_controller](GateController) module.
|
||||
pub fn new(
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
Self {
|
||||
input_transform: LinearConfig {
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer: initializer.clone(),
|
||||
layout: LinearLayout::Row,
|
||||
}
|
||||
.init(device),
|
||||
hidden_transform: LinearConfig {
|
||||
d_input: d_output,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
layout: LinearLayout::Row,
|
||||
}
|
||||
.init(device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function for performing weighted matrix product for a gate and adds
|
||||
/// bias, if any.
|
||||
///
|
||||
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
|
||||
/// Wx = weight matrix for the connection to input vector X
|
||||
/// Wh = weight matrix for the connection to hidden state H
|
||||
/// X = input vector
|
||||
/// H = hidden state
|
||||
/// b = bias terms
|
||||
pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
|
||||
}
|
||||
|
||||
/// Used to initialize a gate controller with known weight layers,
|
||||
/// allowing for predictable behavior. Used only for testing in
|
||||
/// lstm.
|
||||
#[cfg(test)]
|
||||
pub fn create_with_weights(
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
input_record: crate::LinearRecord<B>,
|
||||
hidden_record: crate::LinearRecord<B>,
|
||||
) -> Self {
|
||||
let l1 = LinearConfig {
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer: initializer.clone(),
|
||||
layout: LinearLayout::Row,
|
||||
}
|
||||
.init(&input_record.weight.device())
|
||||
.load_record(input_record);
|
||||
let l2 = LinearConfig {
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
layout: LinearLayout::Row,
|
||||
}
|
||||
.init(&hidden_record.weight.device())
|
||||
.load_record(hidden_record);
|
||||
|
||||
Self {
|
||||
input_transform: l1,
|
||||
hidden_transform: l2,
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,922 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::GateController;
|
||||
use crate::activation::{Activation, ActivationConfig};
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// A LstmState is used to store cell state and hidden state in LSTM.
|
||||
pub struct LstmState<B: Backend, const D: usize> {
|
||||
/// The cell state.
|
||||
pub cell: Tensor<B, D>,
|
||||
/// The hidden state.
|
||||
pub hidden: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> LstmState<B, D> {
|
||||
/// Initialize a new [LSTM State](LstmState).
|
||||
pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
|
||||
Self { cell, hidden }
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LstmConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If a bias should be applied during the Lstm transformation.
|
||||
pub bias: bool,
|
||||
/// Lstm initializer
|
||||
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
|
||||
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
|
||||
#[config(default = true)]
|
||||
pub batch_first: bool,
|
||||
/// If true, process the sequence in reverse order.
|
||||
/// This is useful for implementing reverse-direction LSTMs (e.g., ONNX reverse direction).
|
||||
#[config(default = false)]
|
||||
pub reverse: bool,
|
||||
/// Optional cell state clip threshold. If provided, cell state values are clipped
|
||||
/// to the range `[-clip, +clip]` after each timestep. This can help prevent
|
||||
/// exploding values during inference.
|
||||
pub clip: Option<f64>,
|
||||
/// If true, couples the input and forget gates: `f_t = 1 - i_t`.
|
||||
/// This reduces the number of parameters and is based on GRU-style simplification.
|
||||
#[config(default = false)]
|
||||
pub input_forget: bool,
|
||||
/// Activation function for the input, forget, and output gates.
|
||||
/// Default is Sigmoid, which is standard for LSTM gates.
|
||||
#[config(default = "ActivationConfig::Sigmoid")]
|
||||
pub gate_activation: ActivationConfig,
|
||||
/// Activation function for the cell gate (candidate cell state).
|
||||
/// Default is Tanh, which is standard for LSTM.
|
||||
#[config(default = "ActivationConfig::Tanh")]
|
||||
pub cell_activation: ActivationConfig,
|
||||
/// Activation function applied to the cell state before computing hidden output.
|
||||
/// Default is Tanh, which is standard for LSTM.
|
||||
#[config(default = "ActivationConfig::Tanh")]
|
||||
pub hidden_activation: ActivationConfig,
|
||||
}
|
||||
|
||||
/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
|
||||
///
|
||||
/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).
|
||||
///
|
||||
/// Should be created with [LstmConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Lstm<B: Backend> {
|
||||
/// The input gate regulates which information to update and store in the cell state at each time step.
|
||||
pub input_gate: GateController<B>,
|
||||
/// The forget gate is used to control which information to discard or keep in the memory cell at each time step.
|
||||
/// Note: When `input_forget` is true, this gate is not used (forget = 1 - input).
|
||||
pub forget_gate: GateController<B>,
|
||||
/// The output gate determines which information from the cell state to output at each time step.
|
||||
pub output_gate: GateController<B>,
|
||||
/// The cell gate is used to compute the cell state that stores and carries information through time.
|
||||
pub cell_gate: GateController<B>,
|
||||
/// The hidden state of the LSTM.
|
||||
pub d_hidden: usize,
|
||||
/// If true, input is `[batch_size, seq_length, input_size]`.
|
||||
/// If false, input is `[seq_length, batch_size, input_size]`.
|
||||
pub batch_first: bool,
|
||||
/// If true, process the sequence in reverse order.
|
||||
pub reverse: bool,
|
||||
/// Optional cell state clip threshold.
|
||||
pub clip: Option<f64>,
|
||||
/// If true, couples input and forget gates: f_t = 1 - i_t.
|
||||
pub input_forget: bool,
|
||||
/// Activation function for gates (input, forget, output).
|
||||
pub gate_activation: Activation<B>,
|
||||
/// Activation function for cell gate (candidate cell state).
|
||||
pub cell_activation: Activation<B>,
|
||||
/// Activation function for hidden output.
|
||||
pub hidden_activation: Activation<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Lstm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
|
||||
let bias = self.input_gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl LstmConfig {
|
||||
/// Initialize a new [lstm](Lstm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
|
||||
let d_output = self.d_hidden;
|
||||
|
||||
let new_gate = || {
|
||||
GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
device,
|
||||
)
|
||||
};
|
||||
|
||||
Lstm {
|
||||
input_gate: new_gate(),
|
||||
forget_gate: new_gate(),
|
||||
output_gate: new_gate(),
|
||||
cell_gate: new_gate(),
|
||||
d_hidden: self.d_hidden,
|
||||
batch_first: self.batch_first,
|
||||
reverse: self.reverse,
|
||||
clip: self.clip,
|
||||
input_forget: self.input_forget,
|
||||
gate_activation: self.gate_activation.init(device),
|
||||
cell_activation: self.cell_activation.init(device),
|
||||
hidden_activation: self.hidden_activation.init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Lstm<B> {
|
||||
/// Applies the forward pass on the input tensor. This LSTM implementation
|
||||
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
|
||||
///
|
||||
/// ## Parameters:
|
||||
/// - batched_input: The input tensor of shape:
|
||||
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
|
||||
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
|
||||
/// - state: An optional `LstmState` representing the initial cell state and hidden state.
|
||||
/// Each state tensor has shape `[batch_size, hidden_size]`.
|
||||
/// If no initial state is provided, these tensors are initialized to zeros.
|
||||
///
|
||||
/// ## Returns:
|
||||
/// - output: A tensor represents the output features of LSTM. Shape:
|
||||
/// - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true
|
||||
/// - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false
|
||||
/// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape
|
||||
/// `[batch_size, hidden_size]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
state: Option<LstmState<B, 2>>,
|
||||
) -> (Tensor<B, 3>, LstmState<B, 2>) {
|
||||
// Convert to batch-first layout internally if needed
|
||||
let batched_input = if self.batch_first {
|
||||
batched_input
|
||||
} else {
|
||||
batched_input.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
let device = batched_input.device();
|
||||
let [batch_size, seq_length, _] = batched_input.dims();
|
||||
|
||||
// Process sequence in forward or reverse order based on config
|
||||
let (output, state) = if self.reverse {
|
||||
self.forward_iter(
|
||||
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
|
||||
state,
|
||||
batch_size,
|
||||
seq_length,
|
||||
&device,
|
||||
)
|
||||
} else {
|
||||
self.forward_iter(
|
||||
batched_input.iter_dim(1).zip(0..seq_length),
|
||||
state,
|
||||
batch_size,
|
||||
seq_length,
|
||||
&device,
|
||||
)
|
||||
};
|
||||
|
||||
// Convert output back to seq-first layout if needed
|
||||
let output = if self.batch_first {
|
||||
output
|
||||
} else {
|
||||
output.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
(output, state)
|
||||
}
|
||||
|
||||
fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
|
||||
&self,
|
||||
input_timestep_iter: I,
|
||||
state: Option<LstmState<B, 2>>,
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
device: &B::Device,
|
||||
) -> (Tensor<B, 3>, LstmState<B, 2>) {
|
||||
let mut batched_hidden_state =
|
||||
Tensor::empty([batch_size, seq_length, self.d_hidden], device);
|
||||
|
||||
let (mut cell_state, mut hidden_state) = match state {
|
||||
Some(state) => (state.cell, state.hidden),
|
||||
None => (
|
||||
Tensor::zeros([batch_size, self.d_hidden], device),
|
||||
Tensor::zeros([batch_size, self.d_hidden], device),
|
||||
),
|
||||
};
|
||||
|
||||
for (input_t, t) in input_timestep_iter {
|
||||
let input_t = input_t.squeeze_dim(1);
|
||||
|
||||
// i(nput)g(ate) tensors
|
||||
let biased_ig_input_sum = self
|
||||
.input_gate
|
||||
.gate_product(input_t.clone(), hidden_state.clone());
|
||||
let input_values = self.gate_activation.forward(biased_ig_input_sum);
|
||||
|
||||
// f(orget)g(ate) tensors - either computed or coupled to input gate
|
||||
let forget_values = if self.input_forget {
|
||||
// Coupled mode: f_t = 1 - i_t
|
||||
input_values.clone().neg().add_scalar(1.0)
|
||||
} else {
|
||||
let biased_fg_input_sum = self
|
||||
.forget_gate
|
||||
.gate_product(input_t.clone(), hidden_state.clone());
|
||||
self.gate_activation.forward(biased_fg_input_sum)
|
||||
};
|
||||
|
||||
// o(output)g(ate) tensors
|
||||
let biased_og_input_sum = self
|
||||
.output_gate
|
||||
.gate_product(input_t.clone(), hidden_state.clone());
|
||||
let output_values = self.gate_activation.forward(biased_og_input_sum);
|
||||
|
||||
// c(ell)g(ate) tensors
|
||||
let biased_cg_input_sum = self
|
||||
.cell_gate
|
||||
.gate_product(input_t.clone(), hidden_state.clone());
|
||||
let candidate_cell_values = self.cell_activation.forward(biased_cg_input_sum);
|
||||
|
||||
cell_state = forget_values * cell_state.clone() + input_values * candidate_cell_values;
|
||||
|
||||
// Apply cell state clipping if configured
|
||||
if let Some(clip) = self.clip {
|
||||
cell_state = cell_state.clamp(-clip, clip);
|
||||
}
|
||||
|
||||
hidden_state = output_values * self.hidden_activation.forward(cell_state.clone());
|
||||
|
||||
let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
|
||||
|
||||
// store the hidden state for this timestep
|
||||
batched_hidden_state = batched_hidden_state.slice_assign(
|
||||
[0..batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
unsqueezed_hidden_state.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
(
|
||||
batched_hidden_state,
|
||||
LstmState::new(cell_state, hidden_state),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BiLstmConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If a bias should be applied during the BiLstm transformation.
|
||||
pub bias: bool,
|
||||
/// BiLstm initializer
|
||||
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
/// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
|
||||
/// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
|
||||
#[config(default = true)]
|
||||
pub batch_first: bool,
|
||||
/// Optional cell state clip threshold.
|
||||
pub clip: Option<f64>,
|
||||
/// If true, couples the input and forget gates.
|
||||
#[config(default = false)]
|
||||
pub input_forget: bool,
|
||||
/// Activation function for the input, forget, and output gates.
|
||||
#[config(default = "ActivationConfig::Sigmoid")]
|
||||
pub gate_activation: ActivationConfig,
|
||||
/// Activation function for the cell gate (candidate cell state).
|
||||
#[config(default = "ActivationConfig::Tanh")]
|
||||
pub cell_activation: ActivationConfig,
|
||||
/// Activation function applied to the cell state before computing hidden output.
|
||||
#[config(default = "ActivationConfig::Tanh")]
|
||||
pub hidden_activation: ActivationConfig,
|
||||
}
|
||||
|
||||
/// The BiLstm module. This implementation is for Bidirectional LSTM.
|
||||
///
|
||||
/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).
|
||||
///
|
||||
/// Should be created with [BiLstmConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BiLstm<B: Backend> {
|
||||
/// LSTM for the forward direction.
|
||||
pub forward: Lstm<B>,
|
||||
/// LSTM for the reverse direction.
|
||||
pub reverse: Lstm<B>,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If true, input is `[batch_size, seq_length, input_size]`.
|
||||
/// If false, input is `[seq_length, batch_size, input_size]`.
|
||||
pub batch_first: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for BiLstm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self
|
||||
.forward
|
||||
.input_gate
|
||||
.input_transform
|
||||
.weight
|
||||
.shape()
|
||||
.dims();
|
||||
let bias = self.forward.input_gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl BiLstmConfig {
|
||||
/// Initialize a new [Bidirectional LSTM](BiLstm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
|
||||
// Internal LSTMs always use batch_first=true; BiLstm handles layout conversion
|
||||
let base_config = LstmConfig::new(self.d_input, self.d_hidden, self.bias)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.with_batch_first(true)
|
||||
.with_clip(self.clip)
|
||||
.with_input_forget(self.input_forget)
|
||||
.with_gate_activation(self.gate_activation.clone())
|
||||
.with_cell_activation(self.cell_activation.clone())
|
||||
.with_hidden_activation(self.hidden_activation.clone());
|
||||
|
||||
BiLstm {
|
||||
forward: base_config.clone().init(device),
|
||||
reverse: base_config.init(device),
|
||||
d_hidden: self.d_hidden,
|
||||
batch_first: self.batch_first,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BiLstm<B> {
|
||||
/// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation
|
||||
/// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
|
||||
///
|
||||
/// ## Parameters:
|
||||
/// - batched_input: The input tensor of shape:
|
||||
/// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
|
||||
/// - `[sequence_length, batch_size, input_size]` if `batch_first` is false
|
||||
/// - state: An optional `LstmState` representing the initial cell state and hidden state.
|
||||
/// Each state tensor has shape `[2, batch_size, hidden_size]`.
|
||||
/// If no initial state is provided, these tensors are initialized to zeros.
|
||||
///
|
||||
/// ## Returns:
|
||||
/// - output: A tensor represents the output features of LSTM. Shape:
|
||||
/// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
|
||||
/// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
|
||||
/// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and
|
||||
/// `state.hidden` have the shape `[2, batch_size, hidden_size]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
state: Option<LstmState<B, 3>>,
|
||||
) -> (Tensor<B, 3>, LstmState<B, 3>) {
|
||||
// Convert to batch-first layout internally if needed
|
||||
let batched_input = if self.batch_first {
|
||||
batched_input
|
||||
} else {
|
||||
batched_input.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
let device = batched_input.clone().device();
|
||||
let [batch_size, seq_length, _] = batched_input.shape().dims();
|
||||
|
||||
let [init_state_forward, init_state_reverse] = match state {
|
||||
Some(state) => {
|
||||
let cell_state_forward = state
|
||||
.cell
|
||||
.clone()
|
||||
.slice([0..1, 0..batch_size, 0..self.d_hidden])
|
||||
.squeeze_dim(0);
|
||||
let hidden_state_forward = state
|
||||
.hidden
|
||||
.clone()
|
||||
.slice([0..1, 0..batch_size, 0..self.d_hidden])
|
||||
.squeeze_dim(0);
|
||||
let cell_state_reverse = state
|
||||
.cell
|
||||
.slice([1..2, 0..batch_size, 0..self.d_hidden])
|
||||
.squeeze_dim(0);
|
||||
let hidden_state_reverse = state
|
||||
.hidden
|
||||
.slice([1..2, 0..batch_size, 0..self.d_hidden])
|
||||
.squeeze_dim(0);
|
||||
|
||||
[
|
||||
Some(LstmState::new(cell_state_forward, hidden_state_forward)),
|
||||
Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
|
||||
]
|
||||
}
|
||||
None => [None, None],
|
||||
};
|
||||
|
||||
// forward direction
|
||||
let (batched_hidden_state_forward, final_state_forward) = self
|
||||
.forward
|
||||
.forward(batched_input.clone(), init_state_forward);
|
||||
|
||||
// reverse direction
|
||||
let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
|
||||
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
|
||||
init_state_reverse,
|
||||
batch_size,
|
||||
seq_length,
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = Tensor::cat(
|
||||
[batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
|
||||
2,
|
||||
);
|
||||
|
||||
// Convert output back to seq-first layout if needed
|
||||
let output = if self.batch_first {
|
||||
output
|
||||
} else {
|
||||
output.swap_dims(0, 1)
|
||||
};
|
||||
|
||||
let state = LstmState::new(
|
||||
Tensor::stack(
|
||||
[final_state_forward.cell, final_state_reverse.cell].to_vec(),
|
||||
0,
|
||||
),
|
||||
Tensor::stack(
|
||||
[final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
|
||||
0,
|
||||
),
|
||||
);
|
||||
|
||||
(output, state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{LinearRecord, TestBackend};
|
||||
use burn::module::Param;
|
||||
use burn::tensor::{Device, Distribution, TensorData};
|
||||
use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::TestAutodiffBackend;
|
||||
|
||||
#[test]
|
||||
fn test_with_uniform_initializer() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = LstmConfig::new(5, 5, false)
|
||||
.with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
|
||||
let lstm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
let gate_to_data =
|
||||
|gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
|
||||
|
||||
gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());
|
||||
gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());
|
||||
gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());
|
||||
gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());
|
||||
}
|
||||
|
||||
/// Test forward pass with simple input vector.
|
||||
///
|
||||
/// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
|
||||
/// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
|
||||
/// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
|
||||
/// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
|
||||
/// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
|
||||
/// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
|
||||
#[test]
|
||||
fn test_forward_single_input_single_feature() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = LstmConfig::new(1, 1, false);
|
||||
let device = Default::default();
|
||||
let mut lstm = config.init::<TestBackend>(&device);
|
||||
|
||||
fn create_gate_controller(
|
||||
weights: f32,
|
||||
biases: f32,
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
device: &Device<TestBackend>,
|
||||
) -> GateController<TestBackend> {
|
||||
let record_1 = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from([[weights]]), device),
|
||||
bias: Some(Param::from_data(TensorData::from([biases]), device)),
|
||||
};
|
||||
let record_2 = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from([[weights]]), device),
|
||||
bias: Some(Param::from_data(TensorData::from([biases]), device)),
|
||||
};
|
||||
GateController::create_with_weights(
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
record_1,
|
||||
record_2,
|
||||
)
|
||||
}
|
||||
|
||||
lstm.input_gate = create_gate_controller(
|
||||
0.5,
|
||||
0.0,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
&device,
|
||||
);
|
||||
lstm.forget_gate = create_gate_controller(
|
||||
0.7,
|
||||
0.0,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
&device,
|
||||
);
|
||||
lstm.cell_gate = create_gate_controller(
|
||||
0.9,
|
||||
0.0,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
&device,
|
||||
);
|
||||
lstm.output_gate = create_gate_controller(
|
||||
1.1,
|
||||
0.0,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
&device,
|
||||
);
|
||||
|
||||
// single timestep with single feature
|
||||
let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
|
||||
|
||||
let (output, state) = lstm.forward(input, None);
|
||||
|
||||
let expected = TensorData::from([[0.046]]);
|
||||
let tolerance = Tolerance::default();
|
||||
state
|
||||
.cell
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[0.0242]]);
|
||||
state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, tolerance);
|
||||
|
||||
output
|
||||
.select(0, Tensor::arange(0..1, &device))
|
||||
.squeeze_dim::<2>(0)
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_forward_pass() {
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 1024, true).init(&device);
|
||||
let batched_input =
|
||||
Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
|
||||
|
||||
let (output, state) = lstm.forward(batched_input, None);
|
||||
|
||||
assert_eq!(output.dims(), [8, 10, 1024]);
|
||||
assert_eq!(state.cell.dims(), [8, 1024]);
|
||||
assert_eq!(state.hidden.dims(), [8, 1024]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_forward_pass_batch_of_one() {
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 1024, true).init(&device);
|
||||
let batched_input =
|
||||
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
|
||||
|
||||
let (output, state) = lstm.forward(batched_input, None);
|
||||
|
||||
assert_eq!(output.dims(), [1, 2, 1024]);
|
||||
assert_eq!(state.cell.dims(), [1, 1024]);
|
||||
assert_eq!(state.hidden.dims(), [1, 1024]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_batched_backward_pass() {
|
||||
use burn::tensor::Shape;
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 32, true).init(&device);
|
||||
let shape: Shape = [8, 10, 64].into();
|
||||
let batched_input =
|
||||
Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
|
||||
|
||||
let (output, _) = lstm.forward(batched_input.clone(), None);
|
||||
let fake_loss = output;
|
||||
let grads = fake_loss.backward();
|
||||
|
||||
let some_gradient = lstm
|
||||
.output_gate
|
||||
.hidden_transform
|
||||
.weight
|
||||
.grad(&grads)
|
||||
.unwrap();
|
||||
|
||||
// Asserts that the gradients exist and are non-zero
|
||||
assert_ne!(
|
||||
some_gradient
|
||||
.any()
|
||||
.into_data()
|
||||
.iter::<f32>()
|
||||
.next()
|
||||
.unwrap(),
|
||||
0.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bidirectional() {
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
let config = BiLstmConfig::new(2, 3, true);
|
||||
let device = Default::default();
|
||||
let mut lstm = config.init(&device);
|
||||
|
||||
fn create_gate_controller<const D1: usize, const D2: usize>(
|
||||
input_weights: [[f32; D1]; D2],
|
||||
input_biases: [f32; D1],
|
||||
hidden_weights: [[f32; D1]; D1],
|
||||
hidden_biases: [f32; D1],
|
||||
device: &Device<TestBackend>,
|
||||
) -> GateController<TestBackend> {
|
||||
let d_input = input_weights[0].len();
|
||||
let d_output = input_weights.len();
|
||||
|
||||
let input_record = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from(input_weights), device),
|
||||
bias: Some(Param::from_data(TensorData::from(input_biases), device)),
|
||||
};
|
||||
let hidden_record = LinearRecord {
|
||||
weight: Param::from_data(TensorData::from(hidden_weights), device),
|
||||
bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
|
||||
};
|
||||
GateController::create_with_weights(
|
||||
d_input,
|
||||
d_output,
|
||||
true,
|
||||
Initializer::XavierUniform { gain: 1.0 },
|
||||
input_record,
|
||||
hidden_record,
|
||||
)
|
||||
}
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[
|
||||
[0.949, -0.861],
|
||||
[0.892, 0.927],
|
||||
[-0.173, -0.301],
|
||||
[-0.081, 0.992],
|
||||
]]),
|
||||
&device,
|
||||
);
|
||||
let h0 = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
|
||||
&device,
|
||||
);
|
||||
let c0 = Tensor::<TestBackend, 3>::from_data(
|
||||
TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.forward.input_gate = create_gate_controller(
|
||||
[[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
|
||||
[-0.196, 0.354, 0.209],
|
||||
[
|
||||
[-0.320, 0.232, -0.165],
|
||||
[0.093, -0.572, -0.315],
|
||||
[-0.467, 0.325, 0.046],
|
||||
],
|
||||
[0.181, -0.190, -0.245],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.forward.forget_gate = create_gate_controller(
|
||||
[[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
|
||||
[0.315, -0.413, -0.041],
|
||||
[
|
||||
[0.453, 0.063, 0.561],
|
||||
[0.211, 0.149, 0.213],
|
||||
[-0.499, -0.158, 0.068],
|
||||
],
|
||||
[-0.431, -0.535, 0.125],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.forward.cell_gate = create_gate_controller(
|
||||
[[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
|
||||
[-0.358, 0.282, -0.078],
|
||||
[
|
||||
[-0.358, 0.109, 0.139],
|
||||
[-0.345, 0.091, -0.368],
|
||||
[-0.508, 0.221, -0.507],
|
||||
],
|
||||
[0.502, -0.509, -0.247],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.forward.output_gate = create_gate_controller(
|
||||
[[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
|
||||
[-0.227, -0.274, 0.039],
|
||||
[
|
||||
[-0.383, 0.449, 0.222],
|
||||
[-0.357, -0.093, 0.449],
|
||||
[-0.106, 0.236, 0.360],
|
||||
],
|
||||
[-0.361, -0.209, -0.454],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.reverse.input_gate = create_gate_controller(
|
||||
[[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
|
||||
[0.540, -0.164, 0.033],
|
||||
[
|
||||
[0.159, 0.180, -0.037],
|
||||
[-0.443, 0.485, -0.488],
|
||||
[0.098, -0.085, -0.140],
|
||||
],
|
||||
[-0.510, 0.105, 0.114],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.reverse.forget_gate = create_gate_controller(
|
||||
[[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
|
||||
[0.141, 0.004, 0.055],
|
||||
[
|
||||
[-0.005, -0.277, -0.515],
|
||||
[-0.011, -0.101, -0.365],
|
||||
[0.426, 0.379, 0.337],
|
||||
],
|
||||
[-0.382, 0.331, -0.176],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.reverse.cell_gate = create_gate_controller(
|
||||
[[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
|
||||
[-0.206, -0.546, 0.462],
|
||||
[
|
||||
[0.449, -0.240, 0.071],
|
||||
[-0.045, 0.131, 0.124],
|
||||
[0.138, -0.201, 0.191],
|
||||
],
|
||||
[-0.030, 0.211, -0.352],
|
||||
&device,
|
||||
);
|
||||
|
||||
lstm.reverse.output_gate = create_gate_controller(
|
||||
[[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
|
||||
[-0.387, -0.250, 0.066],
|
||||
[
|
||||
[-0.030, 0.268, 0.299],
|
||||
[-0.019, -0.280, -0.314],
|
||||
[0.466, -0.365, -0.248],
|
||||
],
|
||||
[-0.398, -0.199, -0.566],
|
||||
&device,
|
||||
);
|
||||
|
||||
let expected_output_with_init_state = TensorData::from([[
|
||||
[0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
|
||||
[0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
|
||||
[0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
|
||||
[-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
|
||||
]]);
|
||||
let expected_output_without_init_state = TensorData::from([[
|
||||
[0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
|
||||
[-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
|
||||
[0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
|
||||
[-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
|
||||
]]);
|
||||
let expected_hn_with_init_state = TensorData::from([
|
||||
[[-0.03420, 0.07774, -0.09774]],
|
||||
[[-0.15635, -0.03366, -0.05798]],
|
||||
]);
|
||||
let expected_cn_with_init_state = TensorData::from([
|
||||
[[-0.13593, 0.17125, -0.22395]],
|
||||
[[-0.45425, -0.11206, -0.12908]],
|
||||
]);
|
||||
let expected_hn_without_init_state = TensorData::from([
|
||||
[[-0.04026, 0.07178, -0.10189]],
|
||||
[[-0.15969, -0.05322, -0.08863]],
|
||||
]);
|
||||
let expected_cn_without_init_state = TensorData::from([
|
||||
[[-0.15839, 0.15923, -0.23569]],
|
||||
[[-0.47407, -0.17493, -0.19643]],
|
||||
]);
|
||||
|
||||
let (output_with_init_state, state_with_init_state) =
|
||||
lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
|
||||
let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
output_with_init_state
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
|
||||
output_without_init_state
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
|
||||
state_with_init_state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
|
||||
state_with_init_state
|
||||
.cell
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);
|
||||
state_without_init_state
|
||||
.hidden
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
|
||||
state_without_init_state
|
||||
.cell
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_lstm() {
|
||||
let config = LstmConfig::new(2, 3, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_bilstm() {
|
||||
let config = BiLstmConfig::new(2, 3, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{layer}"),
|
||||
"BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
mod gate_controller;
|
||||
|
||||
/// Basic RNN.
|
||||
pub mod basic;
|
||||
|
||||
/// Gated Recurrent Unit module.
|
||||
pub mod gru;
|
||||
|
||||
/// Long Short-Term Memory module.
|
||||
pub mod lstm;
|
||||
|
||||
pub use basic::*;
|
||||
pub use gate_controller::*;
|
||||
pub use gru::*;
|
||||
pub use lstm::*;
|
||||
@@ -0,0 +1,581 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use alloc::vec;
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use burn::tensor::Int;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use core::ops::Range;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct RotaryEncodingConfig {
|
||||
/// Maximum sequence length of input
|
||||
pub max_sequence_length: usize,
|
||||
|
||||
/// Size of the input embedding or hidden dimension
|
||||
pub d_model: usize,
|
||||
|
||||
/// Scaling factor for frequency computation. Defaults to 10000.0
|
||||
#[config(default = "10000.0")]
|
||||
pub theta: f32,
|
||||
}
|
||||
|
||||
impl RotaryEncodingConfig {
|
||||
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
|
||||
self.initialize(|x| x, device)
|
||||
}
|
||||
|
||||
/// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
|
||||
/// This is useful to apply different RoPE extensions.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
pub fn init_with_frequency_scaling<B: Backend>(
|
||||
&self,
|
||||
scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
|
||||
device: &B::Device,
|
||||
) -> RotaryEncoding<B> {
|
||||
self.initialize(scaling, device)
|
||||
}
|
||||
|
||||
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
fn initialize<B: Backend>(
|
||||
&self,
|
||||
scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
|
||||
device: &B::Device,
|
||||
) -> RotaryEncoding<B> {
|
||||
assert_eq!(
|
||||
self.d_model % 2,
|
||||
0,
|
||||
"The input embedding dimension must be even"
|
||||
);
|
||||
assert!(
|
||||
self.theta > 0.0,
|
||||
"Theta parameter must be positive (default: 10000)."
|
||||
);
|
||||
|
||||
// Calculate the rotation frequencies for positional embeddings based on the formula
|
||||
// `theta = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
|
||||
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
|
||||
.float()
|
||||
.div_scalar(self.d_model as f32);
|
||||
|
||||
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
|
||||
// This is done since burn doesn't support exponentiation of scalar to tensor
|
||||
let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();
|
||||
|
||||
let theta = scaling(theta);
|
||||
|
||||
let freq_complex =
|
||||
RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());
|
||||
|
||||
RotaryEncoding {
|
||||
freq_complex,
|
||||
theta,
|
||||
start_offset: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A module that applies rotary positional encoding to a tensor.
|
||||
/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
|
||||
/// absolute positional information with rotation matrix and naturally incorporates
|
||||
/// explicit relative position dependency in self-attention formulation.
|
||||
///
|
||||
/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
|
||||
///
|
||||
/// Should be created using [RotaryEncodingConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct RotaryEncoding<B: Backend> {
|
||||
/// Complex frequency tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
|
||||
// Essentially a cache of pre-computed RoPE values.
|
||||
pub freq_complex: Tensor<B, 3>,
|
||||
/// Frequency vector used to compute/apply the complex rotations.
|
||||
pub theta: Tensor<B, 1>,
|
||||
start_offset: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("max_sequence_length", &max_sequence_length)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
impl<B: Backend> RotaryEncoding<B> {
|
||||
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
|
||||
///
|
||||
/// # Arguments:
|
||||
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
|
||||
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
|
||||
/// respectively.
|
||||
///
|
||||
/// # Returns:
|
||||
/// Output tensor with the same shape as input tensor after applying rotary encoding.
|
||||
///
|
||||
/// # Panics
|
||||
/// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
|
||||
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
self.apply(x, 0)
|
||||
}
|
||||
|
||||
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
|
||||
///
|
||||
/// # Arguments:
|
||||
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
|
||||
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
|
||||
/// respectively.
|
||||
/// * `start` - Sequence start position index.
|
||||
///
|
||||
/// # Returns:
|
||||
/// Output tensor with the same shape as input tensor after applying rotary encoding.
|
||||
///
|
||||
/// # Panics
|
||||
/// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
|
||||
pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
|
||||
assert!(
|
||||
D >= 2,
|
||||
"Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
|
||||
);
|
||||
|
||||
let device = x.device();
|
||||
let input_shape = x.shape();
|
||||
|
||||
// Extract the sequence length and embedding dimension, other dimensions are kept generic
|
||||
// to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
|
||||
let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
|
||||
let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
|
||||
|
||||
// Create a dummy tensor with signed ones based on the 2D rotation matrix
|
||||
// [[cos, -sin], [sin, cos]]
|
||||
let sign_tensor =
|
||||
Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
|
||||
|
||||
// Rotate input using the frequency tensor. Slice the frequencies till input sequence length
|
||||
let out: Tensor<B, 4> = x
|
||||
.reshape([dummy_dim_size, seq_len, d_model / 2, 2])
|
||||
.matmul(sign_tensor.unsqueeze())
|
||||
.reshape([dummy_dim_size, seq_len, d_model, 2])
|
||||
* self
|
||||
.freq_complex
|
||||
.clone()
|
||||
.slice([start..start + seq_len])
|
||||
.unsqueeze();
|
||||
|
||||
// Sum the real and imaginary components to get output tensor and reshape to original shape
|
||||
out.sum_dim(-1).reshape(input_shape)
|
||||
}
|
||||
|
||||
/// Shifts the pre-computed rotary frequency to cover a new range of positions.
|
||||
///
|
||||
/// This method updates the internal frequency tensor `freq_complex` to store
|
||||
/// the rotary positional encodings for a new window of positions starting at `start`.
|
||||
pub fn shift(&mut self, start: usize) {
|
||||
let max_seq_len = self.freq_complex.dims()[0];
|
||||
assert!(
|
||||
start > self.start_offset,
|
||||
"Shift start position must be monotonically increasing"
|
||||
);
|
||||
|
||||
let current_end = self.start_offset + max_seq_len;
|
||||
|
||||
if start >= current_end {
|
||||
// Overwrite the whole buffer
|
||||
let new_freqs =
|
||||
Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());
|
||||
self.freq_complex
|
||||
.inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));
|
||||
} else {
|
||||
// Shift the tail
|
||||
let num_keep = current_end - start;
|
||||
let start_rel = start - self.start_offset;
|
||||
let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);
|
||||
self.freq_complex
|
||||
.inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));
|
||||
// Compute the rest and assign
|
||||
let new_freqs = Self::compute_rotary_frequencies(
|
||||
current_end..start + max_seq_len,
|
||||
self.theta.clone(),
|
||||
);
|
||||
self.freq_complex
|
||||
.inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));
|
||||
}
|
||||
self.start_offset = start;
|
||||
}
|
||||
|
||||
/// Computes the positional rotation frequencies (cosine and sine values) used in RoPE.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `range`: Range of position indices `[start, end)`.
|
||||
/// - `theta`: 1D tensor of shape `(d_model / 2)` containing base angular frequencies.
|
||||
///
|
||||
/// # Returns
|
||||
/// Tensor of shape `(range.len(), d_model, 2)` containing `[cos, sin]` pairs for each position and frequency.
|
||||
fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {
|
||||
let d_model = theta.dims()[0] * 2;
|
||||
let num_positions = range.end - range.start;
|
||||
|
||||
// Generate frequency values for positional embeddings
|
||||
let frequencies: Tensor<B, 2> =
|
||||
Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())
|
||||
.float()
|
||||
.unsqueeze()
|
||||
.transpose()
|
||||
.repeat_dim(1, d_model / 2)
|
||||
* theta.unsqueeze();
|
||||
|
||||
// Convert frequency values to complex numbers (polar form)
|
||||
let p_cos = frequencies.clone().cos();
|
||||
let p_sin = frequencies.sin();
|
||||
|
||||
Tensor::cat(vec![p_cos, p_sin], 1)
|
||||
.reshape([num_positions, 2, d_model / 2])
|
||||
.transpose()
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat_dim(2, 2)
|
||||
.reshape([num_positions, d_model, 2])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_forward() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
// Input = [Batch size, Num of heads, Seq_len, d_model]
|
||||
let input = input.unsqueeze::<4>();
|
||||
|
||||
let output = rotary_encoding.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 2.0000, 3.0000, 4.0000],
|
||||
[-2.3473, 7.4492, 6.9197, 8.0696],
|
||||
],
|
||||
[
|
||||
[9.0000, 10.0000, 11.0000, 12.0000],
|
||||
[-4.7567, 18.5034, 14.8393, 16.1492],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.squeeze_dim::<3>(0)
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_3d() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
// Input = [Batch size, Num of heads, Seq_len, d_model]
|
||||
// let input = input.unsqueeze::<4>();
|
||||
|
||||
let output = rotary_encoding.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 2.0000, 3.0000, 4.0000],
|
||||
[-2.3473, 7.4492, 6.9197, 8.0696],
|
||||
],
|
||||
[
|
||||
[9.0000, 10.0000, 11.0000, 12.0000],
|
||||
[-4.7567, 18.5034, 14.8393, 16.1492],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_input_rotary_encoding_forward() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
// Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
|
||||
let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
|
||||
|
||||
let output = rotary_encoding.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
],
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.squeeze_dim::<3>(0)
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_valid_input_hidden_dim() {
|
||||
// Hidden dimension must be even to be able to split into real and imaginary components
|
||||
// for rotation
|
||||
let d_model = 15;
|
||||
let device = Default::default();
|
||||
let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
|
||||
let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
|
||||
let _output = pe.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_frequencies() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
|
||||
|
||||
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
],
|
||||
[
|
||||
[5.4030e-01, 8.4147e-01],
|
||||
[9.9500e-01, 9.9833e-02],
|
||||
[9.9995e-01, 9.9998e-03],
|
||||
[9.9999e-01, 9.9999e-04],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat_dim(2, 2)
|
||||
.reshape([2, 8, 2]);
|
||||
|
||||
rotary_encoding
|
||||
.freq_complex
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
|
||||
// Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
|
||||
let scale_factor = 8.;
|
||||
let low_freq_factor = 1.;
|
||||
let high_freq_factor = 4.;
|
||||
let old_context_len = 8192.;
|
||||
|
||||
let low_freq_wavelen = old_context_len / low_freq_factor;
|
||||
let high_freq_wavelen = old_context_len / high_freq_factor;
|
||||
|
||||
let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
|
||||
|
||||
// if wavelen >= high_freq_wavelen
|
||||
let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
|
||||
let smooth = wavelen
|
||||
.clone()
|
||||
.recip()
|
||||
.mul_scalar(old_context_len)
|
||||
.sub_scalar(low_freq_factor)
|
||||
.div_scalar(high_freq_factor - low_freq_factor);
|
||||
// (1 - smooth) * freq / scale_factor + smooth * freq
|
||||
let new_freqs = smooth
|
||||
.clone()
|
||||
.neg()
|
||||
.add_scalar(1.)
|
||||
.mul(freqs.clone().div_scalar(scale_factor))
|
||||
.add(smooth.clone().mul(freqs.clone()));
|
||||
let new_freqs = freqs.clone().mask_where(cond, new_freqs);
|
||||
|
||||
// if wavelen > low_freq_wavelen
|
||||
let cond = wavelen.clone().greater_elem(low_freq_wavelen);
|
||||
let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
|
||||
|
||||
// if wavelen < high_freq_wavelen
|
||||
let cond = wavelen.lower_elem(high_freq_wavelen);
|
||||
new_freqs.mask_where(cond, freqs)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_with_frequency_scaling() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(2, 8)
|
||||
.init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
|
||||
|
||||
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
],
|
||||
[
|
||||
[5.4030e-01, 8.4148e-01],
|
||||
[9.9500e-01, 9.9833e-02],
|
||||
[9.9995e-01, 9.9998e-03],
|
||||
[1.0000, 2.1361e-04],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat_dim(2, 2)
|
||||
.reshape([2, 8, 2]);
|
||||
|
||||
rotary_encoding
|
||||
.freq_complex
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_shift_full() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
// Input = [Batch size, Num of heads, Seq_len, d_model]
|
||||
let input = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.unsqueeze::<4>();
|
||||
|
||||
// Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
|
||||
// as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
|
||||
// initial position
|
||||
let expected_output = rotary_encoding.apply(input.clone(), 6);
|
||||
|
||||
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
|
||||
rotary_encoding.shift(6); // start > 4 will perform a full re-compute
|
||||
|
||||
let output = rotary_encoding.apply(input, 0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_shift() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
// Input = [Batch size, Num of heads, Seq_len, d_model]
|
||||
let input = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.unsqueeze::<4>();
|
||||
|
||||
// Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
|
||||
// as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
|
||||
// initial position
|
||||
let expected_output = rotary_encoding.apply(input.clone(), 2);
|
||||
|
||||
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
|
||||
rotary_encoding.shift(2); // start < 4 will shift the (current_end - start) freqs and compute the rest
|
||||
|
||||
let output = rotary_encoding.apply(input, 0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_shift_multiple() {
|
||||
let device = Default::default();
|
||||
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
|
||||
rotary_encoding.shift(2);
|
||||
rotary_encoding.shift(5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Shift start position must be monotonically increasing"]
|
||||
fn test_rotary_encoding_shift_should_increase() {
|
||||
let device = Default::default();
|
||||
let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
|
||||
rotary_encoding.shift(6);
|
||||
rotary_encoding.shift(4); // should be monotonically increasing
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = RotaryEncodingConfig::new(10, 4);
|
||||
let pe = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{pe}"),
|
||||
"RotaryEncoding {d_model: 4, max_sequence_length: 10}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::{Bool, Tensor, backend::Backend};
|
||||
|
||||
use crate::activation::ActivationConfig;
|
||||
use crate::cache::TensorCache;
|
||||
use crate::{
|
||||
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
|
||||
attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
|
||||
};
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
|
||||
/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct TransformerDecoderConfig {
|
||||
/// The size of the model.
|
||||
pub d_model: usize,
|
||||
/// The size of the position-wise feed-forward network.
|
||||
pub d_ff: usize,
|
||||
/// The number of attention heads.
|
||||
pub n_heads: usize,
|
||||
/// The number of layers.
|
||||
pub n_layers: usize,
|
||||
/// The dropout rate. Default: 0.1
|
||||
#[config(default = 0.1)]
|
||||
pub dropout: f64,
|
||||
/// Layer norm will be applied first instead of after the other modules.
|
||||
#[config(default = false)]
|
||||
pub norm_first: bool,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
|
||||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
pub quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
/// The activation function used in the position-wise feed-forward network. Default: Gelu
|
||||
#[config(default = "ActivationConfig::Gelu")]
|
||||
pub activation: ActivationConfig,
|
||||
/// The epsilon value for layer normalization. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - layers: transformer decoder layers with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created using [TransformerDecoderConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct TransformerDecoder<B: Backend> {
|
||||
/// Transformer decoder layers.
|
||||
pub layers: Vec<TransformerDecoderLayer<B>>,
|
||||
|
||||
/// The size of the model.
|
||||
pub d_model: usize,
|
||||
|
||||
/// The size of the position-wise feed-forward network.
|
||||
pub d_ff: usize,
|
||||
|
||||
/// The number of attention heads.
|
||||
pub n_heads: usize,
|
||||
|
||||
/// The number of layers.
|
||||
pub n_layers: usize,
|
||||
|
||||
/// The dropout rate. Default: 0.1
|
||||
pub dropout: f64,
|
||||
|
||||
/// Layer norm will be applied first instead of after the other modules.
|
||||
pub norm_first: bool,
|
||||
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("d_model", &self.d_model)
|
||||
.add("d_ff", &self.d_ff)
|
||||
.add("n_heads", &self.n_heads)
|
||||
.add("n_layers", &self.n_layers)
|
||||
.add("dropout", &self.dropout)
|
||||
.add("norm_first", &self.norm_first)
|
||||
.add("quiet_softmax", &self.quiet_softmax)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerDecoderConfig {
|
||||
/// Initialize a new [Transformer Decoder](TransformerDecoder) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
|
||||
let layers = (0..self.n_layers)
|
||||
.map(|_| TransformerDecoderLayer::new(self, device))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
TransformerDecoder {
|
||||
layers,
|
||||
d_model: self.d_model,
|
||||
d_ff: self.d_ff,
|
||||
n_heads: self.n_heads,
|
||||
n_layers: self.n_layers,
|
||||
dropout: self.dropout,
|
||||
norm_first: self.norm_first,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [Transformer Decoder](TransformerDecoder) forward pass input argument.
|
||||
#[derive(Debug)]
|
||||
pub struct TransformerDecoderInput<B: Backend> {
|
||||
target: Tensor<B, 3>,
|
||||
target_mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
target_mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
memory: Tensor<B, 3>,
|
||||
memory_mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
memory_mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerDecoderInput<B> {
|
||||
/// Create a [transformer decoder](TransformerDecoder) input argument.
|
||||
pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
|
||||
Self {
|
||||
target,
|
||||
target_mask_pad: None,
|
||||
target_mask_attn: None,
|
||||
memory,
|
||||
memory_mask_pad: None,
|
||||
memory_mask_attn: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the memory padding mask.
|
||||
pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
|
||||
self.memory_mask_pad = Some(mask_pad);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register the memory attention mask.
|
||||
pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
|
||||
self.memory_mask_attn = Some(mask_attn);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register the target padding mask.
|
||||
pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
|
||||
self.target_mask_pad = Some(mask_pad);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register the target attention mask.
|
||||
pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
|
||||
self.target_mask_attn = Some(mask_attn);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// [Transformer Decoder](TransformerDecoder) layer module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerDecoderLayer<B: Backend> {
|
||||
/// Cross-attention module.
|
||||
pub cross_attn: MultiHeadAttention<B>,
|
||||
/// Self-attention module.
|
||||
pub self_attn: MultiHeadAttention<B>,
|
||||
/// Position-wise feed-forward module.
|
||||
pub pwff: PositionWiseFeedForward<B>,
|
||||
/// First layer norm.
|
||||
pub norm_1: LayerNorm<B>,
|
||||
/// Second layer norm.
|
||||
pub norm_2: LayerNorm<B>,
|
||||
/// Third layer norm.
|
||||
pub norm_3: LayerNorm<B>,
|
||||
/// Dropout.
|
||||
pub dropout: Dropout,
|
||||
/// Whether to apply norm first.
|
||||
pub norm_first: bool,
|
||||
}
|
||||
|
||||
/// Autoregressive cache for a single [Transformer Decoder Layer](TransformerDecoderLayer).
|
||||
pub struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
|
||||
/// Cross-attention cache.
|
||||
pub cross_attn: MhaCache<B>,
|
||||
/// Self-attention cache.
|
||||
pub self_attn: MhaCache<B>,
|
||||
/// Position-wise feed-forward cache.
|
||||
pub pwff: TensorCache<B, 3>,
|
||||
/// First layer norm cache.
|
||||
pub norm_1: TensorCache<B, 3>,
|
||||
/// Second layer norm cache.
|
||||
pub norm_2: TensorCache<B, 3>,
|
||||
/// Third layer norm cache.
|
||||
pub norm_3: TensorCache<B, 3>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
|
||||
/// Create an empty cache.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
cross_attn: MhaCache::autoregressive_cross_attention(),
|
||||
self_attn: MhaCache::autoregressive(),
|
||||
pwff: TensorCache::empty(),
|
||||
norm_1: TensorCache::empty(),
|
||||
norm_2: TensorCache::empty(),
|
||||
norm_3: TensorCache::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer.
|
||||
///
|
||||
/// To be used during inference when decoding tokens.
|
||||
pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
|
||||
layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
|
||||
fn empty(num_layers: usize) -> Self {
|
||||
Self {
|
||||
layers: (0..num_layers)
|
||||
.map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerDecoderLayer<B> {
|
||||
/// Create a new [TransformerDecoderLayer](TransformerDecoderLayer).
|
||||
pub fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
|
||||
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init(device);
|
||||
|
||||
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init(device);
|
||||
let norm_1 = LayerNormConfig::new(config.d_model)
|
||||
.with_epsilon(config.layer_norm_eps)
|
||||
.init(device);
|
||||
let norm_2 = LayerNormConfig::new(config.d_model)
|
||||
.with_epsilon(config.layer_norm_eps)
|
||||
.init(device);
|
||||
let norm_3 = LayerNormConfig::new(config.d_model)
|
||||
.with_epsilon(config.layer_norm_eps)
|
||||
.init(device);
|
||||
let dropout = DropoutConfig::new(config.dropout).init();
|
||||
let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_activation(config.activation.clone())
|
||||
.init(device);
|
||||
|
||||
Self {
|
||||
cross_attn,
|
||||
self_attn,
|
||||
norm_1,
|
||||
norm_2,
|
||||
norm_3,
|
||||
pwff,
|
||||
dropout,
|
||||
norm_first: config.norm_first,
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the TransformerDecoder forward pass to the input tensor.
|
||||
pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
|
||||
// Self attention residual path.
|
||||
let x = input.target;
|
||||
let mut residual_path = x.clone();
|
||||
|
||||
// Normalize.
|
||||
if self.norm_first {
|
||||
residual_path = self.norm_3.forward(residual_path);
|
||||
}
|
||||
|
||||
// Self attention.
|
||||
let mut self_attn_input = MhaInput::self_attn(residual_path);
|
||||
if let Some(mask_pad) = &input.target_mask_pad {
|
||||
self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
|
||||
}
|
||||
if let Some(mask_attn) = &input.target_mask_attn {
|
||||
self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
|
||||
}
|
||||
let residual_path = self.self_attn.forward(self_attn_input).context;
|
||||
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Cross attention residual path.
|
||||
// Normalize.
|
||||
let residual_path = if self.norm_first {
|
||||
self.norm_1.forward(x.clone())
|
||||
} else {
|
||||
x = self.norm_1.forward(x);
|
||||
x.clone()
|
||||
};
|
||||
|
||||
// Cross attention.
|
||||
let mut cross_attn_input =
|
||||
MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
|
||||
if let Some(mask_pad) = &input.memory_mask_pad {
|
||||
cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
|
||||
}
|
||||
if let Some(mask_attn) = &input.memory_mask_attn {
|
||||
cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
|
||||
}
|
||||
let residual_path = self.cross_attn.forward(cross_attn_input).context;
|
||||
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Feed forward residual path.
|
||||
// Normalize.
|
||||
let residual_path = if self.norm_first {
|
||||
self.norm_2.forward(x.clone())
|
||||
} else {
|
||||
x = self.norm_2.forward(x);
|
||||
x.clone()
|
||||
};
|
||||
|
||||
let residual_path = self.pwff.forward(residual_path);
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Main path.
|
||||
// Normalize.
|
||||
if !self.norm_first {
|
||||
x = self.norm_3.forward(x)
|
||||
}
|
||||
|
||||
input.target = x;
|
||||
input
|
||||
}
|
||||
|
||||
/// Applies the forward pass using an autoregressive cache.
|
||||
pub fn forward_autoregressive_inference(
|
||||
&self,
|
||||
mut input: TransformerDecoderInput<B>,
|
||||
cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
|
||||
) -> TransformerDecoderInput<B> {
|
||||
// Self attention residual path.
|
||||
let x = input.target;
|
||||
let mut residual_path = x.clone();
|
||||
|
||||
// Normalize.
|
||||
if self.norm_first {
|
||||
residual_path = cache
|
||||
.norm_3
|
||||
.forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
|
||||
}
|
||||
|
||||
// Self attention.
|
||||
let mut self_attn_input = MhaInput::self_attn(residual_path);
|
||||
if let Some(mask_pad) = &input.target_mask_pad {
|
||||
self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
|
||||
}
|
||||
if let Some(mask_attn) = &input.target_mask_attn {
|
||||
self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
|
||||
}
|
||||
let residual_path = self
|
||||
.self_attn
|
||||
.forward_cache(self_attn_input, &mut cache.self_attn)
|
||||
.context;
|
||||
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Cross attention residual path.
|
||||
// Normalize.
|
||||
let residual_path = if self.norm_first {
|
||||
cache
|
||||
.norm_1
|
||||
.forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
|
||||
} else {
|
||||
x = cache
|
||||
.norm_1
|
||||
.forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
|
||||
x.clone()
|
||||
};
|
||||
|
||||
// Cross attention.
|
||||
let mut cross_attn_input =
|
||||
MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
|
||||
if let Some(mask_pad) = &input.memory_mask_pad {
|
||||
cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
|
||||
}
|
||||
if let Some(mask_attn) = &input.memory_mask_attn {
|
||||
cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
|
||||
}
|
||||
let residual_path = self
|
||||
.cross_attn
|
||||
.forward_cache(cross_attn_input, &mut cache.cross_attn)
|
||||
.context;
|
||||
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Feed forward residual path.
|
||||
// Normalize.
|
||||
let residual_path = if self.norm_first {
|
||||
cache
|
||||
.norm_2
|
||||
.forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
|
||||
} else {
|
||||
x = cache
|
||||
.norm_2
|
||||
.forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
|
||||
x.clone()
|
||||
};
|
||||
|
||||
let residual_path = cache
|
||||
.pwff
|
||||
.forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Main path.
|
||||
// Normalize.
|
||||
if !self.norm_first {
|
||||
x = cache
|
||||
.norm_3
|
||||
.forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
|
||||
}
|
||||
|
||||
input.target = x;
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerDecoder<B> {
|
||||
/// Applies the forward pass.
|
||||
pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
|
||||
for layer in self.layers.iter() {
|
||||
input = layer.forward(input);
|
||||
}
|
||||
|
||||
input.target
|
||||
}
|
||||
|
||||
/// Applies the forward pass on the input using autoregressive cache.
|
||||
pub fn forward_autoregressive_inference(
|
||||
&self,
|
||||
mut input: TransformerDecoderInput<B>,
|
||||
cache: &mut TransformerDecoderAutoregressiveCache<B>,
|
||||
) -> Tensor<B, 3> {
|
||||
for i in 0..self.layers.len() {
|
||||
let layer = self.layers.get(i).unwrap();
|
||||
let cache = cache.layers.get_mut(i).unwrap();
|
||||
|
||||
input = layer.forward_autoregressive_inference(input, cache);
|
||||
}
|
||||
|
||||
input.target
|
||||
}
|
||||
/// Create an empty autoregressive cache.
|
||||
pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
|
||||
TransformerDecoderAutoregressiveCache::empty(self.layers.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::Device;
|
||||
|
||||
use super::*;
|
||||
use crate::{TestBackend, attention::generate_autoregressive_mask};
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_norm_last() {
|
||||
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
test_autoregressive(
|
||||
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
|
||||
.with_norm_first(false),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_norm_first() {
|
||||
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
|
||||
let device = Default::default();
|
||||
TestBackend::seed(&device, 0);
|
||||
|
||||
test_autoregressive(
|
||||
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
|
||||
)
|
||||
}
|
||||
|
||||
fn test_autoregressive(config: TransformerDecoderConfig) {
|
||||
let device: Device<TestBackend> = Default::default();
|
||||
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
|
||||
let transformer = config.init::<TestBackend>(&device);
|
||||
|
||||
let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
|
||||
.float()
|
||||
.reshape([batch_size, seq_length, d_model]);
|
||||
let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
|
||||
.float()
|
||||
.reshape([batch_size, seq_length, d_model]);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
|
||||
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
|
||||
.target_mask_attn(mask_attn);
|
||||
|
||||
// Normal forward using masking.
|
||||
let output_1 = transformer.forward(input);
|
||||
|
||||
// Forward using the autoregressive cache.
|
||||
let mut output_2 = Vec::new();
|
||||
let mut cache = transformer.new_autoregressive_cache();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
|
||||
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
|
||||
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
|
||||
.target_mask_attn(mask_attn);
|
||||
let next_tok = transformer // Greedy sampling
|
||||
.forward_autoregressive_inference(input, &mut cache)
|
||||
.slice([0..batch_size, i - 1..i, 0..d_model]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
let output_2 = Tensor::cat(output_2, 1);
|
||||
|
||||
// Should produce the same tokens.
|
||||
let tolerance = Tolerance::rel_abs(5e-3, 1e-4);
|
||||
output_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&output_2.into_data(), tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = TransformerDecoderConfig::new(2, 4, 2, 3);
|
||||
let transformer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{transformer}"),
|
||||
"TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
|
||||
dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,489 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
use crate::{
|
||||
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
|
||||
activation::ActivationConfig,
|
||||
attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
|
||||
cache::TensorCache,
|
||||
};
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::{Bool, Tensor, backend::Backend};
|
||||
|
||||
/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct TransformerEncoderConfig {
|
||||
/// The size of the model.
|
||||
pub d_model: usize,
|
||||
/// The size of the position-wise feed-forward network.
|
||||
pub d_ff: usize,
|
||||
/// The number of attention heads.
|
||||
pub n_heads: usize,
|
||||
/// The number of layers.
|
||||
pub n_layers: usize,
|
||||
/// The dropout rate. Default: 0.1
|
||||
#[config(default = 0.1)]
|
||||
pub dropout: f64,
|
||||
/// Layer norm will be applied first instead of after the other modules.
|
||||
#[config(default = false)]
|
||||
pub norm_first: bool,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
|
||||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
pub quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
/// The activation function used in the position-wise feed-forward network. Default: Gelu
|
||||
#[config(default = "ActivationConfig::Gelu")]
|
||||
pub activation: ActivationConfig,
|
||||
/// The epsilon value for layer normalization. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - layers: transformer encoder layers with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created using [TransformerEncoderConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct TransformerEncoder<B: Backend> {
|
||||
/// The transformer encoder layers.
|
||||
pub layers: Vec<TransformerEncoderLayer<B>>,
|
||||
|
||||
/// The size of the model.
|
||||
pub d_model: usize,
|
||||
|
||||
/// The size of the position-wise feed-forward network.
|
||||
pub d_ff: usize,
|
||||
|
||||
/// The number of attention heads.
|
||||
pub n_heads: usize,
|
||||
|
||||
/// The number of layers.
|
||||
pub n_layers: usize,
|
||||
|
||||
/// The dropout rate. Default: 0.1
|
||||
pub dropout: f64,
|
||||
|
||||
/// Layer norm will be applied first instead of after the other modules.
|
||||
pub norm_first: bool,
|
||||
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("d_model", &self.d_model)
|
||||
.add("d_ff", &self.d_ff)
|
||||
.add("n_heads", &self.n_heads)
|
||||
.add("n_layers", &self.n_layers)
|
||||
.add("dropout", &self.dropout)
|
||||
.add("norm_first", &self.norm_first)
|
||||
.add("quiet_softmax", &self.quiet_softmax)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
|
||||
#[derive(Debug)]
|
||||
pub struct TransformerEncoderInput<B: Backend> {
|
||||
tensor: Tensor<B, 3>,
|
||||
mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoderInput<B> {
|
||||
/// Create a [transformer encoder](TransformerEncoder) input argument.
|
||||
pub fn new(tensor: Tensor<B, 3>) -> Self {
|
||||
Self {
|
||||
tensor,
|
||||
mask_pad: None,
|
||||
mask_attn: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the padding mask.
|
||||
pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
|
||||
self.mask_pad = Some(mask_pad);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register the attention mask.
|
||||
pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
|
||||
self.mask_attn = Some(mask_attn);
|
||||
self
|
||||
}
|
||||
}
|
||||
impl TransformerEncoderConfig {
|
||||
/// Initialize a new [transformer encoder](TransformerEncoder) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
|
||||
let layers = (0..self.n_layers)
|
||||
.map(|_| TransformerEncoderLayer::new(self, device))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
TransformerEncoder {
|
||||
layers,
|
||||
d_model: self.d_model,
|
||||
d_ff: self.d_ff,
|
||||
n_heads: self.n_heads,
|
||||
n_layers: self.n_layers,
|
||||
dropout: self.dropout,
|
||||
norm_first: self.norm_first,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoder<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - tensor: `[batch_size, seq_length, d_model]`
|
||||
/// - output: `[batch_size, seq_length, d_model]`
|
||||
pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
|
||||
let mut x = input.tensor;
|
||||
|
||||
for layer in self.layers.iter() {
|
||||
x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
/// Applies the forward pass on the input tensor using autoregressive cache.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - tensor: `[batch_size, seq_length, d_model]`
|
||||
/// - output: `[batch_size, seq_length, d_model]`
|
||||
pub fn forward_autoregressive_inference(
|
||||
&self,
|
||||
input: TransformerEncoderInput<B>,
|
||||
cache: &mut TransformerEncoderAutoregressiveCache<B>,
|
||||
) -> Tensor<B, 3> {
|
||||
let mut x = input.tensor;
|
||||
|
||||
for i in 0..self.layers.len() {
|
||||
let layer = self.layers.get(i).unwrap();
|
||||
let cache = cache.layers.get_mut(i).unwrap();
|
||||
|
||||
x = layer.forward_autoregressive_inference(
|
||||
x,
|
||||
input.mask_pad.clone(),
|
||||
input.mask_attn.clone(),
|
||||
cache,
|
||||
);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Create an empty autoregressive cache.
|
||||
pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
|
||||
TransformerEncoderAutoregressiveCache::empty(self.layers.len())
|
||||
}
|
||||
}
|
||||
|
||||
/// Transformer encoder layer module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerEncoderLayer<B: Backend> {
|
||||
/// Multi-head self-attention sub-layer.
|
||||
pub mha: MultiHeadAttention<B>,
|
||||
/// Position-wise feed-forward sub-layer.
|
||||
pub pwff: PositionWiseFeedForward<B>,
|
||||
/// Layer normalization applied around the feed-forward sub-layer.
|
||||
pub norm_1: LayerNorm<B>,
|
||||
/// Layer normalization applied around the attention sub-layer.
|
||||
pub norm_2: LayerNorm<B>,
|
||||
/// Dropout module applied to residual connections.
|
||||
pub dropout: Dropout,
|
||||
/// If `true`, apply layer normalization before sub-layers (pre-norm),
|
||||
/// otherwise apply it after (post-norm).
|
||||
pub norm_first: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoderLayer<B> {
|
||||
/// Create a new transformer encoder layer from the given configuration.
|
||||
pub fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
|
||||
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init(device);
|
||||
let norm_1 = LayerNormConfig::new(config.d_model)
|
||||
.with_epsilon(config.layer_norm_eps)
|
||||
.init(device);
|
||||
let norm_2 = LayerNormConfig::new(config.d_model)
|
||||
.with_epsilon(config.layer_norm_eps)
|
||||
.init(device);
|
||||
let dropout = DropoutConfig::new(config.dropout).init();
|
||||
let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_activation(config.activation.clone())
|
||||
.init(device);
|
||||
|
||||
Self {
|
||||
mha,
|
||||
norm_1,
|
||||
norm_2,
|
||||
pwff,
|
||||
dropout,
|
||||
norm_first: config.norm_first,
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, seq_length, d_model]`
|
||||
/// - output: `[batch_size, seq_length, d_model]`
|
||||
pub fn forward(
|
||||
&self,
|
||||
input: Tensor<B, 3>,
|
||||
mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
) -> Tensor<B, 3> {
|
||||
// Multi-head attention residual path.
|
||||
let x = input;
|
||||
let mut residual_path = x.clone();
|
||||
|
||||
// Normalize.
|
||||
if self.norm_first {
|
||||
residual_path = self.norm_2.forward(residual_path)
|
||||
}
|
||||
|
||||
// Multi-head attention.
|
||||
let mut input_mhs = MhaInput::self_attn(residual_path);
|
||||
if let Some(mask_pad) = mask_pad {
|
||||
input_mhs = input_mhs.mask_pad(mask_pad);
|
||||
}
|
||||
if let Some(mask_attn) = mask_attn {
|
||||
input_mhs = input_mhs.mask_attn(mask_attn);
|
||||
}
|
||||
let residual_path = self.mha.forward(input_mhs).context;
|
||||
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Feed forward residual path.
|
||||
// Normalize.
|
||||
let residual_path = if self.norm_first {
|
||||
self.norm_1.forward(x.clone())
|
||||
} else {
|
||||
x = self.norm_1.forward(x);
|
||||
x.clone()
|
||||
};
|
||||
|
||||
// Feed forward.
|
||||
let residual_path = self.pwff.forward(residual_path);
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Main path.
|
||||
// Normalize.
|
||||
if !self.norm_first {
|
||||
x = self.norm_2.forward(x)
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Applies the forward pass using an autoregressive cache.
|
||||
pub fn forward_autoregressive_inference(
|
||||
&self,
|
||||
input: Tensor<B, 3>,
|
||||
mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
|
||||
) -> Tensor<B, 3> {
|
||||
// Multi-head attention residual path.
|
||||
let x = input;
|
||||
let mut residual_path = x.clone();
|
||||
|
||||
// Normalize.
|
||||
if self.norm_first {
|
||||
residual_path = cache
|
||||
.norm_2
|
||||
.forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
|
||||
}
|
||||
|
||||
// Multi-head attention.
|
||||
let mut input_mhs = MhaInput::self_attn(residual_path);
|
||||
if let Some(mask_pad) = mask_pad {
|
||||
input_mhs = input_mhs.mask_pad(mask_pad);
|
||||
}
|
||||
if let Some(mask_attn) = mask_attn {
|
||||
input_mhs = input_mhs.mask_attn(mask_attn);
|
||||
}
|
||||
let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
|
||||
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Feed forward residual path.
|
||||
// Normalize.
|
||||
let residual_path = if self.norm_first {
|
||||
cache
|
||||
.norm_1
|
||||
.forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
|
||||
} else {
|
||||
x = cache
|
||||
.norm_1
|
||||
.forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
|
||||
x.clone()
|
||||
};
|
||||
|
||||
// Feed forward.
|
||||
let residual_path = cache
|
||||
.pwff
|
||||
.forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
|
||||
let residual_path = self.dropout.forward(residual_path);
|
||||
let mut x = x + residual_path;
|
||||
|
||||
// Main path.
|
||||
// Normalize.
|
||||
if !self.norm_first {
|
||||
x = cache
|
||||
.norm_2
|
||||
.forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
/// Autoregressive cache for a single [Transformer Encoder Layer](TransformerEncoderLayer).
|
||||
pub struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
|
||||
/// Multi-head attention cache.
|
||||
pub mha: MhaCache<B>,
|
||||
/// Position-wise feed-forward cache.
|
||||
pub pwff: TensorCache<B, 3>,
|
||||
/// First layer norm cache.
|
||||
pub norm_1: TensorCache<B, 3>,
|
||||
/// Second layer norm cache.
|
||||
pub norm_2: TensorCache<B, 3>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
|
||||
/// Create an empty cache.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
mha: MhaCache::autoregressive(),
|
||||
pwff: TensorCache::empty(),
|
||||
norm_1: TensorCache::empty(),
|
||||
norm_2: TensorCache::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
|
||||
///
|
||||
/// To be used during inference when decoding tokens.
|
||||
pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
|
||||
layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
|
||||
fn empty(num_layers: usize) -> Self {
|
||||
Self {
|
||||
layers: (0..num_layers)
|
||||
.map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{TestBackend, attention::generate_autoregressive_mask};
|
||||
use burn::tensor::Distribution;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_norm_last() {
|
||||
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
|
||||
test_autoregressive(
|
||||
TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
|
||||
.with_norm_first(false),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_norm_first() {
|
||||
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
|
||||
test_autoregressive(
|
||||
TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
|
||||
)
|
||||
}
|
||||
|
||||
fn test_autoregressive(config: TransformerEncoderConfig) {
|
||||
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
|
||||
let device = Default::default();
|
||||
let transformer = config.init(&device);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Default,
|
||||
&device,
|
||||
);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
|
||||
let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
|
||||
|
||||
let output_1 = transformer.forward(input);
|
||||
let mut output_2 = Vec::new();
|
||||
let mut cache = transformer.new_autoregressive_cache();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
|
||||
let input = TransformerEncoderInput::new(tensor.clone());
|
||||
let next_tok = transformer
|
||||
.forward_autoregressive_inference(input, &mut cache)
|
||||
.slice([0..batch_size, i - 1..i, 0..d_model]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
let output_2 = Tensor::cat(output_2, 1);
|
||||
|
||||
output_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = TransformerEncoderConfig::new(2, 4, 2, 3);
|
||||
let transformer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{transformer}"),
|
||||
"TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
|
||||
n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod decoder;
|
||||
mod encoder;
|
||||
mod pwff;
|
||||
|
||||
pub use decoder::*;
|
||||
pub use encoder::*;
|
||||
pub use pwff::*;
|
||||
@@ -0,0 +1,117 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::activation::{Activation, ActivationConfig};
|
||||
use crate::{Dropout, DropoutConfig, Linear, LinearConfig};
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PositionWiseFeedForwardConfig {
|
||||
/// The size of the input and output features.
|
||||
pub d_model: usize,
|
||||
/// The size of the hidden inner features.
|
||||
pub d_ff: usize,
|
||||
/// The dropout rate. Default: 0.1
|
||||
#[config(default = 0.1)]
|
||||
pub dropout: f64,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
/// The activation function used between the two linear layers. Default: Gelu
|
||||
#[config(default = "ActivationConfig::Gelu")]
|
||||
pub activation: ActivationConfig,
|
||||
}
|
||||
|
||||
/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.
|
||||
/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
|
||||
///
|
||||
/// `FFN(x) = max(0, xW1 + b1)W2 + b2`
|
||||
///
|
||||
/// Should be created using [PositionWiseFeedForwardConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PositionWiseFeedForward<B: Backend> {
|
||||
/// Linear layer with `d_model` input features and `d_ff` output features.
|
||||
pub linear_inner: Linear<B>,
|
||||
/// Linear layer with `d_ff` input features and `d_model` output features.
|
||||
pub linear_outer: Linear<B>,
|
||||
/// Dropout layer.
|
||||
pub dropout: Dropout,
|
||||
/// Activation function.
|
||||
pub activation: Activation<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_model, dff] = self.linear_inner.weight.shape().dims();
|
||||
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("d_ff", &dff)
|
||||
.add("prob", &self.dropout.prob)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl PositionWiseFeedForwardConfig {
|
||||
/// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
|
||||
PositionWiseFeedForward {
|
||||
linear_inner: LinearConfig::new(self.d_model, self.d_ff)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.init(device),
|
||||
linear_outer: LinearConfig::new(self.d_ff, self.d_model)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.init(device),
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
activation: self.activation.init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PositionWiseFeedForward<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - tensor: `[batch_size, seq_length, d_model]`
|
||||
/// - output: `[batch_size, seq_length, d_model]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = self.linear_inner.forward(input);
|
||||
let x = self.activation.forward(x);
|
||||
let x = self.dropout.forward(x);
|
||||
|
||||
self.linear_outer.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = PositionWiseFeedForwardConfig::new(2, 4);
|
||||
let pwff = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{pwff}"),
|
||||
"PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::module::unfold4d;
|
||||
use burn::tensor::ops::UnfoldOptions;
|
||||
|
||||
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Unfold4dConfig {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// The padding configuration.
|
||||
#[config(default = "[0, 0]")]
|
||||
pub padding: [usize; 2],
|
||||
}
|
||||
|
||||
/// Four-dimensional unfolding.
|
||||
///
|
||||
/// Should be created with [Unfold4dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Unfold4d {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// The padding configuration.
|
||||
pub padding: [usize; 2],
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Unfold4d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
|
||||
.add("stride", &alloc::format!("{:?}", &self.stride))
|
||||
.add("dilation", &alloc::format!("{:?}", &self.dilation))
|
||||
.add("padding", &alloc::format!("{:?}", &self.padding))
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Unfold4dConfig {
|
||||
/// Initializes a new [Unfold4d] module.
|
||||
pub fn init(&self) -> Unfold4d {
|
||||
Unfold4d {
|
||||
kernel_size: self.kernel_size,
|
||||
stride: self.stride,
|
||||
dilation: self.dilation,
|
||||
padding: self.padding,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unfold4d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [unfold4d](burn::tensor::module::unfold4d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// input: `[batch_size, channels_in, height, width]`
|
||||
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
|
||||
unfold4d(
|
||||
input,
|
||||
self.kernel_size,
|
||||
UnfoldOptions::new(self.stride, self.padding, self.dilation),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Unfold4dConfig::new([3, 3]);
|
||||
let unfold = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{unfold}"),
|
||||
"Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}"
|
||||
);
|
||||
}
|
||||
}
|
||||
247
crates/stable-diffusion-burn/burn-crates/burn-nn/src/padding.rs
Normal file
247
crates/stable-diffusion-burn/burn-crates/burn-nn/src/padding.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
|
||||
/// Calculate asymmetric padding for "same" convolution.
|
||||
/// Returns (start_padding, end_padding) where start is applied first (top/left).
|
||||
/// For odd total padding, the extra pad goes to the end (bottom/right) following ONNX convention.
|
||||
fn calculate_same_padding(kernel_size: usize, stride: usize, size_in: usize) -> (usize, usize) {
|
||||
let size_out = size_in.div_ceil(stride); // ceil division for same padding
|
||||
let total_padding = if size_out > 0 {
|
||||
let needed = (size_out - 1) * stride + kernel_size;
|
||||
needed.saturating_sub(size_in)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let pad_start = total_padding / 2;
|
||||
let pad_end = total_padding - pad_start;
|
||||
(pad_start, pad_end)
|
||||
}
|
||||
|
||||
/// Padding configuration for 1D operators.
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig1d {
|
||||
/// Dynamically calculates padding to ensure output size matches input size.
|
||||
Same,
|
||||
/// No padding applied.
|
||||
Valid,
|
||||
/// Applies explicit padding values.
|
||||
/// Format: (left, right)
|
||||
/// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`).
|
||||
Explicit(usize, usize),
|
||||
}
|
||||
|
||||
impl PaddingConfig1d {
|
||||
/// Calculate padding as (left, right) pair for 1D operations.
|
||||
/// For `Same` padding, this computes the actual asymmetric padding if needed.
|
||||
pub(crate) fn calculate_padding_1d_pair(
|
||||
&self,
|
||||
length: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
) -> (usize, usize) {
|
||||
match self {
|
||||
Self::Valid => (0, 0),
|
||||
Self::Same => calculate_same_padding(kernel_size, stride, length),
|
||||
Self::Explicit(left, right) => (*left, *right),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Padding configuration for 2D operators.
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig2d {
|
||||
/// Dynamically calculates padding to preserve input dimensions in output.
|
||||
Same,
|
||||
/// No padding applied.
|
||||
Valid,
|
||||
/// Applies explicit padding values.
|
||||
/// Format: (top, left, bottom, right)
|
||||
/// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`).
|
||||
Explicit(usize, usize, usize, usize),
|
||||
}
|
||||
|
||||
impl PaddingConfig2d {
|
||||
/// Calculate padding as ((top, bottom), (left, right)) pairs for 2D operations.
|
||||
/// For `Same` padding, this computes the actual asymmetric padding if needed.
|
||||
pub(crate) fn calculate_padding_2d_pairs(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
kernel_size: &[usize; 2],
|
||||
stride: &[usize; 2],
|
||||
) -> ((usize, usize), (usize, usize)) {
|
||||
match self {
|
||||
Self::Valid => ((0, 0), (0, 0)),
|
||||
Self::Same => {
|
||||
let (top, bottom) = calculate_same_padding(kernel_size[0], stride[0], height);
|
||||
let (left, right) = calculate_same_padding(kernel_size[1], stride[1], width);
|
||||
((top, bottom), (left, right))
|
||||
}
|
||||
Self::Explicit(top, left, bottom, right) => ((*top, *bottom), (*left, *right)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate symmetric padding for 2D operations.
|
||||
/// Returns padding values [height, width] (same for both sides).
|
||||
/// Panics if asymmetric padding is detected.
|
||||
pub(crate) fn calculate_padding_2d(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
kernel_size: &[usize; 2],
|
||||
stride: &[usize; 2],
|
||||
) -> [usize; 2] {
|
||||
let ((top, bottom), (left, right)) =
|
||||
self.calculate_padding_2d_pairs(height, width, kernel_size, stride);
|
||||
if top != bottom || left != right {
|
||||
panic!("Asymmetric padding should be handled via calculate_padding_2d_pairs()")
|
||||
}
|
||||
[top, left]
|
||||
}
|
||||
}
|
||||
|
||||
/// Padding configuration for 3D operators.
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig3d {
|
||||
/// Dynamically calculates padding to preserve input dimensions in output.
|
||||
Same,
|
||||
/// No padding applied.
|
||||
Valid,
|
||||
/// Applies explicit symmetric padding values.
|
||||
/// Format: (depth, height, width) — same padding on both sides of each dimension.
|
||||
Explicit(usize, usize, usize),
|
||||
}
|
||||
|
||||
impl PaddingConfig3d {
|
||||
/// Calculate symmetric padding for 3D operations.
|
||||
/// Returns padding values [depth, height, width] (same for both sides).
|
||||
pub(crate) fn calculate_padding_3d(
|
||||
&self,
|
||||
depth: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
kernel_size: &[usize; 3],
|
||||
stride: &[usize; 3],
|
||||
) -> [usize; 3] {
|
||||
match self {
|
||||
Self::Valid => [0, 0, 0],
|
||||
Self::Same => {
|
||||
let (front, back) = calculate_same_padding(kernel_size[0], stride[0], depth);
|
||||
let (top, bottom) = calculate_same_padding(kernel_size[1], stride[1], height);
|
||||
let (left, right) = calculate_same_padding(kernel_size[2], stride[2], width);
|
||||
if front != back || top != bottom || left != right {
|
||||
panic!(
|
||||
"Asymmetric 3D 'Same' padding is not supported. \
|
||||
Use odd kernel sizes for symmetric padding."
|
||||
)
|
||||
}
|
||||
[front, top, left]
|
||||
}
|
||||
Self::Explicit(depth, height, width) => [*depth, *height, *width],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ==================== PaddingConfig1d Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_1d_calculate_pair_valid() {
|
||||
let padding = PaddingConfig1d::Valid;
|
||||
assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (0, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_1d_calculate_pair_explicit() {
|
||||
let padding = PaddingConfig1d::Explicit(1, 2);
|
||||
assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_1d_calculate_pair_same() {
|
||||
let padding = PaddingConfig1d::Same;
|
||||
// kernel=3, stride=1, length=10: total=2, start=1, end=1
|
||||
assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 1));
|
||||
}
|
||||
|
||||
// ==================== PaddingConfig2d Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_2d_calculate_pairs_valid() {
|
||||
let padding = PaddingConfig2d::Valid;
|
||||
assert_eq!(
|
||||
padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
|
||||
((0, 0), (0, 0))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_2d_calculate_pairs_explicit() {
|
||||
let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
|
||||
assert_eq!(
|
||||
padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
|
||||
((1, 3), (2, 4))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_2d_calculate_symmetric_valid() {
|
||||
let padding = PaddingConfig2d::Valid;
|
||||
assert_eq!(
|
||||
padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
|
||||
[0, 0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_2d_calculate_symmetric_explicit() {
|
||||
let padding = PaddingConfig2d::Explicit(2, 3, 2, 3);
|
||||
assert_eq!(
|
||||
padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
|
||||
[2, 3]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(
|
||||
expected = "Asymmetric padding should be handled via calculate_padding_2d_pairs"
|
||||
)]
|
||||
fn test_padding_config_2d_calculate_symmetric_asymmetric_panics() {
|
||||
let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
|
||||
let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]);
|
||||
}
|
||||
|
||||
// ==================== PaddingConfig3d Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_3d_calculate_valid() {
|
||||
let padding = PaddingConfig3d::Valid;
|
||||
assert_eq!(
|
||||
padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
|
||||
[0, 0, 0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_3d_calculate_explicit() {
|
||||
let padding = PaddingConfig3d::Explicit(1, 2, 3);
|
||||
assert_eq!(
|
||||
padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
|
||||
[1, 2, 3]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_padding_config_3d_calculate_same_odd_kernel() {
|
||||
let padding = PaddingConfig3d::Same;
|
||||
// kernel=3, stride=1: total=2, symmetric (1,1) per dim
|
||||
assert_eq!(
|
||||
padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
|
||||
[1, 1, 1]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::module::{Module, Quantizer};
|
||||
use burn::tensor::{
|
||||
Device, Distribution, Tensor, Tolerance,
|
||||
ops::{FloatElem, QuantizedTensor},
|
||||
quantization::{
|
||||
Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantScheme, QuantValue,
|
||||
},
|
||||
};
|
||||
use burn_nn::{
|
||||
Linear, LinearConfig,
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
};
|
||||
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda"),
|
||||
not(feature = "test-rocm")
|
||||
))]
|
||||
pub type B = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
/// Backend for test cases
|
||||
pub type B = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
/// Backend for test cases
|
||||
pub type B = burn_cuda::Cuda;
|
||||
|
||||
#[cfg(all(test, feature = "test-rocm"))]
|
||||
/// Backend for test cases
|
||||
pub type B = burn_rocm::Rocm;
|
||||
|
||||
fn should_quantize_module<M: Module<B>, const D: usize, F: Fn(&M) -> Tensor<B, D>>(
|
||||
module: M,
|
||||
scheme: QuantScheme,
|
||||
func: F,
|
||||
tolerance: Tolerance<FloatElem<B>>,
|
||||
) {
|
||||
let result = func(&module);
|
||||
|
||||
let calibration = Calibration::MinMax;
|
||||
let mut quantizer = Quantizer {
|
||||
calibration,
|
||||
scheme,
|
||||
};
|
||||
let q_module = module.quantize_weights(&mut quantizer);
|
||||
let q_result = func(&q_module);
|
||||
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<f32>(&q_result.into_data(), tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_quantize_transformer() {
|
||||
let device: Device<B> = Default::default();
|
||||
let transformer: TransformerEncoder<B> =
|
||||
TransformerEncoderConfig::new(128, 256, 2, 2).init(&device);
|
||||
let signal = Tensor::random([2, 32, 128], Distribution::Default, &device);
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::block([32]))
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
should_quantize_module(
|
||||
transformer,
|
||||
scheme,
|
||||
|tr| tr.forward(TransformerEncoderInput::new(signal.clone())),
|
||||
Tolerance::rel_abs(1e-2, 2e-2), // slightly higher abs tolerance (permissive: 1e-2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_quantize_linear_128_256() {
|
||||
let device: Device<B> = Default::default();
|
||||
let transformer: Linear<B> = LinearConfig::new(128, 256).with_bias(false).init(&device);
|
||||
let signal = Tensor::<B, 2>::random([1, 128], Distribution::Default, &device);
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::Tensor)
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
should_quantize_module(
|
||||
transformer,
|
||||
scheme,
|
||||
|tr| tr.forward(signal.clone()),
|
||||
Tolerance::permissive(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_quantize_linear() {
|
||||
let device: Device<B> = Default::default();
|
||||
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
||||
let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);
|
||||
// Default scheme should select supported QuantStore default
|
||||
// TODO: set native if dtype is supported by the test backend
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::Tensor)
|
||||
// .with_store(QuantStore::Native)
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
should_quantize_module(
|
||||
transformer,
|
||||
scheme,
|
||||
|tr| tr.forward(signal.clone()),
|
||||
Tolerance::permissive(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_quantize_linear_weights() {
|
||||
let device: Device<B> = Default::default();
|
||||
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::Tensor)
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
should_quantize_module(
|
||||
transformer,
|
||||
scheme,
|
||||
|tr| tr.weight.val().dequantize(),
|
||||
Tolerance::permissive(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_quantize_linear_blocks() {
|
||||
let device: Device<B> = Default::default();
|
||||
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
||||
let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::block([16]))
|
||||
// .with_store(QuantStore::Native)
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
should_quantize_module(
|
||||
transformer,
|
||||
scheme,
|
||||
|tr| tr.forward(signal.clone()),
|
||||
Tolerance::permissive(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_quantize_linear_weights_blocks() {
|
||||
let device: Device<B> = Default::default();
|
||||
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::block([16]))
|
||||
// .with_store(QuantStore::Native)
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
should_quantize_module(
|
||||
transformer,
|
||||
scheme,
|
||||
|tr| tr.weight.val().dequantize(),
|
||||
Tolerance::permissive(),
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user