feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,102 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Optimizer building blocks for the Burn deep learning framework"
documentation = "https://docs.rs/burn-optim"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-optim"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-optim"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"std",
"burn-core/default",
]
doc = [
"std",
# Doc features
"burn-core/doc",
]
std = [
"burn-core/std",
"num-traits/std",
"serde/std",
"log",
]
tracing = [
"burn-collective?/tracing",
"burn-core/tracing",
"burn-cuda?/tracing",
"burn-fusion?/tracing",
"burn-remote?/tracing",
"burn-rocm?/tracing",
"burn-router?/tracing",
"burn-tch?/tracing",
"burn-wgpu?/tracing",
]
collective = ["burn-collective"]
test-cuda = [
"burn-cuda/default",
] # To use cuda during testing, default uses ndarray.
test-rocm = [
"burn-rocm/default",
] # To use hip during testing, default uses ndarray.
test-tch = [
"burn-tch/default",
] # To use tch during testing, default uses ndarray.
test-wgpu = [
"burn-wgpu/default",
] # To use wgpu during testing, default uses ndarray.
test-vulkan = [
"test-wgpu",
"burn-wgpu/vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.
test-metal = [
"test-wgpu",
"burn-wgpu/metal",
] # To use wgpu-spirv during testing, default uses ndarray.
# Memory checks are disabled by default
test-memory-checks = ["burn-fusion/memory-checks"]
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false }
burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true, default-features = false }
num-traits = { workspace = true }
derive-new = { workspace = true }
log = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
# FOR TESTING
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
[dev-dependencies]
burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2" }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
rstest = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1,3 @@
# Burn Optimizers
Core building blocks for Burn optimizers.

View File

@@ -0,0 +1,144 @@
use burn_core as burn;
use burn::tensor::backend::Backend;
use burn::{config::Config, tensor::Tensor};
/// Gradient Clipping provides a way to mitigate exploding gradients
#[derive(Config, Debug)]
pub enum GradientClippingConfig {
/// Clip the gradient by value.
Value(f32),
/// Clip the gradient by norm.
Norm(f32),
}
impl GradientClippingConfig {
/// Initialize the gradient clipping.
///
/// # Returns
///
/// The gradient clipping.
pub fn init(&self) -> GradientClipping {
match self {
GradientClippingConfig::Value(val) => GradientClipping::Value(*val),
GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val),
}
}
}
/// Gradient Clipping provides a way to mitigate exploding gradients
/// by clipping every component of the gradient by value or by norm during
/// backpropagation.
#[derive(Clone)]
pub enum GradientClipping {
/// Clip the gradient by value.
Value(f32),
/// Clip the gradient by norm.
Norm(f32),
}
impl GradientClipping {
/// Clip the gradient.
///
/// # Arguments
///
/// * `grad` - The gradient to clip.
///
/// # Returns
///
/// The clipped gradient.
pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {
match self {
GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),
GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm),
}
}
fn clip_by_value<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
threshold: f32,
) -> Tensor<B, D> {
let greater_mask = grad.clone().greater_elem(threshold);
let lower_mask = grad.clone().lower_elem(-threshold);
let clipped_grad = grad.mask_fill(greater_mask, threshold);
clipped_grad.mask_fill(lower_mask, -threshold)
}
fn clip_by_norm<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
threshold: f32,
) -> Tensor<B, D> {
let norm = Self::l2_norm(grad.clone());
let clip_coef = threshold / norm.add_scalar(1e-6); // avoid div by zero
let clip_coef_clamped = clip_coef.clamp_max(1.0);
grad.mul(clip_coef_clamped.unsqueeze())
}
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
let squared = tensor.square();
let sum = squared.sum();
sum.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::Tensor;
#[test]
fn test_clip_by_value() {
let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&Default::default(),
);
let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient);
let clipped_gradient_data = clipped_gradient.into_data();
for value in clipped_gradient_data.iter::<f32>() {
assert!(value <= 0.5);
}
}
#[test]
fn test_clip_by_norm() {
let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&Default::default(),
);
let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient);
let clipped_gradient_data = clipped_gradient.into_data();
for value in clipped_gradient_data.iter::<f32>() {
assert!(value <= 0.88);
}
}
#[test]
fn test_clip_by_norm_no_clipping() {
let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
[[0.3, 0.4, 0.5, 0.2], [0.1, 0.6, 0.3, 0.4]],
&Default::default(),
);
let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient.clone());
clipped_gradient
.into_data()
.assert_eq(&gradient.into_data(), true);
}
}

View File

@@ -0,0 +1,2 @@
mod base;
pub use base::*;

View File

@@ -0,0 +1,63 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "256"]
//! Burn optimizers.
#[macro_use]
extern crate derive_new;
extern crate alloc;
/// Optimizer module.
pub mod optim;
pub use optim::*;
/// Gradient clipping module.
pub mod grad_clipping;
/// Learning rate scheduler module.
#[cfg(feature = "std")]
pub mod lr_scheduler;
/// Type alias for the learning rate.
///
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
/// can be used for constant learning rate.
pub type LearningRate = f64; // We could potentially change the type.
/// Backend for test cases
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda"),
not(feature = "test-rocm")
))]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-tch"))]
/// Backend for test cases
pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type TestBackend = burn_cuda::Cuda;
#[cfg(all(test, feature = "test-rocm"))]
/// Backend for test cases
pub type TestBackend = burn_rocm::Rocm;
/// Backend for autodiff test cases
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
#[cfg(all(test, feature = "test-memory-checks"))]
mod tests {
burn_fusion::memory_checks!();
}

View File

@@ -0,0 +1,85 @@
pub(super) use alloc::string::String;
use burn_core as burn;
use burn::record::Record;
use burn::tensor::backend::Backend;
use crate::LearningRate;
/// Learning rate scheduler defines how the learning rate will evolve during training.
pub trait LrScheduler: Clone + Send + Sync {
/// Scheduler associative type to be used when saving and loading the state.
type Record<B: Backend>: Record<B>;
/// Perform the scheduler step, potentially updating its state, and returning the effective
/// learning rate.
fn step(&mut self) -> LearningRate;
/// Get the current state of the scheduler as a [record](Record).
fn to_record<B: Backend>(&self) -> Self::Record<B>;
/// Load the state of the scheduler as a [record](Record).
fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;
}
#[cfg(test)]
pub(super) mod test_utils {
use super::*;
use crate::TestBackend;
// A small tolerance for learning rate comparisons. Depending on how learning rates are
// computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is
// used here.
const LOOSE_EPSILON: LearningRate = 1e-10;
pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
where
I: IntoIterator<Item = LearningRate>,
S: LrScheduler,
{
expected_lrs
.into_iter()
.enumerate()
.for_each(|(i, expected)| {
let lr = scheduler.step();
assert!(
(lr - expected).abs() < LOOSE_EPSILON,
"Scheduled learning rate {lr} is not approximately equal to the expected value \
{expected} at step {i}",
);
});
}
// save_at_step is the number of steps to run the scheduler before saving and loading back its
// state.
pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)
where
S: Clone + LrScheduler,
{
let mut truth = scheduler.clone();
// Consume some steps before saving and loading back
(0..save_at_step).for_each(|_| {
truth.step();
scheduler.step();
});
let rec = scheduler.to_record::<TestBackend>();
scheduler = scheduler.load_record::<TestBackend>(rec);
// Validate that the scheduler resumes from where it left off.
compare_steps(&mut scheduler, &mut truth, save_at_step);
}
// Check if two schedulers produce the same learning rate sequences over the specified number of
// steps.
pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
(0..num_steps).for_each(|i| {
let lr_a = a.step();
let lr_b = b.step();
assert!(
(lr_a - lr_b).abs() < LOOSE_EPSILON,
"The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
sequences are not approximately equal",
);
});
}
}

View File

@@ -0,0 +1,195 @@
use burn_core as burn;
use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};
use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};
use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};
use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};
use super::{LrScheduler, String};
use crate::LearningRate;
use burn::config::Config;
use burn::record::Record;
use burn::tensor::backend::Backend;
/// Compose multiple [learning rate schedulers](LrScheduler) together.
#[derive(Config, Debug)]
pub struct ComposedLrSchedulerConfig {
#[config(default = "Vec::new()")]
schedulers: Vec<LrSchedulerConfig>,
#[config(default = "SchedulerReduction::Prod")]
reduction: SchedulerReduction,
}
/// Compose multiple [learning rate schedulers](LrScheduler) together.
#[derive(Clone)]
pub struct ComposedLrScheduler {
schedulers: Vec<LrSchedulerItem>,
reduction: SchedulerReduction,
}
/// Defines how the learning rates generated by the schedulers are combined.
#[derive(Config, Debug, Copy)]
pub enum SchedulerReduction {
/// All learning rates are averaged.
Avg,
/// All learning rates are summed.
Sum,
/// All learning rates are multiplied.
Prod,
}
impl ComposedLrSchedulerConfig {
/// Initialize the learning rate scheduler.
pub fn init(&self) -> Result<ComposedLrScheduler, String> {
let mut schedulers = Vec::with_capacity(self.schedulers.len());
for config in self.schedulers.iter() {
let config = match config {
LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),
LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),
LrSchedulerConfig::Exponential(config) => {
LrSchedulerItem::Exponential(config.init()?)
}
LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),
};
schedulers.push(config);
}
Ok(ComposedLrScheduler {
schedulers,
reduction: self.reduction,
})
}
/// Appends a [linear scheduler](LinearLrScheduler).
pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Linear(config));
self
}
/// Appends a [cosine scheduler](ComposedLrSchedulerConfig).
pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Cosine(config));
self
}
/// Appends an [exponential scheduler](ExponentialLrScheduler).
pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Exponential(config));
self
}
/// Appends a [noam scheduler](NoamLrScheduler).
pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Noam(config));
self
}
}
#[derive(Config, Debug)]
enum LrSchedulerConfig {
Linear(LinearLrSchedulerConfig),
Cosine(CosineAnnealingLrSchedulerConfig),
Exponential(ExponentialLrSchedulerConfig),
Noam(NoamLrSchedulerConfig),
}
#[derive(Clone)]
enum LrSchedulerItem {
Linear(LinearLrScheduler),
Cosine(CosineAnnealingLrScheduler),
Exponential(ExponentialLrScheduler),
Noam(NoamLrScheduler),
}
#[derive(Record)]
/// Record item for the [composed learning rate scheduler](ComposedLrScheduler).
pub enum LrSchedulerRecord<B: Backend> {
/// The linear variant.
Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
/// The cosine variant.
Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),
/// The exponential variant.
Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),
/// The noam variant.
Noam(<NoamLrScheduler as LrScheduler>::Record<B>),
}
#[derive(Record)]
/// Records for the [composed learning rate scheduler](ComposedLrScheduler).
pub struct ComposedLrSchedulerRecord<B: Backend> {
schedulers: Vec<LrSchedulerRecord<B>>,
}
impl LrScheduler for ComposedLrScheduler {
type Record<B: Backend> = ComposedLrSchedulerRecord<B>;
fn step(&mut self) -> LearningRate {
let mut step = match self.reduction {
SchedulerReduction::Avg => 0.0,
SchedulerReduction::Sum => 0.0,
SchedulerReduction::Prod => 1.0,
};
let num_scheduler = self.schedulers.len() as f64;
for lr in self.schedulers.iter_mut().map(|s| match s {
LrSchedulerItem::Linear(item) => item.step(),
LrSchedulerItem::Cosine(item) => item.step(),
LrSchedulerItem::Exponential(item) => item.step(),
LrSchedulerItem::Noam(item) => item.step(),
}) {
step = match self.reduction {
SchedulerReduction::Avg => step + (lr / num_scheduler),
SchedulerReduction::Sum => step + lr,
SchedulerReduction::Prod => step * lr,
}
}
step
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
ComposedLrSchedulerRecord::<B> {
schedulers: self
.schedulers
.iter()
.map(|s| match s {
LrSchedulerItem::Linear(item) => {
LrSchedulerRecord::Linear(item.to_record::<B>())
}
LrSchedulerItem::Cosine(item) => {
LrSchedulerRecord::Linear(item.to_record::<B>())
}
LrSchedulerItem::Exponential(item) => {
LrSchedulerRecord::Exponential(item.to_record::<B>())
}
LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),
})
.collect(),
}
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.schedulers = self
.schedulers
.into_iter()
.zip(record.schedulers)
.map(|scheduler| match scheduler {
(LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {
LrSchedulerItem::Linear(item.load_record::<B>(record))
}
(LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {
LrSchedulerItem::Cosine(item.load_record::<B>(record))
}
(LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {
LrSchedulerItem::Exponential(item.load_record::<B>(record))
}
(LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {
LrSchedulerItem::Noam(item.load_record::<B>(record))
}
_ => panic!("Invalid state"),
})
.collect();
self
}
}

View File

@@ -0,0 +1,50 @@
use burn_core as burn;
use burn::tensor::backend::Backend;
use super::LrScheduler;
use crate::LearningRate;
/// Constant learning rate implementing [learning rate scheduler](LrScheduler).
///
/// # Notes
///
/// You can also use [learning rate](LearningRate) which the same effect.
#[derive(new, Clone, Debug)]
pub struct ConstantLr {
lr: LearningRate,
}
impl From<LearningRate> for ConstantLr {
fn from(lr: LearningRate) -> Self {
Self { lr }
}
}
impl LrScheduler for ConstantLr {
type Record<B: Backend> = ();
fn step(&mut self) -> LearningRate {
self.lr
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {}
fn load_record<B: Backend>(self, _record: Self::Record<B>) -> Self {
self
}
}
impl LrScheduler for LearningRate {
type Record<B: Backend> = ();
fn step(&mut self) -> LearningRate {
*self
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {}
fn load_record<B: Backend>(self, _record: Self::Record<B>) -> Self {
self
}
}

View File

@@ -0,0 +1,191 @@
use burn_core as burn;
use super::{LrScheduler, String};
use crate::LearningRate;
use burn::config::Config;
use burn::tensor::backend::Backend;
/// The configuration for creating a [Cosine Annealing learning rate scheduler with warm
/// restarts](CosineAnnealingLrScheduler).
///
/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by
/// following a cosine function. After `num_iters` iterations, the learning rate is reset to
/// `initial_lr`.
#[derive(Config, Debug)]
pub struct CosineAnnealingLrSchedulerConfig {
// The initial learning rate.
initial_lr: LearningRate,
// The final learning rate.
#[config(default = 0.0)]
min_lr: LearningRate,
// The number of iterations between two restarts. The two restart iterations themselves are not
// included.
num_iters: usize,
}
impl CosineAnnealingLrSchedulerConfig {
/// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler).
///
/// # Errors
///
/// An error will be returned if any of the following conditions is true:
///
/// * `initial_lr` is out of range (0.0, 1.0]
/// * `min_lr` is out of range [0.0, `initial_lr`]
/// * `num_iters` is 0
pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
if self.initial_lr <= 0. || self.initial_lr > 1. {
return Err("Initial learning rate must be greater than 0 and at most 1".into());
}
if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
return Err(
"Minimum learning rate must be at least 0 and at most equal to the initial \
learning rate"
.into(),
);
}
if self.num_iters == 0 {
return Err("Number of iterations must be at least 1".into());
}
Ok(CosineAnnealingLrScheduler {
min_lr: self.min_lr,
max_lr: self.initial_lr,
num_iters: self.num_iters,
current_iter: usize::MAX,
})
}
}
/// A Cosine Annealing learning rate scheduler.
///
/// This scheduler is described in [SGDR: Stochastic Gradient Descent with Warm
/// Restarts](https://arxiv.org/abs/1608.03983). See [CosineAnnealingLrSchedulerConfig] for more
/// information.
#[derive(Clone, Copy, Debug)]
pub struct CosineAnnealingLrScheduler {
min_lr: LearningRate,
max_lr: LearningRate,
num_iters: usize,
current_iter: usize,
}
impl LrScheduler for CosineAnnealingLrScheduler {
type Record<B: Backend> = usize;
fn step(&mut self) -> LearningRate {
// Make current_iter overflow from usize::MAX to 0 to get the initial learning rate on the
// first call. We could've used i64 with an initial value -1, but keeping it in usize saves
// us from some type casting here.
self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);
self.min_lr
+ 0.5
* (self.max_lr - self.min_lr)
* (1.0
+ (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)
.cos())
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.current_iter
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.current_iter = record;
self
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils;
use super::*;
#[test]
fn config_initial_lr_too_low() {
let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_initial_lr_too_high() {
let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_min_lr_too_low() {
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
.with_min_lr(-0.1)
.init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Minimum learning rate must be at least 0 and at most equal to the initial learning \
rate",
"Error messages should match",
);
}
#[test]
fn config_min_lr_too_high() {
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
.with_min_lr(0.6)
.init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Minimum learning rate must be at least 0 and at most equal to the initial learning \
rate",
"Error messages should match",
);
}
#[test]
fn config_num_iters_too_low() {
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Number of iterations must be at least 1",
"Error messages should match",
);
}
#[test]
fn test_lr_change() {
const INITIAL_LR: LearningRate = 0.5;
const MIN_LR: LearningRate = 0.1;
let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
.with_min_lr(MIN_LR)
.init()
.unwrap();
let expected_lrs = [
INITIAL_LR, // cos(0)
(INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2)
MIN_LR, // cos(PI)
INITIAL_LR, // restart
];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_save_and_load() {
const NUM_ITERS: usize = 9;
let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
.init()
.unwrap();
test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
}
}

View File

@@ -0,0 +1,139 @@
use burn_core as burn;
use super::{LrScheduler, String};
use crate::LearningRate;
use burn::config::Config;
use burn::tensor::backend::Backend;
/// The configuration for creating an [exponential learning rate scheduler](ExponentialLrScheduler).
///
/// This scheduler returns the learning rate `initial_lr` at the first step, then multiplies it by
/// a constant `gamma` at every iteration. At any iteration `i` (which starts from 0), the learning
/// rate is given by `initial_lr * gamma^i`.
#[derive(Config, Debug)]
pub struct ExponentialLrSchedulerConfig {
// The initial learning rate.
initial_lr: LearningRate,
// The constant that the learning rate is multiplied by on each iteration.
gamma: f64,
}
impl ExponentialLrSchedulerConfig {
/// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler).
///
/// # Errors
///
/// An error will be returned if any of the following conditions is true:
///
/// * `initial_lr` is out of range (0.0, 1.0]
/// * `gamma` is out of range (0.0, 1.0]
pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
if self.initial_lr <= 0. || self.initial_lr > 1. {
return Err("Initial learning rate must be greater than 0 and at most 1".into());
}
if self.gamma <= 0. || self.gamma > 1. {
return Err("Gamma must be greater than 0 and at most 1".into());
}
Ok(ExponentialLrScheduler {
// Such an initial value eliminates the need for special-case handling of the first
// learning rate.
previous_lr: self.initial_lr / self.gamma,
gamma: self.gamma,
})
}
}
/// A exponential learning rate scheduler.
///
/// See [ExponentialLrSchedulerConfig] for more information.
#[derive(Clone, Copy, Debug)]
pub struct ExponentialLrScheduler {
// The previous iteration's learning rate.
previous_lr: LearningRate,
// The constant that the learning rate is multiplied by on each iteration.
gamma: f64,
}
impl LrScheduler for ExponentialLrScheduler {
type Record<B: Backend> = LearningRate;
fn step(&mut self) -> LearningRate {
self.previous_lr *= self.gamma;
self.previous_lr
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.previous_lr
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.previous_lr = record;
self
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils;
use super::*;
#[test]
fn config_initial_lr_too_low() {
let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_initial_lr_too_high() {
let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_gamma_too_low() {
let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Gamma must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_gamma_too_high() {
let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Gamma must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn test_lr_change() {
let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_save_and_load() {
let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
.init()
.unwrap();
test_utils::check_save_load(scheduler, 7);
}
}

View File

@@ -0,0 +1,173 @@
use burn_core as burn;
use super::{LrScheduler, String};
use crate::LearningRate;
use burn::config::Config;
use burn::tensor::backend::Backend;
/// The configuration for creating a [linear learning rate scheduler](LinearLrScheduler).
///
/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by a
/// constant amount on each iteration until reaching a final learning rate `final_lr`. The
/// `num_iters` parameter controls how many iterations are needed to go from `initial_lr` to
/// `final_lr`.
#[derive(Config, Debug)]
pub struct LinearLrSchedulerConfig {
// The initial learning rate.
initial_lr: LearningRate,
// The final learning rate.
final_lr: LearningRate,
// The number of iterations before reaching the final learning rate.
num_iters: usize,
}
impl LinearLrSchedulerConfig {
/// Initializes a [linear learning rate scheduler](LinearLrScheduler).
///
/// # Errors
///
/// An error will be returned if any of the following conditions is true:
///
/// * `initial_lr` is out of range (0.0, 1.0]
/// * `final_lr` is out of range [0.0, 1.0]
/// * `num_iters` is 0
pub fn init(&self) -> Result<LinearLrScheduler, String> {
if self.initial_lr <= 0. || self.initial_lr > 1. {
return Err("Initial learning rate must be greater than 0 and at most 1".into());
}
if self.final_lr < 0. || self.final_lr > 1. {
return Err("Final learning rate must be at least 0 and at most 1".into());
}
if self.num_iters == 0 {
return Err("Number of iterations must be at least 1".into());
}
Ok(LinearLrScheduler {
final_lr: self.final_lr,
step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,
remaining_iters: self.num_iters + 1,
})
}
}
/// A linear learning rate scheduler.
///
/// See [LinearLrSchedulerConfig] for more information.
#[derive(Clone, Copy, Debug)]
pub struct LinearLrScheduler {
// The final learning rate after the linear changing process stops.
final_lr: LearningRate,
// The amount that the learning rate changes by on each iteration.
step_size: f64,
// The number of iterations left before reaching the final learning rate.
remaining_iters: usize,
}
impl LrScheduler for LinearLrScheduler {
type Record<B: Backend> = usize;
fn step(&mut self) -> LearningRate {
self.remaining_iters -= (self.remaining_iters != 0) as usize;
self.final_lr - self.step_size * self.remaining_iters as f64
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.remaining_iters
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.remaining_iters = record;
self
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils;
use super::*;
#[test]
fn config_initial_lr_too_low() {
let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_initial_lr_too_high() {
let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_final_lr_too_low() {
let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Final learning rate must be at least 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_final_lr_too_high() {
let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Final learning rate must be at least 0 and at most 1",
"Error messages should match",
);
}
#[test]
fn config_num_iters_too_low() {
let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Number of iterations must be at least 1",
"Error messages should match",
);
}
#[test]
fn test_lr_decreasing() {
let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();
let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_lr_increasing() {
let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();
let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_lr_unchanging() {
let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();
let expected_lrs = [0.3, 0.3, 0.3, 0.3];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_save_and_load() {
const NUM_ITERS: usize = 6;
let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)
.init()
.unwrap();
test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
}
}

View File

@@ -0,0 +1,24 @@
/// Constant learning rate scheduler
pub mod constant;
/// Composed learning rate scheduler
pub mod composed;
/// Linear learning rate scheduler
pub mod linear;
/// Noam learning rate scheduler
pub mod noam;
/// Exponential learning rate scheduler
pub mod exponential;
/// Cosine learning rate scheduler
pub mod cosine;
/// Step learning rate scheduler
pub mod step;
mod base;
pub use base::*;

View File

@@ -0,0 +1,136 @@
use burn_core as burn;
use burn::config::Config;
use burn::tensor::backend::Backend;
use super::{LrScheduler, String};
use crate::LearningRate;
/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
#[derive(Config, Debug)]
pub struct NoamLrSchedulerConfig {
/// The overall scale factor for the learning rate decay.
factor: f64,
/// The number of steps before the exponential decay stats.
#[config(default = 4000)]
warmup_steps: usize,
/// The size of the model.
#[config(default = 512)]
model_size: usize,
}
/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
#[derive(Clone, Debug)]
pub struct NoamLrScheduler {
warmup_steps: f64,
embedding_size: f64,
factor: f64,
step: f64,
}
impl NoamLrSchedulerConfig {
/// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.
///
/// # Errors
///
/// An error will be returned if any of the following conditions is true:
///
/// * `warmup_steps` is 0
/// * `model_size` is 0
pub fn init(&self) -> Result<NoamLrScheduler, String> {
if self.warmup_steps == 0 {
return Err(
"Number of steps before exponential decay starts must be greater than 0".into(),
);
}
if self.model_size == 0 {
return Err("Model size must be greater than 0".into());
}
Ok(NoamLrScheduler {
warmup_steps: self.warmup_steps as f64,
embedding_size: self.model_size as f64,
factor: self.factor,
step: 0.0,
})
}
}
impl LrScheduler for NoamLrScheduler {
type Record<B: Backend> = usize;
fn step(&mut self) -> LearningRate {
self.step += 1.0;
let arg1 = self.step.powf(-0.5);
let arg2 = self.step * self.warmup_steps.powf(-1.5);
self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.step as usize
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.step = record as f64;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_warmup_steps_invalid() {
let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();
assert!(r.is_err(), "Should return an error");
}
#[test]
fn test_config_warmup_steps_valid() {
let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();
assert!(r.is_ok(), "Should return a success value");
}
#[test]
fn test_config_model_size_invalid() {
let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();
assert!(r.is_err(), "Should return an error");
}
#[test]
fn test_config_model_size_valid() {
let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();
assert!(r.is_ok(), "Should return a success value");
}
#[test]
fn test_function_increase_and_decrease() {
let warmup_steps = 100;
let mut scheduler = NoamLrSchedulerConfig::new(10.0)
.with_warmup_steps(warmup_steps)
.init()
.unwrap();
let mut lr_current = 0.0;
for _ in 0..warmup_steps {
let lr = scheduler.step();
assert!(
lr > lr_current,
"Learning rate should increase before the warmup_steps is reached."
);
lr_current = lr;
}
for _ in 0..warmup_steps {
let lr = scheduler.step();
assert!(
lr < lr_current,
"Learning rate should decrease after the warmup_steps is reached."
);
lr_current = lr;
}
}
}

View File

@@ -0,0 +1,218 @@
use burn_core as burn;
use burn::config::Config;
use burn::tensor::backend::Backend;
use super::{LrScheduler, String};
use crate::LearningRate;
/// The configuration for create a [step learning rate scheduler](StepLrScheduler).
///
/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until
/// the same value has been given for `step_size` times. Then it multiplies the learning rate by
/// `gamma` before repeating the process.
///
/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but
/// a warning log will be output for such a value in case of mistyping.
///
/// ## Notes
///
/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than
/// `i32::MAX + 1` times.
#[derive(Config, Debug)]
pub struct StepLrSchedulerConfig {
// The learning rate at the initial step.
initial_lr: LearningRate,
// The number of iterations over which the learning rate remains unchanged before the next
// update.
step_size: usize,
/// The factor by which the learning rate is multiplied with each update. Default: 0.1.
#[config(default = 0.1)]
gamma: f64,
}
impl StepLrSchedulerConfig {
/// Initializes a [step learning rate scheduler](StepLrScheduler).
///
/// # Errors
///
/// An error will be returned if `step_size` is 0.
pub fn init(&self) -> Result<StepLrScheduler, String> {
if self.step_size == 0 {
return Err("Step size must be greater than 0".into());
}
// Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful
// in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518).
if self.initial_lr <= 0.0 {
log::warn!(
"Initial learning rate value of {} is not a positive number. Ignore this warning \
if it is intended.",
self.initial_lr
);
}
if self.gamma <= 0.0 || self.gamma >= 1.0 {
log::warn!(
"Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
intended.",
self.gamma
);
}
Ok(StepLrScheduler {
init_lr: self.initial_lr,
step_size: self.step_size,
gamma: self.gamma,
iter_idx: -1,
})
}
}
/// Step learning rate scheduler.
#[derive(Clone, Debug)]
pub struct StepLrScheduler {
init_lr: LearningRate,
step_size: usize,
gamma: f64,
// The index of the current iteration.
// `i32` is used for avoiding truncating the exponent when taking powers of `gamma`.
iter_idx: i32,
}
impl LrScheduler for StepLrScheduler {
type Record<B: Backend> = i32;
fn step(&mut self) -> LearningRate {
self.iter_idx = self
.iter_idx
.checked_add(1)
.expect("`.step()` should be called no more than `i32::MAX + 1` times");
// Type casting below causes no truncation, as all the values fall within the ranges.
self.init_lr
* self
.gamma
.powi((self.iter_idx as usize / self.step_size) as i32)
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.iter_idx
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.iter_idx = record;
self
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils;
use super::*;
use crate::TestBackend;
// Warning logs for initial LR and gamma are not tested because there seems no straightforward
// way to do it.
//
// Creating a mock logger that collects logs into `String` for later examination seems a possible
// solution, but unit tests run in the same process in parallel, where the single logger would
// be shared by multiple tests, so logs from different tests would be mixed up with no easy way
// to separate them.
// Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is
// worth the slowdown would be a question. Also, using a primitive provided by `std` to
// synchronize the logger across tests is not an option since we need to support `no-std`.
// Maybe the mocking approach can be reconsidered after we are given an option to run tests in
// separate processes like what the issue below is proposing:
// https://github.com/rust-lang/rust/issues/47506
//
// As a side note, a helper crate exists for the exact purpose:
// https://crates.io/crates/testing_logger
// but the crate has been unmaintained and using it would introduce another dependency.
#[test]
fn test_config_step_size_zero() {
let r = StepLrSchedulerConfig::new(1.0, 0).init();
assert!(r.is_err(), "Should return an error");
}
#[test]
fn test_config_step_size_nonzero() {
let r = StepLrSchedulerConfig::new(1.0, 1).init();
assert!(r.is_ok(), "Should return a success value");
}
#[test]
fn test_config_default_gamma() {
const INIT_LR: LearningRate = 0.4;
const STEP_SIZE: usize = 2;
let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
.init()
.unwrap();
let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
.with_gamma(0.1)
.init()
.unwrap();
test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
}
#[test]
fn test_lr_decreasing() {
let scheduler = StepLrSchedulerConfig::new(0.5, 3)
.with_gamma(0.1)
.init()
.unwrap();
let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_lr_increasing() {
let scheduler = StepLrSchedulerConfig::new(0.1, 2)
.with_gamma(2.0)
.init()
.unwrap();
let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_lr_unchanging() {
let scheduler = StepLrSchedulerConfig::new(3.1, 1)
.with_gamma(1.0)
.init()
.unwrap();
let expected_lrs = [3.1, 3.1, 3.1];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_save_and_load() {
const STEP_SIZE: usize = 10;
let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
.with_gamma(0.03)
.init()
.unwrap();
test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
}
// It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that
// depends on private fields is used to implement the test.
#[test]
fn test_number_of_calls_within_limit() {
// Create a scheduler that has already run `i32::MAX` steps
let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
scheduler.step();
}
#[test]
#[should_panic = "i32::MAX"]
fn test_number_of_calls_over_limit() {
// Create a scheduler that has already run `i32::MAX` steps
let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
scheduler.step();
scheduler.step();
}
}

View File

@@ -0,0 +1,306 @@
use burn_core as burn;
use burn::{module::AutodiffModule, record::Record};
use burn::config::Config;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use burn::tensor::{backend::Backend, ops::Device};
use super::{
SimpleOptimizer,
adaptor::OptimizerAdaptor,
decay::{WeightDecay, WeightDecayConfig},
};
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
/// AdaGrad configuration.
#[derive(Config, Debug)]
pub struct AdaGradConfig {
#[config(default = 0.)]
lr_decay: f64,
#[config(default = 1e-5)]
epsilon: f32,
/// [Weight decay](WeightDecayConfig) config.
weight_decay: Option<WeightDecayConfig>,
/// [Gradient Clipping](GradientClippingConfig) config.
grad_clipping: Option<GradientClippingConfig>,
}
/// AdaGrad optimizer
#[derive(Clone)]
pub struct AdaGrad {
lr_decay: LrDecay,
weight_decay: Option<WeightDecay>,
}
/// AdaGrad state.
#[derive(Record, Clone, new)]
pub struct AdaGradState<B: Backend, const D: usize> {
lr_decay: LrDecayState<B, D>,
}
impl<B: Backend> SimpleOptimizer<B> for AdaGrad {
type State<const D: usize> = AdaGradState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
mut grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let mut state_lr_decay = None;
if let Some(state) = state {
state_lr_decay = Some(state.lr_decay);
}
if let Some(weight_decay) = &self.weight_decay {
grad = weight_decay.transform(grad, tensor.clone());
}
let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay);
let state = AdaGradState::new(state_lr_decay);
(tensor - grad, Some(state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.lr_decay = state.lr_decay.to_device(device);
state
}
}
impl AdaGradConfig {
/// Initialize AdaGrad optimizer.
///
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<AdaGrad, M, B> {
let optim = AdaGrad {
lr_decay: LrDecay {
lr_decay: self.lr_decay,
epsilon: self.epsilon,
},
weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
};
let mut optim = OptimizerAdaptor::from(optim);
if let Some(config) = &self.grad_clipping {
optim = optim.with_grad_clipping(config.init());
}
optim
}
}
/// Learning rate decay state (also includes sum state).
#[derive(Record, new, Clone)]
pub struct LrDecayState<B: Backend, const D: usize> {
time: usize,
sum: Tensor<B, D>,
}
#[derive(Clone)]
struct LrDecay {
lr_decay: f64,
epsilon: f32,
}
impl LrDecay {
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
lr: LearningRate,
lr_decay_state: Option<LrDecayState<B, D>>,
) -> (Tensor<B, D>, LrDecayState<B, D>) {
let state = if let Some(mut state) = lr_decay_state {
state.sum = state.sum.add(grad.clone().square());
state.time += 1;
state
} else {
LrDecayState::new(1, grad.clone().square())
};
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
let grad = grad
.div(state.sum.clone().sqrt().add_scalar(self.epsilon))
.mul_scalar(new_lr);
(grad, state)
}
}
impl<B: Backend, const D: usize> LrDecayState<B, D> {
/// Move state to device.
///
/// # Arguments
///
/// * `device` - Device to move state to.
///
/// # Returns
///
/// Returns state moved to device.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.sum = self.sum.to_device(device);
self
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
use super::*;
use crate::TestAutodiffBackend;
use crate::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
const LEARNING_RATE: LearningRate = 0.01;
#[test]
fn test_adagrad_optimizer_save_load_state() {
let device = Default::default();
let linear = LinearConfig::new(6, 6).init(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer = create_adagrad();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
#[cfg(feature = "std")]
{
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
BinFileRecorder::<FullPrecisionSettings>::default()
.record(
optimizer.to_record(),
std::env::temp_dir().as_path().join("test_optim_adagrad"),
)
.unwrap();
}
#[cfg(not(feature = "std"))]
{
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
.record(optimizer.to_record(), ())
.unwrap();
assert!(!result.is_empty());
}
let state_optim_before = optimizer.to_record();
let state_optim_before_copy = optimizer.to_record();
let optimizer = create_adagrad();
let optimizer = optimizer.load_record(state_optim_before_copy);
let state_optim_after = optimizer.to_record();
assert_eq!(state_optim_before.len(), state_optim_after.len());
}
#[test]
fn test_adagrad_optimizer_with_numbers() {
let device = Default::default();
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&device,
)
.require_grad();
let mut optimizer = AdaGradConfig::new()
.with_epsilon(1e-8)
.with_lr_decay(0.5)
.init();
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
let weights_expected = TensorData::from([
[-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711],
[
0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756,
],
[
-0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538,
],
[
-0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964,
],
[
0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504,
],
[-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895],
]);
let bias_expected = TensorData::from([
-0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714,
]);
let (weight_updated, bias_updated) = (
state_updated.weight.val().into_data(),
state_updated.bias.unwrap().val().into_data(),
);
type FT = FloatElem<TestAutodiffBackend>;
let tolerance = Tolerance::absolute(1e-6);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: Some(Param::from_data(bias, &device)),
};
LinearConfig::new(6, 6).init(&device).load_record(record)
}
fn create_adagrad()
-> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
let config = AdaGradConfig::new();
AdaGrad {
lr_decay: LrDecay {
lr_decay: config.lr_decay,
epsilon: config.epsilon,
},
weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
}
.into()
}
}

View File

@@ -0,0 +1,521 @@
use burn_core as burn;
use burn::{module::AutodiffModule, record::Record};
use burn::config::Config;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use burn::tensor::{backend::Backend, ops::Device};
use super::{
SimpleOptimizer,
adaptor::OptimizerAdaptor,
decay::{WeightDecay, WeightDecayConfig},
};
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Adam configuration.
#[derive(Config, Debug)]
pub struct AdamConfig {
/// Parameter for Adam.
#[config(default = 0.9)]
beta_1: f32,
/// Parameter for Adam.
#[config(default = 0.999)]
beta_2: f32,
/// A value required for numerical stability.
#[config(default = 1e-5)]
epsilon: f32,
/// Whether to use AMSGrad algorithm
#[config(default = false)]
amsgrad: bool,
/// [Weight decay](WeightDecayConfig) config.
weight_decay: Option<WeightDecayConfig>,
/// [Gradient Clipping](GradientClippingConfig) config.
grad_clipping: Option<GradientClippingConfig>,
}
/// Adam optimizer.
///
/// See:
/// - [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
#[derive(Clone)]
pub struct Adam {
momentum: AdaptiveMomentum,
weight_decay: Option<WeightDecay>,
}
/// Adam state.
#[derive(Record, Clone, new)]
pub struct AdamState<B: Backend, const D: usize> {
/// The current adaptive momentum.
pub momentum: AdaptiveMomentumState<B, D>,
}
impl<B: Backend> SimpleOptimizer<B> for Adam {
type State<const D: usize> = AdamState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
mut grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let mut state_momentum = None;
if let Some(state) = state {
state_momentum = Some(state.momentum);
}
if let Some(weight_decay) = &self.weight_decay {
grad = weight_decay.transform(grad, tensor.clone());
}
let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
let state = AdamState::new(state_momentum);
let delta = grad.mul_scalar(lr);
(tensor - delta, Some(state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.momentum = state.momentum.to_device(device);
state
}
}
impl AdamConfig {
/// Initialize Adam optimizer.
///
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
let optim = Adam {
momentum: AdaptiveMomentum {
beta_1: self.beta_1,
beta_2: self.beta_2,
epsilon: self.epsilon,
amsgrad: self.amsgrad,
},
weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
};
let mut optim = OptimizerAdaptor::from(optim);
if let Some(config) = &self.grad_clipping {
optim = optim.with_grad_clipping(config.init());
}
optim
}
}
/// Adaptive momentum state.
#[derive(Record, new, Clone)]
pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
/// The number of iterations aggregated.
pub time: usize,
/// The first order momentum.
pub moment_1: Tensor<B, D>,
/// The second order momentum.
pub moment_2: Tensor<B, D>,
/// Max of second order momentum (for AMSGrad)
#[new(default)]
pub max_moment_2: Option<Tensor<B, D>>,
}
#[derive(Clone)]
struct AdaptiveMomentum {
beta_1: f32,
beta_2: f32,
epsilon: f32,
amsgrad: bool,
}
impl AdaptiveMomentum {
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
momentum_state: Option<AdaptiveMomentumState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
let state = if let Some(mut state) = momentum_state {
let factor = 1.0 - self.beta_1;
state.moment_1 = state
.moment_1
.mul_scalar(self.beta_1)
.add(grad.clone().mul_scalar(factor));
let factor = 1.0 - self.beta_2;
state.moment_2 = state
.moment_2
.mul_scalar(self.beta_2)
.add(grad.square().mul_scalar(factor));
if self.amsgrad {
let max_v = state
.max_moment_2
.take()
.unwrap_or_else(|| state.moment_2.clone());
let new_max = max_v.max_pair(state.moment_2.clone());
state.max_moment_2 = Some(new_max);
}
state.time += 1;
state
} else {
let factor = 1.0 - self.beta_1;
let moment_1 = grad.clone().mul_scalar(factor);
let factor = 1.0 - self.beta_2;
let moment_2 = grad.square().mul_scalar(factor);
let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
AdaptiveMomentumState {
time: 1,
moment_1,
moment_2,
max_moment_2,
}
};
let time = state.time as i32;
let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt();
let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time));
let v_to_use = if self.amsgrad {
state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
} else {
&state.moment_2
};
let grad = state.moment_1.clone().mul_scalar(combined_factor).div(
v_to_use
.clone()
.sqrt()
.add_scalar(self.epsilon * bias_correction2_sqrt),
);
(grad, state)
}
}
impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
/// Move state to device.
///
/// # Arguments
///
/// * `device` - Device to move state to.
///
/// # Returns
///
/// Returns state moved to device.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.moment_1 = self.moment_1.to_device(device);
self.moment_2 = self.moment_2.to_device(device);
self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device));
self
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
use super::*;
use crate::TestAutodiffBackend;
use crate::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
const LEARNING_RATE: LearningRate = 0.01;
#[test]
fn test_adam_optimizer_save_load_state() {
let device = Default::default();
let linear = LinearConfig::new(6, 6).init(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer = create_adam();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
#[cfg(feature = "std")]
{
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
BinFileRecorder::<FullPrecisionSettings>::default()
.record(
optimizer.to_record(),
std::env::temp_dir().as_path().join("test_optim_adam"),
)
.unwrap();
}
#[cfg(not(feature = "std"))]
{
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
.record(optimizer.to_record(), ())
.unwrap();
assert!(!result.is_empty());
}
let state_optim_before = optimizer.to_record();
let state_optim_before_copy = optimizer.to_record();
let optimizer = create_adam();
let optimizer = optimizer.load_record(state_optim_before_copy);
let state_optim_after = optimizer.to_record();
assert_eq!(state_optim_before.len(), state_optim_after.len());
}
#[test]
fn test_adam_optimizer_with_amsgrad_50_steps() {
let device = Default::default();
let mut linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let mut optimizer = AdamConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_amsgrad(true)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();
for i in 1..=50 {
let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
.mul_scalar(i as f32 * 0.1)
.require_grad();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
linear = optimizer.step(LEARNING_RATE, linear, grads);
}
let state_updated = linear.into_record();
let weight_updated = state_updated.weight.to_data();
let bias_updated = state_updated.bias.unwrap().to_data();
let weights_expected = TensorData::from([
[
-0.9125810265541077,
-0.45855265855789185,
-0.1915993094444275,
-0.2759990692138672,
-0.5099529027938843,
-0.5287043452262878,
],
[
-0.5181325674057007,
-0.6139854788780212,
-0.9574727416038513,
-0.34102925658226013,
-0.400514155626297,
-0.8847861886024475,
],
[
-0.614483118057251,
-0.5611032247543335,
-0.8887064456939697,
-0.34762972593307495,
-0.8708556890487671,
-0.2830044627189636,
],
[
-0.8904699683189392,
-0.8151527643203735,
-0.9621278643608093,
-0.8905676603317261,
-0.671261191368103,
-0.4333854615688324,
],
[
-0.26599061489105225,
-0.8119961023330688,
-0.22424538433551788,
-0.7672406435012817,
-0.2163349837064743,
-0.6258266568183899,
],
[
-0.611397922039032,
-0.6075160503387451,
-0.4701341986656189,
-0.4039117991924286,
-0.5663845539093018,
-0.21262989938259125,
],
]);
let bias_expected = TensorData::from([
-0.8817203044891357,
-0.4038999378681183,
-0.5889149308204651,
-0.37475723028182983,
-0.3557940721511841,
-0.47914788126945496,
]);
type FT = FloatElem<TestAutodiffBackend>;
let tolerance = Tolerance::absolute(1e-5);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
}
#[test]
fn test_adam_optimizer_with_numbers() {
let device = Default::default();
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&device,
)
.require_grad();
let mut optimizer = AdamConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
let weights_expected = TensorData::from([
[-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
[
0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
],
[
-0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
],
[
-0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
],
[
0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
],
[-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
]);
let bias_expected = TensorData::from([
-0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
]);
let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
type FT = FloatElem<TestAutodiffBackend>;
let tolerance = Tolerance::absolute(1e-2);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
#[test]
fn test_adam_optimizer_no_nan() {
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&Default::default(),
)
.require_grad();
let mut optimizer = AdamConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();
let grads = linear.forward(x.clone()).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
}
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: Some(Param::from_data(bias, &device)),
};
LinearConfig::new(6, 6).init(&device).load_record(record)
}
fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
let config = AdamConfig::new();
Adam {
momentum: AdaptiveMomentum {
beta_1: config.beta_1,
beta_2: config.beta_2,
epsilon: config.epsilon,
amsgrad: config.amsgrad,
},
weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
}
.into()
}
}

View File

@@ -0,0 +1,598 @@
use burn_core as burn;
use burn::config::Config;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use burn::tensor::{backend::Backend, ops::Device};
use burn::{module::AutodiffModule, record::Record};
use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// [`AdamW`] Configuration.
#[derive(Config, Debug)]
pub struct AdamWConfig {
/// Parameter for AdamW.
#[config(default = 0.9)]
beta_1: f32,
/// Parameter for AdamW.
#[config(default = 0.999)]
beta_2: f32,
/// A value required for numerical stability.
#[config(default = 1e-5)]
epsilon: f32,
/// Weight decay config.
#[config(default = 1e-4)]
weight_decay: f32,
/// Cautious weight decay config.
///
/// See: <https://arxiv.org/abs/2510.12402>
#[config(default = false)]
cautious_weight_decay: bool,
/// Whether to use AMSGrad algorithm
#[config(default = false)]
amsgrad: bool,
/// [Gradient Clipping](GradientClippingConfig) config.
grad_clipping: Option<GradientClippingConfig>,
}
/// AdamW optimizer.
///
/// See:
/// - [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
/// - [Cautious Weight Decay, 2025](https://arxiv.org/abs/2510.12402)
/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
///
/// Configured by [`AdamWConfig`].
#[derive(Clone)]
pub struct AdamW {
momentum: AdaptiveMomentumW,
weight_decay: f32,
cautious_weight_decay: bool,
}
/// AdamW state.
#[derive(Record, Clone, new)]
pub struct AdamWState<B: Backend, const D: usize> {
/// Th current adaptive momentum state.
pub momentum: AdaptiveMomentumState<B, D>,
}
impl<B: Backend> SimpleOptimizer<B> for AdamW {
type State<const D: usize> = AdamWState<B, D>;
/// A single optimization step for any tensor that represents the parameters of a model.
fn step<const D: usize>(
&self,
// Learning rate.
lr: LearningRate,
// Any tensor that represents the parameters of a model.
tensor: Tensor<B, D>,
// Gradient of the loss w.r.t. the parameters.
grad: Tensor<B, D>,
// State of the optimizer.
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
let decay_rate = lr * (self.weight_decay as f64);
let decayed_tensor = if decay_rate == 0.0 {
tensor.clone()
} else if self.cautious_weight_decay {
// Cautious weight decay.
// See: https://arxiv.org/abs/2510.12402
let tensor_pos = tensor.clone().greater_equal_elem(0.0);
let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);
let differ = tensor_pos.not_equal(grad_pos);
// Zero out the decay where the decay is counter to the update direction.
tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)
} else {
tensor.clone().mul_scalar(1.0 - decay_rate)
};
let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);
let state = AdamWState {
momentum: momentum_state,
};
(tensor_updated, Some(state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.momentum = state.momentum.to_device(device);
state
}
}
impl AdamWConfig {
/// Initialize AdamW optimizer.
///
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
let optim = AdamW {
momentum: AdaptiveMomentumW {
beta_1: self.beta_1,
beta_2: self.beta_2,
epsilon: self.epsilon,
amsgrad: self.amsgrad,
},
weight_decay: self.weight_decay,
cautious_weight_decay: self.cautious_weight_decay,
};
let mut optim = OptimizerAdaptor::from(optim);
if let Some(config) = &self.grad_clipping {
optim = optim.with_grad_clipping(config.init());
}
optim
}
}
#[derive(Clone)]
struct AdaptiveMomentumW {
beta_1: f32,
beta_2: f32,
epsilon: f32,
amsgrad: bool,
}
impl AdaptiveMomentumW {
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
state: Option<AdaptiveMomentumState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
let factor_1 = 1.0 - self.beta_1;
let factor_2 = 1.0 - self.beta_2;
let state = if let Some(mut state) = state {
// Update first moment estimate.
state.moment_1 = state
.moment_1
.mul_scalar(self.beta_1)
.add(grad.clone().mul_scalar(factor_1));
// Update second moment estimate.
state.moment_2 = state
.moment_2
.mul_scalar(self.beta_2)
.add(grad.square().mul_scalar(factor_2));
if self.amsgrad {
let max_v = state
.max_moment_2
.take()
.unwrap_or_else(|| state.moment_2.clone());
state.max_moment_2 = Some(max_v.max_pair(state.moment_2.clone()));
}
// Update time.
state.time += 1;
state
} else {
// Initialize first moment estimate.
let moment_1 = grad.clone().mul_scalar(factor_1);
// Initialize second moment estimate.
let moment_2 = grad.square().mul_scalar(factor_2);
let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
AdaptiveMomentumState {
time: 1,
moment_1,
moment_2,
max_moment_2,
}
};
let time: i32 = state.time as i32;
// Compute bias-corrected first and second moment estimates.
let moment_1_corrected = state
.moment_1
.clone()
.div_scalar(1f32 - self.beta_1.powi(time));
let v_to_use = if self.amsgrad {
state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
} else {
&state.moment_2
};
let moment_2_corrected = v_to_use.clone().div_scalar(1f32 - self.beta_2.powi(time));
let update_delta =
moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
(update_delta, state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use crate::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_nn::{Linear, LinearConfig, LinearRecord};
type FT = FloatElem<TestAutodiffBackend>;
const LEARNING_RATE: LearningRate = 0.01;
#[test]
fn test_adamw_optimizer_save_load_state() {
let device = Default::default();
let linear = LinearConfig::new(6, 6).init(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer = create_adamw();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
#[cfg(feature = "std")]
{
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
BinFileRecorder::<FullPrecisionSettings>::default()
.record(
optimizer.to_record(),
std::env::temp_dir().as_path().join("test_optim_adamw"),
)
.unwrap();
}
#[cfg(not(feature = "std"))]
{
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
.record(optimizer.to_record(), ())
.unwrap();
assert!(!result.is_empty());
}
let state_optim_before = optimizer.to_record();
let state_optim_before_copy = optimizer.to_record();
let optimizer = create_adamw();
let optimizer = optimizer.load_record(state_optim_before_copy);
let state_optim_after = optimizer.to_record();
assert_eq!(state_optim_before.len(), state_optim_after.len());
}
#[test]
fn test_adamw_optimizer_with_amsgrad_50_steps() {
let device = Default::default();
let mut linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let mut optimizer = AdamWConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_amsgrad(true)
.with_weight_decay(0.5)
.init();
for i in 1..=50 {
let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
.mul_scalar(i as f32 * 0.1)
.require_grad();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
linear = optimizer.step(LEARNING_RATE, linear, grads);
}
let state_updated = linear.into_record();
let weight_updated = state_updated.weight.to_data();
let bias_updated = state_updated.bias.unwrap().to_data();
let weights_expected = TensorData::from([
[
-0.7822558283805847,
-0.42578864097595215,
-0.21805696189403534,
-0.28366872668266296,
-0.46587175130844116,
-0.4805040955543518,
],
[
-0.4722539782524109,
-0.5471276640892029,
-0.8181359767913818,
-0.33425918221473694,
-0.3805687427520752,
-0.7601516842842102,
],
[
-0.5475167632102966,
-0.5057991743087769,
-0.763265073299408,
-0.3393959403038025,
-0.7490996718406677,
-0.28911691904067993,
],
[
-0.7646660208702087,
-0.7050473093986511,
-0.8218720555305481,
-0.7647438049316406,
-0.5919585227966309,
-0.40617525577545166,
],
[
-0.27588561177253723,
-0.7025567889213562,
-0.24343004822731018,
-0.6672990918159485,
-0.23728127777576447,
-0.556389570236206,
],
[
-0.5451040267944336,
-0.5420684814453125,
-0.4348171353340149,
-0.3832150399684906,
-0.5099242925643921,
-0.23440153896808624,
],
]);
let bias_expected = TensorData::from([
-0.7473056316375732,
-0.3745720386505127,
-0.5188710689544678,
-0.35184532403945923,
-0.33705732226371765,
-0.4332566559314728,
]);
type FT = FloatElem<TestAutodiffBackend>;
let tolerance = Tolerance::absolute(1e-5);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
}
#[test]
fn test_adamw_optimizer_with_numbers() {
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let device = Default::default();
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&device,
)
.require_grad();
let mut optimizer = AdamWConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(0.5)
.init();
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
let weights_expected = TensorData::from([
[-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
[
0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
],
[
-0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
],
[
-0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
],
[
0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
],
[-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
]);
let bias_expected = TensorData::from([
-0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
]);
let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
let tolerance = Tolerance::absolute(1e-2);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
#[test]
fn test_adamw_optimizer_with_numbers_cautious() {
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let device = Default::default();
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],
],
&device,
)
.require_grad();
let mut optimizer = AdamWConfig::new()
.with_cautious_weight_decay(true)
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(0.5)
.init();
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
let weights_expected = TensorData::from([
[-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
[
0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
],
[
-0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
],
[
-0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
],
[
0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
],
[
-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,
],
]);
let bias_expected = TensorData::from([
-0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
]);
let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
let tolerance = Tolerance::absolute(1e-2);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
#[test]
fn test_adam_optimizer_no_nan() {
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&Default::default(),
)
.require_grad();
let mut optimizer = AdamWConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(0.5)
.init();
let grads = linear.forward(x.clone()).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
}
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: Some(Param::from_data(bias, &device)),
};
LinearConfig::new(6, 6).init(&device).load_record(record)
}
fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
let config = AdamWConfig::new();
AdamW {
momentum: AdaptiveMomentumW {
beta_1: config.beta_1,
beta_2: config.beta_2,
epsilon: config.epsilon,
amsgrad: config.amsgrad,
},
weight_decay: config.weight_decay,
cautious_weight_decay: false,
}
.into()
}
}

View File

@@ -0,0 +1,88 @@
use burn_core::{self as burn, Tensor};
use burn_core::module::ParamId;
use burn_core::prelude::{Backend, DeviceOps};
use burn_core::tensor::Device;
use burn_core::tensor::backend::DeviceId;
use super::GradientsParams;
use crate::LearningRate;
use alloc::vec::Vec;
use burn::module::AutodiffModule;
use burn::record::Record;
use burn::tensor::backend::AutodiffBackend;
#[derive(Default)]
/// Exposes multiple gradients for each parameter.
pub struct MultiGradientsParams {
/// Each [GradientsParams] has its associated [DeviceId].
pub grads: Vec<(GradientsParams, DeviceId)>,
}
impl MultiGradientsParams {
/// Removes the gradients for the given [parameter id](ParamId).
///
/// Potentially accumulates the gradients from multiple sources using a device associated with
/// a parameter id. The same parameter will be accumulated using the same device during
/// all training.
pub fn remove<B: Backend, const D: usize>(
&mut self,
id: ParamId,
) -> Option<(Tensor<B, D>, Device<B>)> {
let (mut tensor, device, index) = self.select(id)?;
for (i, (grads, _)) in self.grads.iter_mut().enumerate() {
if i == index {
continue;
}
if let Some(grad) = grads.remove::<B, D>(id) {
tensor = tensor + grad.to_device(&device);
}
}
Some((tensor, device))
}
fn select<B: Backend, const D: usize>(
&mut self,
id: ParamId,
) -> Option<(Tensor<B, D>, Device<B>, usize)> {
let id_val = id.val() as usize;
for i in 0..self.grads.len() {
let selected_device_index = (id_val + i) % self.grads.len();
if let Some(acc) = self.grads[selected_device_index].0.remove::<B, D>(id) {
let device_id = self.grads[selected_device_index].1;
let device = <B::Device as DeviceOps>::from_id(device_id);
return Some((acc.to_device(&device), device, selected_device_index));
}
}
None
}
}
/// General trait to optimize [module](AutodiffModule).
pub trait Optimizer<M, B>: Send + Clone
where
M: AutodiffModule<B>,
B: AutodiffBackend,
{
/// Optimizer associative type to be used when saving and loading the state.
type Record: Record<B>;
/// Perform the optimizer step using the given learning rate and gradients.
/// The updated module is returned.
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
/// Perform the optimizer step using the given learning rate and gradients.
/// The updated module is returned.
fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M;
/// Get the current state of the optimizer as a [record](Record).
fn to_record(&self) -> Self::Record;
/// Load the state of the optimizer as a [record](Record).
fn load_record(self, record: Self::Record) -> Self;
}

View File

@@ -0,0 +1,68 @@
use burn_core as burn;
use burn::config::Config;
use burn::record::Record;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
/// Configuration to create [weight decay](WeightDecay).
#[derive(Config, Debug)]
pub struct WeightDecayConfig {
/// L2 penalty.
pub penalty: f32,
}
/// State of [weight decay](WeightDecay).
#[derive(Record, Clone, new)]
pub struct WeightDecayState<B: Backend, const D: usize> {
pub(crate) grad_last_step: Tensor<B, D>,
}
/// Weight decay implementation that transforms gradients.
#[derive(Clone)]
pub struct WeightDecay {
penalty: f32,
}
impl WeightDecay {
/// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig).
pub fn new(config: &WeightDecayConfig) -> Self {
Self {
penalty: config.penalty,
}
}
/// Transforms a gradient.
///
/// # Arguments
///
/// * `grad` - Gradient to transform.
/// * `tensor` - Tensor param of the last iteration.
///
/// # Returns
///
/// * `grad` - Transformed gradient.
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
tensor.mul_scalar(self.penalty).add(grad)
}
}
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.grad_last_step = self.grad_last_step.to_device(device);
self
}
}

View File

@@ -0,0 +1,121 @@
use burn_core as burn;
use core::marker::PhantomData;
use burn::module::{AutodiffModule, ModuleVisitor, Param};
use burn::tensor::{Tensor, backend::AutodiffBackend};
use super::GradientsParams;
/// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object.
pub struct GradientsAccumulator<M> {
grads: GradientsParams,
phantom: PhantomData<M>,
}
impl<M> Default for GradientsAccumulator<M> {
fn default() -> Self {
Self::new()
}
}
impl<M> GradientsAccumulator<M> {
/// Create a new gradients accumulator.
pub fn new() -> Self {
Self {
grads: GradientsParams::new(),
phantom: PhantomData,
}
}
}
impl<M> GradientsAccumulator<M> {
/// Accumulate the given gradients for each parameter in the given module.
pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)
where
M: AutodiffModule<B>,
{
let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
module.visit(&mut visitor);
}
/// Return the accumulated gradients and reset the accumulator state.
pub fn grads(&mut self) -> GradientsParams {
let mut grads = GradientsParams::new();
core::mem::swap(&mut self.grads, &mut grads);
grads
}
}
#[derive(new)]
struct ModuleGradsAccumulator<'a, M> {
grads: &'a mut GradientsParams,
grads_new: GradientsParams,
phantom: PhantomData<M>,
}
impl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(param.id) {
Some(new) => match self.grads.remove::<B::InnerBackend, D>(param.id) {
Some(grad) => grad.add(new),
None => new,
},
None => match self.grads.remove::<B::InnerBackend, D>(param.id) {
Some(grad) => grad,
None => return,
},
};
self.grads
.register::<B::InnerBackend, D>(param.id, grad_updated);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use burn::tensor::{Distribution, backend::Backend};
use burn_nn::{Linear, LinearConfig};
#[test]
fn test_accumulate_gradients_one_step() {
let device = Default::default();
let mut accumulator = GradientsAccumulator::new();
let layer = layer::<TestAutodiffBackend>(&device);
let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
let grads = GradientsParams::from_grads(loss.backward(), &layer);
accumulator.accumulate(&layer, grads);
let grads = accumulator.grads();
assert!(!grads.is_empty())
}
#[test]
fn test_accumulate_gradients_two_steps() {
let device = Default::default();
let mut accumulator = GradientsAccumulator::new();
let layer = layer::<TestAutodiffBackend>(&device);
let loss_1 = layer.forward(random_tensor(&device));
let loss_2 = layer.forward(random_tensor(&device));
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
accumulator.accumulate(&layer, grads_1);
accumulator.accumulate(&layer, grads_2);
let grads = accumulator.grads();
assert_eq!(grads.len(), 2)
}
fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
LinearConfig::new(20, 20).init(device)
}
fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
}
}

View File

@@ -0,0 +1,192 @@
use burn_core as burn;
#[cfg(feature = "collective")]
use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};
use burn::{
Tensor,
tensor::{
backend::{AutodiffBackend, Backend},
container::TensorContainer,
},
};
use burn::module::{AutodiffModule, ParamId};
use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
/// Data type that contains gradients for parameters.
#[derive(Default, Debug)]
pub struct GradientsParams {
container: TensorContainer<ParamId>,
}
impl GradientsParams {
/// Creates a new [GradientsParams](GradientsParams).
pub fn new() -> Self {
Self::default()
}
/// Extract each tensor gradients for the given [module](AutodiffModule).
///
/// Note: This consumes the gradients. See ['from_module'] to extract gradients only for
/// a specific module.
pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
grads: B::Gradients,
module: &M,
) -> Self {
let mut grads = grads;
Self::from_module(&mut grads, module)
}
/// Extract each tensor gradients for the given [module](AutodiffModule).
pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
grads: &mut B::Gradients,
module: &M,
) -> Self {
let mut grads_params = GradientsParams::new();
let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
module.visit(&mut visitor);
grads_params
}
/// Extract tensor gradients for the given [module](AutodiffModule) and given parameters.
pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
grads: &mut B::Gradients,
module: &M,
params: &[ParamId],
) -> Self {
let mut grads_params = GradientsParams::new();
let mut visitor =
GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
module.visit(&mut visitor);
grads_params
}
/// Get the gradients for the given [parameter id](ParamId).
///
/// # Notes
///
/// You should use [remove](GradientsParams::remove) if you want to get the gradients
/// only one time.
pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
where
B: Backend,
{
self.container.get(&id).map(Tensor::from_primitive)
}
/// Remove the gradients for the given [parameter id](ParamId).
pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
where
B: Backend,
{
self.container.remove(&id).map(Tensor::from_primitive)
}
/// Register a gradients tensor for the given [parameter id](ParamId).
///
/// # Notes
///
/// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced.
pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
where
B: Backend,
{
self.container.register(id, value.into_primitive())
}
/// The number of gradients tensors registered.
pub fn len(&self) -> usize {
self.container.len()
}
/// If any tensor is contained.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Change the device of each tensor gradients registered for the given [module](AutodiffModule).
pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
mut self,
device: &B::Device,
module: &M,
) -> Self {
let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
module.visit(&mut visitor);
self
}
/// Syncs the gradient params with the other peers in the collective.
#[cfg(feature = "collective")]
pub fn all_reduce<B: Backend>(
mut self,
peer_id: PeerId,
op: ReduceOperation,
) -> Result<Self, CollectiveError> {
let mut ids = self
.container
.ids()
.into_iter()
.copied()
.collect::<Vec<ParamId>>();
// This is crucial, since the all-reduce operations need to happen in the same order for the same parameters on all nodes!
ids.sort();
for id in ids {
let Some(grad) = self.container.remove::<B>(&id) else {
todo!()
};
let grad = match grad {
burn::tensor::TensorPrimitive::Float(grad) => {
let grad = all_reduce::<B>(peer_id, grad, op)?;
burn::tensor::TensorPrimitive::Float(grad)
}
burn::tensor::TensorPrimitive::QFloat(_grad) => {
unimplemented!("quantized all-reduce unimplemented")
}
};
self.container.register::<B>(id, grad);
}
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use burn::module::{Module, list_param_ids};
use burn::tensor::{Distribution, backend::Backend};
use burn_nn::{Linear, LinearConfig};
#[test]
fn test_convert_grads() {
let device = Default::default();
let layer_1 = layer::<TestAutodiffBackend>(&device);
let mut layer_2 = layer_1.clone();
layer_2 = layer_2.fork(&device);
let loss_1 = layer_1.forward(random_tensor(&device));
let loss_2 = layer_2.forward(random_tensor(&device));
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
let param_ids_1 = list_param_ids(&layer_1);
let param_ids_2 = list_param_ids(&layer_2);
assert_eq!(param_ids_1, param_ids_2);
assert_eq!(grads_1.len(), param_ids_1.len());
assert_eq!(grads_2.len(), param_ids_2.len());
}
fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
LinearConfig::new(20, 20).init(device)
}
fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
}
}

View File

@@ -0,0 +1,978 @@
#![allow(clippy::excessive_precision)]
use burn_core as burn;
use super::GradientsParams;
use crate::LearningRate;
use burn::config::Config;
use burn::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
use burn::prelude::ToElement;
use burn::record::Record;
use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use serde::{Deserialize, Serialize};
use alloc::vec;
use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Cubic Interpolate
///
/// Uses two points (x1, f1), (x2, f2) and their first derivatives g1,g2 to construct
/// a cubic interpolant and return its minimum within the given bounds.
fn cubic_interpolate(
x1: f64,
f1: f64,
g1: f64,
x2: f64,
f2: f64,
g2: f64,
bounds: Option<(f64, f64)>,
) -> f64 {
// Compute bounds of interpolation area
let (min_bound, max_bound) = bounds.unwrap_or(if x1 <= x2 { (x1, x2) } else { (x2, x1) });
// Code for most common case: cubic interpolation of 2 points
// with function and derivative values for both
// Solution in this case (where x2 is the farthest point)
// d1 = g1 + g2 - 3*(f1 - f2) / (x1-x2);
// d2 = sqrt(d1^2 - g1 * g2);
// min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
// t_new = min(max(min_pos,min_bound), max_bound);
let d1 = g1 + g2 - 3.0 * (f1 - f2) / (x1 - x2);
let d2_square = d1 * d1 - g1 * g2;
if d2_square >= 0.0 {
let d2 = d2_square.sqrt();
let min_pos = if x1 <= x2 {
x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2.0 * d2))
} else {
x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2.0 * d2))
};
min_pos.max(min_bound).min(max_bound)
} else {
(min_bound + max_bound) / 2.0
}
}
/// Auxiliary Struct For Strong_Wolfe
struct LineSearchSample<B: Backend> {
// step size
t: f64,
// loss
f: f64,
// gradient
g: Tensor<B, 1>,
// directional derivative
gtd: f64,
}
#[allow(clippy::too_many_arguments)]
fn strong_wolfe<B: Backend, F>(
// obj_func(x,step size,direction) -> (loss,grad)
obj_func: &mut F,
x: &Tensor<B, 1>,
// initial step size
mut t: f64,
d: &Tensor<B, 1>,
f: f64,
g: Tensor<B, 1>,
gtd: f64,
c1: f64,
c2: f64,
tolerance_change: f64,
max_ls: usize,
) -> (f64, Tensor<B, 1>, f64, usize)
where
F: FnMut(&Tensor<B, 1>, f64, &Tensor<B, 1>) -> (f64, Tensor<B, 1>),
{
let d_norm = d.clone().abs().max().into_scalar().to_f64();
// evaluate objective and gradient using initial step
let (mut f_new, mut g_new) = obj_func(x, t, d);
let mut ls_func_evals = 1;
let mut gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
// bracket an interval [t_prev,t] containing a point satisfying the Wolfe criteria
let (mut t_prev, mut f_prev, mut g_prev, mut gtd_prev) = (0.0, f, g.clone(), gtd);
let mut done = false;
let mut ls_iter = 0;
// the interval [low,high] using for Zoom phase
let mut bracket: Option<[LineSearchSample<B>; 2]> = None;
// point which satisfy the wolfe condition
let mut wolfe_bracket: Option<LineSearchSample<B>> = None;
while ls_iter < max_ls {
// Checking Conditions.
// Checking the Armijo Condition and function value increasing condition.
// Armijo: f(x+t*d) <= f(x) + c_1 t gtd
if f_new > (f + c1 * t * gtd) || (ls_iter > 1 && f_new >= f_prev) {
bracket = Some([
LineSearchSample {
t: t_prev,
f: f_prev,
g: g_prev,
gtd: gtd_prev,
},
LineSearchSample {
t,
f: f_new,
g: g_new.clone(),
gtd: gtd_new,
},
]);
break;
}
// Checking Strong Wolfe Condition
// |gtd_new| <= -c_2 gtd
if gtd_new.abs() <= -c2 * gtd {
wolfe_bracket = Some(LineSearchSample {
t,
f: f_new,
g: g_new.clone(),
gtd: gtd_new,
});
done = true;
break;
}
// gtd_new >=0 , there must be a local minimum in the interval.
if gtd_new >= 0.0 {
bracket = Some([
LineSearchSample {
t: t_prev,
f: f_prev,
g: g_prev,
gtd: gtd_prev,
},
LineSearchSample {
t,
f: f_new,
g: g_new.clone(),
gtd: gtd_new,
},
]);
break;
}
// interpolate
let min_step = t + 0.01 * (t - t_prev);
let max_step = t * 10.0;
let t_next = cubic_interpolate(
t_prev,
f_prev,
gtd_prev,
t,
f_new,
gtd_new,
Some((min_step, max_step)),
);
t_prev = t;
f_prev = f_new;
g_prev = g_new;
gtd_prev = gtd_new;
// next step
t = t_next;
(f_new, g_new) = obj_func(x, t, d);
ls_func_evals += 1;
gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
ls_iter += 1;
}
if let Some(sample) = wolfe_bracket {
return (sample.f, sample.g, sample.t, ls_func_evals);
}
let mut bracket = bracket.unwrap_or_else(|| {
[
LineSearchSample {
t: 0.0,
f,
g: g.clone(),
gtd,
},
LineSearchSample {
t,
f: f_new,
g: g_new.clone(),
gtd: gtd_new,
},
]
});
// zoom phase
let mut insuf_progress = false;
// find high and low points in bracket
let (mut low_idx, mut high_idx) = if bracket[0].f <= bracket[1].f {
(0, 1)
} else {
(1, 0)
};
while !done && ls_iter < max_ls {
let diff = (bracket[1].t - bracket[0].t).abs();
// line-search bracket is so small
if diff * d_norm < tolerance_change {
break;
}
// compute new trial value
t = cubic_interpolate(
bracket[0].t,
bracket[0].f,
bracket[0].gtd,
bracket[1].t,
bracket[1].f,
bracket[1].gtd,
None,
);
let b_min = bracket[0].t.min(bracket[1].t);
let b_max = bracket[0].t.max(bracket[1].t);
let eps = 0.1 * (b_max - b_min);
if (b_max - t).min(t - b_min) < eps {
// interpolation close to boundary
if insuf_progress || t >= b_max || t <= b_min {
t = if (t - b_max).abs() < (t - b_min).abs() {
b_max - eps
} else {
b_min + eps
};
insuf_progress = false;
} else {
insuf_progress = true;
}
} else {
insuf_progress = false;
}
// Evaluate new point
(f_new, g_new) = obj_func(x, t, d);
ls_func_evals += 1;
gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
ls_iter += 1;
let armijo_holds = f_new <= (f + c1 * t * gtd) && f_new < bracket[low_idx].f;
if !armijo_holds {
bracket[high_idx] = LineSearchSample {
t,
f: f_new,
g: g_new,
gtd: gtd_new,
};
} else {
if gtd_new.abs() <= -c2 * gtd {
return (f_new, g_new, t, ls_func_evals);
}
if gtd_new * (bracket[high_idx].t - bracket[low_idx].t) >= 0.0 {
bracket[high_idx] = LineSearchSample {
t: bracket[low_idx].t,
f: bracket[low_idx].f,
g: bracket[low_idx].g.clone(),
gtd: bracket[low_idx].gtd,
};
}
bracket[low_idx] = LineSearchSample {
t,
f: f_new,
g: g_new,
gtd: gtd_new,
};
}
if bracket[0].f <= bracket[1].f {
low_idx = 0;
high_idx = 1;
} else {
low_idx = 1;
high_idx = 0;
}
}
// return stuff
(
bracket[low_idx].f,
bracket[low_idx].g.clone(),
bracket[low_idx].t,
ls_func_evals,
)
}
/// Strategy for the line search optimization phase
#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LineSearchFn {
/// No line search performed
#[default]
None,
/// strong wolfe conditions
///
/// See: <https://en.wikipedia.org/wiki/Wolfe_conditions>
StrongWolfe,
}
/// LBFGS Configuration.
#[derive(Config, Debug)]
pub struct LBFGSConfig {
/// Maximal number of iterations per optimization step (default: 20)
#[config(default = 20)]
pub max_iter: usize,
/// Update history size (default: 100).
#[config(default = 100)]
pub history_size: usize,
/// Termination tolerance on first order optimality (default: 1e-7).
#[config(default = 1e-7)]
pub tolerance_grad: f64,
/// Termination tolerance on function value/parameter changes (default: 1e-9).
#[config(default = 1e-9)]
pub tolerance_change: f64,
/// Maximal number of function evaluations per optimization step (default: max_iter * 1.25).
#[config(default = "None")]
pub max_eval: Option<usize>,
/// Either strong_wolfe or None (default: None).
#[config(default = "LineSearchFn::None")]
pub line_search_fn: LineSearchFn,
}
impl LBFGSConfig {
/// Initialize AdamW optimizer
///
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module
pub fn init<B: AutodiffBackend>(&self) -> LBFGS<B> {
// by default max_eval = max_iter * 5/4
let max_eval = self.max_eval.unwrap_or(self.max_iter * 5 / 4);
LBFGS {
config: LBFGSConfig {
max_iter: self.max_iter,
history_size: self.history_size,
tolerance_grad: self.tolerance_grad,
tolerance_change: self.tolerance_change,
max_eval: Some(max_eval),
line_search_fn: self.line_search_fn,
},
state: Default::default(),
}
}
}
/// Collects gradients in module visit order.
struct FlattenGradsVisitorInner<'a, B: AutodiffBackend> {
grads: &'a GradientsParams,
tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
}
impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenGradsVisitorInner<'_, B> {
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
if let Some(g) = self.grads.get::<B::InnerBackend, D>(param.id) {
let numel = g.shape().num_elements();
self.tensors.push(g.reshape([numel]));
}
}
}
/// Flatten params to inner backend 1D tensor.
fn flatten_params_inner<B: AutodiffBackend, M: Module<B>>(
module: &M,
) -> Tensor<B::InnerBackend, 1> {
let mut tensors = Vec::new();
let mut visitor = FlattenParamsVisitorInner::<B> {
tensors: &mut tensors,
};
module.visit(&mut visitor);
if tensors.is_empty() {
return Tensor::empty([0], &module.devices()[0]);
}
Tensor::cat(tensors, 0)
}
struct FlattenParamsVisitorInner<'a, B: AutodiffBackend> {
tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
}
impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenParamsVisitorInner<'_, B> {
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let t = param.val().inner();
let numel = t.shape().num_elements();
self.tensors.push(t.reshape([numel]));
}
}
/// Flatten gradients for a module.
fn flatten_grads_inner<B: AutodiffBackend, M: Module<B>>(
module: &M,
grads: &GradientsParams,
) -> Tensor<B::InnerBackend, 1> {
let mut tensors = Vec::new();
let mut visitor = FlattenGradsVisitorInner {
grads,
tensors: &mut tensors,
};
module.visit(&mut visitor);
if tensors.is_empty() {
return Tensor::empty([0], &module.devices()[0]);
}
Tensor::cat(tensors, 0)
}
/// Mapper that assigns each float param from a flat inner-backend 1D tensor.
struct ParamsFromFlatMapperInner<'a, B: AutodiffBackend> {
flat: &'a Tensor<B::InnerBackend, 1>,
offset: &'a mut usize,
}
impl<B: AutodiffBackend> ParamsFromFlatMapperInner<'_, B> {
fn take_slice(&mut self, numel: usize) -> Tensor<B::InnerBackend, 1> {
let start = *self.offset;
*self.offset += numel;
self.flat.clone().slice(start..*self.offset)
}
}
impl<B: AutodiffBackend> ModuleMapper<B> for ParamsFromFlatMapperInner<'_, B> {
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let numel = tensor.shape().num_elements();
let slice_1d = self.take_slice(numel);
let new_inner = slice_1d.reshape(tensor.shape());
let new_tensor = Tensor::from_inner(new_inner).require_grad();
Param::from_mapped_value(id, new_tensor, mapper)
}
}
/// Overwrite module parameters from a flat inner-backend 1D tensor
fn set_params_from_flat_inner<B: AutodiffBackend, M: Module<B>>(
module: M,
flat: Tensor<B::InnerBackend, 1>,
) -> M {
let mut offset = 0;
let mut mapper = ParamsFromFlatMapperInner {
flat: &flat,
offset: &mut offset,
};
module.map(&mut mapper)
}
/// L-BFGS optimizer state
#[derive(Clone, Record)]
pub struct LBFGSState<B: Backend> {
/// Historical displacement vectors
pub history_s: Vec<Tensor<B, 1>>,
/// Historical gradient difference vectors
pub history_y: Vec<Tensor<B, 1>>,
/// Search direction
pub d: Option<Tensor<B, 1>>,
/// Step size from the previous iteration
pub t: Option<f64>,
/// Flattened gradient from the previous iteration
pub prev_flat_grad: Option<Tensor<B, 1>>,
/// Loss value from the previous iteration
pub prev_loss: Option<f64>,
/// Global iteration count
pub g_iter: usize,
}
impl<B: Backend> LBFGSState<B> {
/// Moves all historical tensors to the target device.
pub fn to_device(self, device: &B::Device) -> Self {
Self {
history_s: self
.history_s
.into_iter()
.map(|t| t.to_device(device))
.collect(),
history_y: self
.history_y
.into_iter()
.map(|t| t.to_device(device))
.collect(),
d: self.d.map(|t| t.to_device(device)),
t: self.t,
prev_flat_grad: self.prev_flat_grad.map(|t| t.to_device(device)),
prev_loss: self.prev_loss,
g_iter: self.g_iter,
}
}
}
impl<B: Backend> Default for LBFGSState<B> {
fn default() -> Self {
Self {
history_s: Vec::new(),
history_y: Vec::new(),
d: None,
t: Some(1.0),
prev_flat_grad: None,
prev_loss: None,
g_iter: 0,
}
}
}
/// L-BFGS optimizer.
///
/// Ported from [pytorch](https://github.com/pytorch/pytorch/torch/optim/lbfgs.py). Heavily inspired by [miniFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html)
///
/// See also:
/// - [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS)
///
/// # Note
/// This optimizer is memory intensive
#[derive(Clone)]
pub struct LBFGS<B: Backend + AutodiffBackend> {
config: LBFGSConfig,
state: LBFGSState<B::InnerBackend>,
}
impl<B: Backend + AutodiffBackend> LBFGS<B> {
/// A single optimization step for any tensor that represents the parameters of a model.
pub fn step<M, F>(&mut self, lr: LearningRate, mut module: M, mut closure: F) -> (M, f64)
where
M: AutodiffModule<B> + Clone,
F: FnMut(M) -> (f64, GradientsParams),
{
// evaluate initial f(x) and df/dx
let (mut loss, grads) = closure(module.clone());
let mut current_evals = 1;
let mut flat_grad = flatten_grads_inner::<B, M>(&module, &grads);
let mut x_flat = flatten_params_inner::<B, M>(&module);
let opt_cond =
flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad;
// optimal condition
if opt_cond {
return (module, loss);
}
// tensors cached in state
let mut d = self
.state
.d
.take()
.unwrap_or_else(|| flat_grad.clone().neg());
let mut t = self.state.t.unwrap_or(lr);
let mut prev_flat_grad = self.state.prev_flat_grad.take();
let mut n_iter = 0;
// optimize for a max of max_iter iterations
while n_iter < self.config.max_iter {
// keep track of nb of iterations
n_iter += 1;
self.state.g_iter += 1;
// compute gradient descent direction
if self.state.g_iter == 1 {
d = flat_grad.clone().neg();
self.state.history_s.clear();
self.state.history_y.clear();
} else {
// do lbfgs update (update memory)
if let Some(pg) = prev_flat_grad.as_ref() {
let y = flat_grad.clone().sub(pg.clone());
let s = d.clone().mul_scalar(t);
let ys = y.clone().dot(s.clone()).into_scalar().to_f64();
if ys > 1e-10 {
// updating memory
if self.state.history_s.len() >= self.config.history_size {
// shift history by one (limited-memory)
self.state.history_s.remove(0);
self.state.history_y.remove(0);
}
self.state.history_s.push(s);
self.state.history_y.push(y);
}
}
// compute the approximate (L-BFGS) inverse Hessian
// multiplied by the gradient
let num_old = self.state.history_s.len();
let mut q = flat_grad.clone().neg();
let mut alphas: Vec<Tensor<B::InnerBackend, 1>> =
vec![Tensor::zeros([1], &flat_grad.device()); num_old];
if num_old > 0 {
// multiply by initial Hessian
// r/d is the final direction
for i in (0..num_old).rev() {
let s = &self.state.history_s[i];
let y = &self.state.history_y[i];
let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
let alpha = rho.clone().mul(s.clone().dot(q.clone()));
alphas[i] = alpha.clone();
q = q.sub(y.clone().mul(alpha));
}
let last_s = &self.state.history_s[num_old - 1];
let last_y = &self.state.history_y[num_old - 1];
let ys = last_y.clone().dot(last_s.clone());
let yy = last_y.clone().dot(last_y.clone());
let h_diag = ys.div(yy);
let mut r = q.mul(h_diag);
for ((s, y), alpha) in self
.state
.history_s
.iter()
.zip(self.state.history_y.iter())
.zip(alphas.into_iter())
.take(num_old)
{
let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
let beta = rho.mul(y.clone().dot(r.clone()));
r = r.add(s.clone().mul(alpha.sub(beta)));
}
d = r;
} else {
d = q;
}
}
prev_flat_grad = Some(flat_grad.clone());
let prev_loss_iter = loss;
// compute step len
if self.state.g_iter == 1 {
let grad_l1 = flat_grad.clone().abs().sum().into_scalar().to_f64();
t = (1.0f64 / grad_l1).min(1.0) * lr;
} else {
t = lr;
}
// directional derivative
let gtd = flat_grad.clone().dot(d.clone()).into_scalar().to_f64();
if gtd > -self.config.tolerance_change {
break;
}
let ls_func_evals;
if let LineSearchFn::StrongWolfe = self.config.line_search_fn {
// perform line search, using user function
let mut obj_func =
|current_x: &Tensor<B::InnerBackend, 1>,
step: f64,
dir: &Tensor<B::InnerBackend, 1>| {
let update = dir.clone().mul_scalar(step);
let new_x = current_x.clone().add(update);
let tmp_module = set_params_from_flat_inner::<B, M>(module.clone(), new_x);
let (l, g) = closure(tmp_module);
(l, flatten_grads_inner::<B, M>(&module, &g))
};
let (ls_f, ls_g, ls_t, evals) = strong_wolfe(
&mut obj_func,
&x_flat,
t,
&d,
loss,
flat_grad.clone(),
gtd,
1e-4,
0.9,
self.config.tolerance_change,
self.config.max_eval.unwrap() - current_evals,
);
loss = ls_f;
flat_grad = ls_g;
t = ls_t;
ls_func_evals = evals;
x_flat = x_flat.add(d.clone().mul_scalar(t));
module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
} else {
// no line search, simply move with fixed-step
let step_vec = d.clone().mul_scalar(t);
x_flat = x_flat.add(step_vec);
module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
// re-evaluate function only if not in last iteration
// the reason we do this: in a stochastic setting,
// no use to re-evaluate that function here
let (new_loss, new_grads) = closure(module.clone());
loss = new_loss;
flat_grad = flatten_grads_inner::<B, M>(&module, &new_grads);
ls_func_evals = 1;
}
// update func eval
current_evals += ls_func_evals;
// check conditions
if current_evals >= self.config.max_eval.unwrap() {
break;
}
if flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad {
break;
}
if d.clone().mul_scalar(t).abs().max().into_scalar().to_f64()
<= self.config.tolerance_change
{
break;
}
if (loss - prev_loss_iter).abs() < self.config.tolerance_change {
break;
}
}
self.state.d = Some(d);
self.state.t = Some(t);
self.state.prev_flat_grad = prev_flat_grad;
self.state.prev_loss = Some(loss);
(module, loss)
}
/// Moves the optimizer state to the specified device.
pub fn to_device(self, device: &B::Device) -> Self {
Self {
config: self.config,
// History tensors reside in InnerBackend, so we convert the device accordingly
state: self.state.to_device(device),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GradientsParams;
use crate::TestAutodiffBackend;
use burn::module::{Module, Param};
use burn::tensor::{Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: Some(Param::from_data(bias, &device)),
};
LinearConfig::new(6, 6).init(&device).load_record(record)
}
#[test]
fn test_cubic_interpolate() {
let tolerance = 1e-8;
// basic
let (x1, f1, g1, x2, f2, g2) = (-1.0, 1.0, -2.0, 1.0, 1.0, 2.0);
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
assert!(
(result - 0.00000).abs() < tolerance,
"Basic: Result {} should be close to 0.0",
result
);
// bound
let (x1, f1, g1, x2, f2, g2) = (0.0, 0.25, -1.0, 1.0, 0.25, 1.0);
let bounds = Some((0.6, 1.0));
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds);
assert!(
(result - 0.6000000000).abs() < tolerance,
"Bound: Result {} should be clamped to 0.6",
result
);
// d2_square < 0,should return mid value
let (x1, f1, g1, x2, f2, g2) = (0.0, 0.0, 10.0, 1.0, 5.0, 10.0);
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((0.0, 1.0)));
assert!(
(result - 0.5000000).abs() < tolerance,
"Fallback: Result {} should be midpoint 0.5",
result
);
// asymmetric
let (x1, f1, g1, x2, f2, g2) = (0.0, 1.0, -5.0, 1.0, 0.5, 1.0);
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
assert!(
(result - 0.4606553370833684).abs() < tolerance,
"Asymmetric: Result {} should be 0.4606553370833684",
result
);
// not good value
let (x1, f1, g1, x2, f2, g2) = (
1.231232145,
-0.12567458754,
9.1231243007,
8.239105015,
-100.9012398021,
123201321.0293982,
);
let result_1 = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
let result_2 = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((-4.4, 4.4)));
assert!(
(result_1 - 5.9031480234724434).abs() < tolerance,
"not good value 1: Result {} should be 5.9031480234724434",
result
);
assert!(
(result_2 - 4.4000000000000004).abs() < tolerance,
"not good value 2: Result {} should be 4.4000000000000004",
result
);
}
#[test]
fn test_strong_wolfe_direct_comparison() {
let device = Default::default();
let tol = 1e-8;
{
let x = Tensor::<TestAutodiffBackend, 1>::from_floats([2.1321912957_f64], &device);
let d = Tensor::<TestAutodiffBackend, 1>::from_floats([0.91312321_f64], &device);
let t_initial = 1.213132_f64;
fn func<B: Backend>(
x_base: &Tensor<B, 1>,
t_val: f64,
d_vec: &Tensor<B, 1>,
) -> (f64, Tensor<B, 1>) {
let curr_x = x_base.clone().add(d_vec.clone().mul_scalar(t_val));
let x2 = curr_x.clone().mul(curr_x.clone());
let x3 = x2.clone().mul(curr_x.clone());
let x4 = x2.clone().mul(x2.clone());
// f(x) = x^4 - 2*x^2 + x
let f_elements = x4 - x2.mul_scalar(2.0) + curr_x.clone();
let f_val = f_elements.sum().into_scalar().to_f64();
// g(x) = 4*x^3 - 4*x + 1
let g = x3.mul_scalar(4.0) - curr_x.clone().mul_scalar(4.0)
+ Tensor::ones_like(&curr_x);
(f_val, g)
}
let (f_init, g_init) = func(&x, 0.0, &d);
let gtd_init = g_init.clone().dot(d.clone()).into_scalar().to_f64();
println!("Initial State: f={},gtd = {}", f_init, gtd_init);
assert!((f_init - 13.7080059052).abs() < tol);
assert!((gtd_init - 28.5305728912).abs() < tol);
let mut obj_func =
|xb: &Tensor<TestAutodiffBackend, 1>,
tv: f64,
dv: &Tensor<TestAutodiffBackend, 1>| func(xb, tv, dv);
let (f_final, _g_final, t_final, evals) = strong_wolfe(
&mut obj_func,
&x,
t_initial,
&d,
f_init,
g_init,
gtd_init,
1e-4, // c1
0.9, // c2
1e-9, // tolerance_change
10, // max_ls
);
let g_f = _g_final.into_scalar().to_f64();
println!(
"f_final:{:?},_g_final:{:?},t_final:{:?},evals:{:?}",
f_final, g_f, t_final, evals
);
assert!((f_final - 13.708005905151367).abs() < tol);
assert!((g_f - 31.2450428009).abs() < tol);
assert!((t_final.to_f64() - 0.0).abs() < tol);
assert!((evals == 11));
}
}
#[test]
fn test_lbfgs_strong_wolfe_comparison() {
let device = Default::default();
let tol = 1e-5;
let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
let weight = TensorData::from([[0.5f64]]);
let bias = TensorData::from([0.1f64]);
let module = given_linear_layer(weight, bias);
let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
.with_line_search_fn(LineSearchFn::StrongWolfe)
.init();
let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
let output = mod_in.forward(x_data.clone());
let loss = burn_nn::loss::MseLoss::new().forward(
output,
y_true.clone(),
burn_nn::loss::Reduction::Sum,
);
let grads = loss.backward();
let grads_params = GradientsParams::from_grads(grads, &mod_in);
(loss.into_scalar().to_f64(), grads_params)
};
let initial_loss = closure(module.clone()).0;
assert!((initial_loss - 50.1300048828).abs() < tol);
let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
assert!((final_loss - 0.0234732367).abs() < tol);
let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
let optimized_bias: f64 = updated_module
.bias
.as_ref()
.unwrap()
.val()
.into_scalar()
.to_f64();
assert!((optimized_data - 2.0570652485).abs() < tol);
assert!((optimized_bias - 0.8106800914).abs() < tol);
}
#[test]
fn test_lbfgs_no_strong_wolfe_comparison() {
let device = Default::default();
let tol = 1e-5;
let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
let weight = TensorData::from([[0.5f64]]);
let bias = TensorData::from([0.1f64]);
let module = given_linear_layer(weight, bias);
let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
.with_line_search_fn(LineSearchFn::None)
.init();
let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
let output = mod_in.forward(x_data.clone());
let loss = burn_nn::loss::MseLoss::new().forward(
output,
y_true.clone(),
burn_nn::loss::Reduction::Sum,
);
let grads = loss.backward();
let grads_params = GradientsParams::from_grads(grads, &mod_in);
(loss.into_scalar().to_f64(), grads_params)
};
let initial_loss = closure(module.clone()).0;
assert!((initial_loss - 50.1300048828).abs() < tol);
let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
assert!((final_loss - 48.2181930542).abs() < tol);
let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
let optimized_bias: f64 = updated_module
.bias
.as_ref()
.unwrap()
.val()
.into_scalar()
.to_f64();
assert!((optimized_data - 0.5302446192).abs() < tol);
assert!((optimized_bias - 0.1142520783).abs() < tol);
}
}

View File

@@ -0,0 +1,30 @@
/// Weight decay module for optimizers.
pub mod decay;
/// Momentum module for optimizers.
pub mod momentum;
mod adagrad;
mod adam;
mod adamw;
mod base;
mod grad_accum;
mod grads;
mod lbfgs;
mod muon;
mod rmsprop;
mod sgd;
mod simple;
mod visitor;
pub use adagrad::*;
pub use adam::*;
pub use adamw::*;
pub use base::*;
pub use grad_accum::*;
pub use grads::*;
pub use lbfgs::*;
pub use muon::*;
pub use rmsprop::*;
pub use sgd::*;
pub use simple::*;

View File

@@ -0,0 +1,94 @@
use burn_core as burn;
use burn::config::Config;
use burn::record::Record;
use burn::tensor::backend::Backend;
use burn::tensor::{ElementConversion, Tensor};
/// Configuration to create [momentum](Momentum).
#[derive(Config, Debug)]
pub struct MomentumConfig {
/// Momentum factor
#[config(default = 0.9)]
pub momentum: f64,
/// Dampening factor.
#[config(default = 0.1)]
pub dampening: f64,
/// Enables Nesterov momentum, see [On the importance of initialization and
/// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
#[config(default = false)]
pub nesterov: bool,
}
/// State of [momentum](Momentum).
#[derive(Record, Clone, new)]
pub struct MomentumState<B: Backend, const D: usize> {
velocity: Tensor<B, D>,
}
/// Momentum implementation that transforms gradients.
#[derive(Clone)]
pub struct Momentum<B: Backend> {
momentum: B::FloatElem,
dampening: f64,
nesterov: bool,
}
impl<B: Backend> Momentum<B> {
/// Creates a new [momentum](Momentum) from a [config](MomentumConfig).
pub fn new(config: &MomentumConfig) -> Self {
Self {
momentum: config.momentum.elem(),
dampening: config.dampening,
nesterov: config.nesterov,
}
}
/// Transforms a gradient.
///
/// # Arguments
///
/// * `grad` - Gradient to transform.
/// * `state` - State of the optimizer.
///
/// # Returns
///
/// * `grad` - Transformed gradient.
/// * `state` - State of the optimizer.
pub fn transform<const D: usize>(
&self,
grad: Tensor<B, D>,
state: Option<MomentumState<B, D>>,
) -> (Tensor<B, D>, MomentumState<B, D>) {
let velocity = if let Some(state) = state {
grad.clone()
.mul_scalar(1.0 - self.dampening)
.add(state.velocity.mul_scalar(self.momentum))
} else {
grad.clone()
};
let grad = match self.nesterov {
true => velocity.clone().mul_scalar(self.momentum).add(grad),
false => velocity.clone(),
};
(grad, MomentumState::new(velocity))
}
}
impl<B: Backend, const D: usize> MomentumState<B, D> {
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.velocity = self.velocity.to_device(device);
self
}
}

View File

@@ -0,0 +1,775 @@
use burn_core as burn;
use burn::{module::AutodiffModule, record::Record};
use burn::config::Config;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use burn::tensor::{backend::Backend, ops::Device};
use serde::{Deserialize, Serialize};
use super::{
SimpleOptimizer,
adaptor::OptimizerAdaptor,
decay::WeightDecayConfig,
momentum::{Momentum, MomentumConfig, MomentumState},
};
use crate::LearningRate;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Learning rate adjustment method for Muon optimizer.
///
/// Muon adjusts the learning rate based on parameter shape to maintain consistent
/// RMS across rectangular matrices.
///
/// # References
///
/// - Original: [Muon: An optimizer for hidden layers](https://kellerjordan.github.io/posts/muon/)
/// - Moonshot: [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AdjustLrFn {
/// Keller Jordan's original method: `lr * sqrt(max(1, A/B))`
///
/// This scales the learning rate based on the aspect ratio of the weight matrix,
/// ensuring that tall matrices (more rows than columns) get proportionally larger
/// learning rates.
///
/// # Example
///
/// For a [1024, 512] matrix: `lr * sqrt(1024/512) = lr * 1.414`
#[default]
Original,
/// Moonshot's method: `lr * 0.2 * sqrt(max(A, B))`
///
/// This method is designed to match AdamW's RMS, allowing Muon to directly reuse
/// learning rates and weight decay values tuned for AdamW without retuning.
///
/// # Example
///
/// For a [1024, 512] matrix: `lr * 0.2 * sqrt(1024) = lr * 6.4`
MatchRmsAdamW,
}
impl AdjustLrFn {
/// Calculate the learning rate adjustment ratio for a given parameter shape.
///
/// # Arguments
///
/// * `shape` - Parameter shape (uses first two dimensions)
///
/// # Returns
///
/// Adjustment ratio to multiply with the base learning rate
fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
if shape.len() < 2 {
return 1.0;
}
let a = shape[0] as f64;
let b = shape[1] as f64;
match self {
Self::Original => {
// sqrt(max(1, A/B))
let ratio = a / b;
ratio.max(1.0).sqrt()
}
Self::MatchRmsAdamW => {
// 0.2 * sqrt(max(A, B))
0.2 * a.max(b).sqrt()
}
}
}
}
/// Muon configuration.
///
/// Muon is an optimizer specifically designed for 2D parameters of neural network
/// hidden layers (weight matrices). Other parameters such as biases and embeddings
/// should be optimized using a standard method such as AdamW.
///
/// # Learning Rate Adjustment
///
/// Muon adjusts the learning rate based on parameter shape to maintain consistent
/// RMS across rectangular matrices. Two methods are available:
///
/// - **Original**: Uses `sqrt(max(1, A/B))` where A and B are the first two dimensions.
/// This is Keller Jordan's method and is the default.
///
/// - **MatchRmsAdamW**: Uses `0.2 * sqrt(max(A, B))`. This is Moonshot's method
/// designed to match AdamW's RMS, allowing direct reuse of AdamW hyperparameters.
///
/// # Example
///
/// ```ignore
/// use burn_optim::{MuonConfig, AdjustLrFn};
///
/// // Using default (Original) method
/// let optimizer = MuonConfig::new().init();
///
/// // Using MatchRmsAdamW for AdamW-compatible hyperparameters
/// let optimizer = MuonConfig::new()
/// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
/// .init();
/// ```
///
/// # References
///
/// - [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/)
/// - [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
/// - [PyTorch Implementation](https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py)
/// - [Original Implementation](https://github.com/KellerJordan/Muon)
#[derive(Config, Debug)]
pub struct MuonConfig {
/// [Weight decay](WeightDecayConfig) config.
weight_decay: Option<WeightDecayConfig>,
/// [Momentum](MomentumConfig) config.
///
/// Muon always uses momentum. Default configuration:
/// - momentum: 0.95
/// - dampening: 0.0
/// - nesterov: true
#[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
momentum: MomentumConfig,
/// Newton-Schulz iteration coefficients (a, b, c).
///
/// These coefficients are selected to maximize the slope at zero for the
/// quintic iteration. Default values are from Keller Jordan's implementation.
#[config(default = "(3.4445, -4.775, 2.0315)")]
ns_coefficients: (f32, f32, f32),
/// Epsilon for numerical stability.
#[config(default = 1e-7)]
epsilon: f32,
/// Number of Newton-Schulz iteration steps.
#[config(default = 5)]
ns_steps: usize,
/// Learning rate adjustment method.
///
/// Controls how the learning rate is adjusted based on parameter shape.
/// See [`AdjustLrFn`] for available methods.
#[config(default = "AdjustLrFn::Original")]
adjust_lr_fn: AdjustLrFn,
}
impl MuonConfig {
/// Initialize Muon optimizer.
///
/// # Returns
///
/// Returns an optimizer adaptor that can be used to optimize a module.
///
/// # Example
///
/// ```ignore
/// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
///
/// // Basic configuration with default (Original) LR adjustment
/// let optimizer = MuonConfig::new()
/// .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
/// .init();
///
/// // With AdamW-compatible settings using MatchRmsAdamW
/// let optimizer = MuonConfig::new()
/// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
/// .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
/// .init();
///
/// // Custom momentum and NS settings
/// let optimizer = MuonConfig::new()
/// .with_momentum(MomentumConfig {
/// momentum: 0.9,
/// dampening: 0.1,
/// nesterov: false,
/// })
/// .with_ns_steps(7)
/// .init();
/// ```
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
let momentum = Momentum::new(&self.momentum);
let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
let optim = Muon {
momentum,
ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
weight_decay_penalty,
epsilon: self.epsilon,
adjust_lr_fn: self.adjust_lr_fn,
};
OptimizerAdaptor::from(optim)
}
}
/// Parameters for Newton-Schulz orthogonalization.
#[derive(Clone, Copy)]
struct NewtonSchulzParams {
a: f32,
b: f32,
c: f32,
steps: usize,
}
impl NewtonSchulzParams {
fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
Self {
a: coefficients.0,
b: coefficients.1,
c: coefficients.2,
steps,
}
}
}
/// Muon optimizer.
///
/// Muon internally runs standard SGD-momentum, and then performs an orthogonalization
/// post-processing step, in which each 2D parameter's update is replaced with the
/// nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz
/// iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
///
/// # Important Notes
///
/// 1. **Only for 2D+ parameters**: Muon is designed for weight matrices. Use AdamW
/// or SGD for biases, embeddings, and layer norms.
///
/// 2. **Learning rate adjustment**: Muon automatically adjusts the learning rate based
/// on parameter shape. See [`AdjustLrFn`] for details.
///
/// 3. **Weight decay timing**: Unlike typical optimizers, Muon applies weight decay
/// AFTER orthogonalization but uses the original (unadjusted) learning rate for it.
#[derive(Clone)]
pub struct Muon<B: Backend> {
momentum: Momentum<B>,
ns_params: NewtonSchulzParams,
weight_decay_penalty: Option<f32>,
epsilon: f32,
adjust_lr_fn: AdjustLrFn,
}
impl<B: Backend> Muon<B> {
/// Adjust learning rate based on parameter shape.
///
/// # Arguments
///
/// * `lr` - Base learning rate
/// * `shape` - Parameter shape (uses first two dimensions)
///
/// # Returns
///
/// Adjusted learning rate
///
/// ```ignore
/// // For a [1024, 512] weight matrix with lr=0.01:
/// // Original: 0.01 * sqrt(1024/512) = 0.01 * 1.414 = 0.01414
/// // MatchRmsAdamW: 0.01 * 0.2 * sqrt(1024) = 0.01 * 0.2 * 32 = 0.064
/// ```
fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
lr * self.adjust_lr_fn.adjustment_ratio(shape)
}
/// Perform Newton-Schulz orthogonalization on a gradient tensor.
///
/// This computes the zeroth power (orthogonalization) of the input matrix G
/// using a quintic Newton-Schulz iteration.
///
/// # Algorithm
///
/// 1. Transpose if tall matrix (A > B)
/// 2. Normalize: X = X / ||X||
/// 3. For k steps:
/// - A = X @ X^T
/// - B = b*A + c*A^2
/// - X = a*X + B@X
/// 4. Transpose back if needed
///
/// # References
///
/// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py
/// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py
fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
let shape = g.shape();
let dim_m2 = shape[D - 2];
let dim_m1 = shape[D - 1];
// Step 1: Transpose if tall matrix (more rows than columns)
let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
(g.swap_dims(D - 2, D - 1), true)
} else {
(g, false)
};
// Step 2: Normalize by Frobenius norm
// X = X / (||X|| + epsilon)
let norm = x
.clone()
.powf_scalar(2.0)
.sum()
.sqrt()
.clamp_min(self.epsilon)
.unsqueeze();
x = x.div(norm);
// Step 3: Newton-Schulz iteration
// This is the quintic iteration with coefficients (a, b, c)
let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
for _ in 0..steps {
// A = X @ X^T
let x_t = x.clone().swap_dims(D - 2, D - 1);
let a_matrix = x.clone().matmul(x_t);
// B = b*A + c*A@A
let a_squared = a_matrix.clone().matmul(a_matrix.clone());
let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
// X = a*X + B@X
x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
}
// Step 4: Restore transpose if it was a tall matrix
if needs_transpose {
x = x.swap_dims(D - 2, D - 1);
}
x
}
}
/// Muon state.
#[derive(Record, Clone, new)]
pub struct MuonState<B: Backend, const D: usize> {
/// Current momentum state
pub momentum: MomentumState<B, D>,
}
impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
type State<const D: usize> = MuonState<B, D>;
/// Perform a single Muon optimization step.
///
/// # Algorithm
///
/// 1. Apply momentum to gradient
/// 2. Orthogonalize update via Newton-Schulz
/// 3. Adjust learning rate based on parameter shape
/// 4. Apply weight decay (using original lr)
/// 5. Update parameter (using adjusted lr)
///
/// # Notes
///
/// Unlike typical optimizers, the weight decay and parameter update use
/// different learning rates:
/// - Weight decay uses the original `lr`
/// - Parameter update uses the shape-adjusted `lr`
///
/// # Panics
/// This function will panic if the input tensors are not 2D.
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
assert!(
D == 2,
"Newton-Schulz iteration requires 2D tensors, got {}D",
D
);
// Step 1: Apply momentum
let state_momentum = state.map(|s| s.momentum);
let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
// Step 2: Orthogonalize via Newton-Schulz
let update = self.zeropower_via_newtonschulz(grad);
// Step 3: Adjust learning rate based on parameter shape
let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
// Step 4: Apply weight decay (using ORIGINAL lr, not adjusted)
// Muon applies weight decay AFTER orthogonalization
let tensor = if let Some(penalty) = self.weight_decay_penalty {
let decay_factor = 1.0 - lr * penalty as f64;
tensor.mul_scalar(decay_factor)
} else {
tensor
};
// Step 5: Update parameter (using ADJUSTED lr)
let delta = update.mul_scalar(adjusted_lr);
let new_state = MuonState::new(new_momentum_state);
(tensor - delta, Some(new_state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.momentum = state.momentum.to_device(device);
state
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use crate::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
type TestBackend = burn_ndarray::NdArray<f32>;
const TOLERANCE: f64 = 1e-8;
fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: None, //No bias for Muon optimizer
};
LinearConfig::new(4, 4)
.with_bias(false)
.init(&device)
.load_record(record)
}
#[test]
fn test_adjust_lr_fn_original() {
let method = AdjustLrFn::Original;
// Square matrix [512, 512] -> sqrt(1) = 1.0
let ratio = method.adjustment_ratio(&[512, 512]);
assert!((ratio - 1.0).abs() < TOLERANCE);
// Tall matrix [1024, 512] -> sqrt(2) ≈ 1.414
let ratio = method.adjustment_ratio(&[1024, 512]);
let expected = (2.0f64).sqrt();
assert!((ratio - expected).abs() < TOLERANCE);
// Wide matrix [512, 1024] -> max(1, 0.5) = 1.0
let ratio = method.adjustment_ratio(&[512, 1024]);
assert!((ratio - 1.0).abs() < TOLERANCE);
}
#[test]
fn test_adjust_lr_fn_match_rms_adamw() {
let method = AdjustLrFn::MatchRmsAdamW;
// [1024, 512] -> 0.2 * sqrt(1024) = 6.4
let ratio = method.adjustment_ratio(&[1024, 512]);
let expected = 0.2 * 1024.0f64.sqrt();
assert!((ratio - expected).abs() < TOLERANCE);
// [512, 512] -> 0.2 * sqrt(512) ≈ 4.525
let ratio = method.adjustment_ratio(&[512, 512]);
let expected = 0.2 * 512.0f64.sqrt();
assert!((ratio - expected).abs() < TOLERANCE);
}
#[test]
#[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
fn test_1d_tensor_panics() {
let device = Default::default();
let config = MuonConfig::new();
let optim: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
let _ = optim.step(0.01, tensor_1d, grad_1d, None);
}
#[test]
fn test_muon_optimizer_save_load_state() {
let device = Default::default();
// Use Linear layer WITHOUT bias for Muon optimizer
let linear = LinearConfig::new(6, 6)
.with_bias(false) // No bias - only 2D weight matrix
.init::<TestAutodiffBackend>(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer =
MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(0.01, linear, grads);
let state_before = optimizer.to_record();
let state_before_copy = optimizer.to_record();
let optimizer_new =
MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
let optimizer_loaded = optimizer_new.load_record(state_before_copy);
let state_after = optimizer_loaded.to_record();
assert_eq!(state_before.len(), state_after.len());
}
#[test]
fn test_muon_with_weight_decay() {
let device = Default::default();
// Create Linear layer WITHOUT bias for Muon
let linear = given_linear_layer_no_bias(TensorData::from([
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
]));
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
[[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
&device,
)
.require_grad();
let mut optimizer = MuonConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(0.01)))
.init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(0.01, linear, grads);
let state = linear.into_record();
let weight = state.weight.to_data();
for val in weight.as_slice::<f32>().unwrap() {
assert!(
*val < 1.0,
"Weight should be reduced by weight decay, got {}",
val
);
}
}
#[test]
fn test_newton_schulz_orthogonalization() {
let device = Default::default();
let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
let config = MuonConfig::new();
let muon: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
let o_t = orthogonalized.clone().transpose();
let product = orthogonalized.matmul(o_t);
let data = product.into_data();
let values = data.as_slice::<f32>().unwrap();
assert!(
(values[0] - 1.0).abs() < 0.1,
"Product[0,0] should be ~1.0, got {}",
values[0]
);
assert!(
(values[3] - 1.0).abs() < 0.1,
"Product[1,1] should be ~1.0, got {}",
values[3]
);
}
#[test]
fn test_tall_matrix_transpose() {
// Test that tall matrices (A > B) are transposed during Newton-Schulz iteration
// and then transposed back
let device = Default::default();
// Create a tall matrix: [8, 4] (more rows than columns)
let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2],
[0.5, 1.0, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5],
[0.2, 0.1, 0.5, 1.0],
[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1],
[0.2, 0.4, 0.1, 0.3],
[0.3, 0.1, 0.4, 0.2],
],
&device,
);
let config = MuonConfig::new();
let muon: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
// Perform Newton-Schulz orthogonalization
let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
// Verify shape is preserved (should be transposed internally but returned in original shape)
let original_shape = tall_matrix.shape();
let result_shape = orthogonalized.shape();
assert_eq!(
original_shape.dims::<2>(),
result_shape.dims::<2>(),
"Shape should be preserved: [8, 4]"
);
// Verify output is different from input (orthogonalization happened)
let original_data = tall_matrix.into_data();
let result_data = orthogonalized.into_data();
assert_ne!(
original_data.as_slice::<f32>().unwrap(),
result_data.as_slice::<f32>().unwrap(),
"Orthogonalized matrix should differ from input"
);
// For comparison, test a wide matrix [4, 8] should NOT be transposed
let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
[0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
[0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
],
&device,
);
let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
// Verify wide matrix shape is also preserved
let wide_original_shape = wide_matrix.shape();
let wide_result_shape = orthogonalized_wide.shape();
assert_eq!(
wide_original_shape.dims::<2>(),
wide_result_shape.dims::<2>(),
"Wide matrix shape should be preserved: [4, 8]"
);
}
#[test]
fn test_zero_gradient() {
// Test that Muon handles zero gradients gracefully
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2],
[0.5, 1.0, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5],
[0.2, 0.1, 0.5, 1.0],
],
&device,
);
// Zero gradient - all zeros
let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
let config = MuonConfig::new();
let muon: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
// Should not panic or produce NaN
let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
// Verify state was created
assert!(state.is_some());
// With zero gradient and no weight decay, tensor should remain unchanged
let original_data = tensor.into_data();
let updated_data = updated_tensor.clone().into_data();
let original_vals = original_data.as_slice::<f32>().unwrap();
let updated_vals = updated_data.as_slice::<f32>().unwrap();
for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
assert!(
(orig - upd).abs() < 1e-6,
"With zero gradient, tensor should remain unchanged (or very close)"
);
}
// Verify no NaN values
for val in updated_vals {
assert!(
!val.is_nan(),
"Result should not contain NaN values with zero gradient"
);
}
// Test with weight decay - should still work
let muon_with_decay: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: Some(0.01),
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let tensor2 = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2],
[0.5, 1.0, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5],
[0.2, 0.1, 0.5, 1.0],
],
&device,
);
let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
let (updated_tensor_decay, _) =
muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
// With zero gradient but with weight decay, tensor should be slightly reduced
let updated_decay_data = updated_tensor_decay.into_data();
let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
for val in updated_decay_vals {
assert!(
!val.is_nan(),
"Result should not contain NaN with zero gradient and weight decay"
);
}
// With weight decay, values should be slightly smaller than original
let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
if orig.abs() > 1e-6 {
// Non-zero values should be reduced by weight decay
assert!(
upd.abs() < orig.abs(),
"Weight decay should reduce magnitude: original={}, updated={}",
orig,
upd
);
}
}
}
}

View File

@@ -0,0 +1,566 @@
use burn_core as burn;
use burn::{module::AutodiffModule, record::Record};
use super::{
SimpleOptimizer,
adaptor::OptimizerAdaptor,
decay::{WeightDecay, WeightDecayConfig},
};
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
use burn::config::Config;
use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
/// Configuration to create the [RmsProp](RmsProp) optimizer.
#[derive(Config, Debug)]
pub struct RmsPropConfig {
/// Smoothing constant.
#[config(default = 0.99)]
alpha: f32,
/// momentum for RmsProp.
#[config(default = 0.9)]
momentum: f32,
/// A value required for numerical stability.
#[config(default = 1e-5)]
epsilon: f32,
/// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
#[config(default = false)]
centered: bool,
/// [Weight decay](WeightDecayConfig) config.
weight_decay: Option<WeightDecayConfig>,
/// [Gradient Clipping](GradientClippingConfig) config.
grad_clipping: Option<GradientClippingConfig>,
}
impl RmsPropConfig {
/// Initialize RmsProp optimizer.
///
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<RmsProp, M, B> {
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
let mut optim = OptimizerAdaptor::from(RmsProp {
alpha: self.alpha,
centered: self.centered,
weight_decay,
momentum: RmsPropMomentum {
momentum: self.momentum,
epsilon: self.epsilon,
},
});
if let Some(config) = &self.grad_clipping {
optim = optim.with_grad_clipping(config.init());
}
optim
}
}
/// Optimizer that implements stochastic gradient descent with momentum.
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
#[derive(Clone)]
pub struct RmsProp {
alpha: f32,
// epsilon: f32,
centered: bool,
// momentum: Option<Momentum<B>>,
momentum: RmsPropMomentum,
weight_decay: Option<WeightDecay>,
}
impl<B: Backend> SimpleOptimizer<B> for RmsProp {
type State<const D: usize> = RmsPropState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
mut grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
// fetch state for params
let mut state_square_avg = None;
let mut state_centered = None;
let mut state_momentum = None;
if let Some(state) = state {
state_square_avg = Some(state.square_avg);
state_centered = Some(state.centered);
state_momentum = state.momentum;
}
// weight_decay transform
if let Some(weight_decay) = &self.weight_decay {
grad = weight_decay.transform(grad, tensor.clone());
}
// square_avg transform
let (grad, state_square_avg) =
SquareAvgState::transform(self.alpha, grad, state_square_avg);
// centered transform
let (grad, state_square_avg, state_centered) = CenteredState::transform(
self.alpha,
self.centered,
grad,
state_square_avg,
state_centered,
);
// momentum transform
let (grad, state_centered, state_momentum) =
self.momentum
.transform(grad, state_centered, state_momentum);
// transition state
let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
// tensor param transform
let delta = grad.mul_scalar(lr);
(tensor - delta, Some(state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.square_avg = state.square_avg.to_device(device);
state.centered = state.centered.to_device(device);
state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
state
}
}
/// State of [RmsProp](RmsProp)
#[derive(Record, Clone, new)]
pub struct RmsPropState<B: Backend, const D: usize> {
/// Current squared average state.
pub square_avg: SquareAvgState<B, D>,
/// Current centered state
pub centered: CenteredState<B, D>,
/// Current gradient momentum, if any.
pub momentum: Option<RmsPropMomentumState<B, D>>,
}
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
#[derive(Record, Clone, new)]
pub struct SquareAvgState<B: Backend, const D: usize> {
/// Current squared average.
pub square_avg: Tensor<B, D>,
}
impl<B: Backend, const D: usize> SquareAvgState<B, D> {
/// transform [SquareAvgState] to the next step
fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
match state {
Some(state) => {
let square_avg = state
.square_avg
.mul_scalar(alpha)
.add(grad.clone().square().mul_scalar(1. - alpha));
(grad, Self { square_avg })
}
_ => {
let square_avg = grad.clone().square().mul_scalar(1. - alpha);
(grad, Self { square_avg })
}
}
}
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.square_avg = self.square_avg.to_device(device);
self
}
}
/// [CenteredState](CenteredState) is to store and pass optimizer step params.
#[derive(Record, Clone, new)]
pub struct CenteredState<B: Backend, const D: usize> {
/// The averaged gradient to calculate the centered gradient, if available.
pub grad_avg: Option<Tensor<B, D>>,
/// The current average value.
pub avg: Tensor<B, D>,
}
impl<B: Backend, const D: usize> CenteredState<B, D> {
/// transform [CenteredState] to the next step
fn transform(
alpha: f32,
centered: bool,
grad: Tensor<B, D>,
square_avg_state: SquareAvgState<B, D>,
centered_state: Option<Self>,
) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
if centered {
let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
let grad_avg = match centered_state {
Some(state) => state
.grad_avg
.map_or(grad_avg_constant.clone(), move |grad_avg| {
grad_avg.mul_scalar(alpha).add(grad_avg_constant)
}),
_ => grad_avg_constant,
};
let avg = square_avg_state
.square_avg
.clone()
.sub(grad_avg.clone().square());
(
grad,
square_avg_state,
Self {
grad_avg: Some(grad_avg),
avg,
},
)
} else {
(
grad,
square_avg_state.clone(),
Self {
grad_avg: None,
avg: square_avg_state.square_avg,
},
)
}
}
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
self.avg = self.avg.to_device(device);
self
}
}
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
#[derive(Clone)]
pub struct RmsPropMomentum {
momentum: f32,
epsilon: f32,
}
impl RmsPropMomentum {
/// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
centered_state: CenteredState<B, D>,
momentum_state: Option<RmsPropMomentumState<B, D>>,
) -> (
Tensor<B, D>,
CenteredState<B, D>,
Option<RmsPropMomentumState<B, D>>,
) {
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
if self.momentum > 0. {
let buf = match momentum_state {
Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
_ => grad,
};
(
buf.clone(),
centered_state,
Some(RmsPropMomentumState { buf }),
)
} else {
(grad, centered_state, None)
}
}
}
/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
#[derive(Record, Clone, new)]
pub struct RmsPropMomentumState<B: Backend, const D: usize> {
buf: Tensor<B, D>,
}
impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.buf = self.buf.to_device(device);
self
}
}
#[cfg(test)]
mod tests {
use burn::tensor::ops::FloatElem;
use burn::tensor::{Shape, Tolerance};
use super::*;
use crate::TestAutodiffBackend;
use crate::optim::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
type FT = FloatElem<TestAutodiffBackend>;
const LEARNING_RATE: LearningRate = 0.01;
#[test]
fn test_rmsprop_optimizer_save_load_state() {
let device = Default::default();
let linear = LinearConfig::new(6, 6).init(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer = create_rmsprop();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
#[cfg(feature = "std")]
{
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
BinFileRecorder::<FullPrecisionSettings>::default()
.record(
optimizer.to_record(),
std::env::temp_dir().as_path().join("test_optim_rmsprop"),
)
.unwrap();
}
#[cfg(not(feature = "std"))]
{
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
.record(optimizer.to_record(), ())
.unwrap();
assert!(!result.is_empty());
}
let state_optim_before = optimizer.to_record();
let state_optim_before_copy = optimizer.to_record();
let optimizer = create_rmsprop();
let optimizer = optimizer.load_record(state_optim_before_copy);
let state_optim_after = optimizer.to_record();
assert_eq!(state_optim_before.len(), state_optim_after.len());
}
/// used for test differences and debug
#[test]
fn test_rmsprop_optimizer_with_numbers_basic() {
let linear = given_linear_layer(
TensorData::from([
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
]),
TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
);
let device = Default::default();
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&device,
)
.require_grad();
let mut optimizer = RmsPropConfig::new()
.with_alpha(0.99)
.with_epsilon(1e-8)
.with_weight_decay(WeightDecayConfig::new(0.05).into())
.with_momentum(0.9)
.with_centered(false)
.init();
// println!("linear is {:?}", linear);
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
// println!("linear is {:?}", linear);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
// println!("linear is {:?}", linear);
let state_updated = linear.into_record();
let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
// println!("\nweight_updated\n{:?}", weight_updated);
// println!("\nbias_updated\n{:?}", bias_updated);
let weights_expected = TensorData::from([
[0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
[0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
[0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
[0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
[0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
[0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
]);
let bias_expected =
TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
let tolerance = Tolerance::absolute(1e-6);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
#[test]
fn test_rmsprop_optimizer_with_numbers() {
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let device = Default::default();
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&device,
)
.require_grad();
let mut optimizer = RmsPropConfig::new()
.with_alpha(0.99)
.with_epsilon(1e-8)
.with_weight_decay(WeightDecayConfig::new(0.05).into())
.with_momentum(0.9)
.with_centered(false)
.init();
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
let weights_expected = TensorData::from([
[
-0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
],
[
-0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
],
[
-0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
],
[
-0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
],
[
0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
],
[
-0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
],
]);
let bias_expected = TensorData::from([
-0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
]);
let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
// println!("\nweight_updated\n{:?}", weight_updated);
// println!("\nbias_updated\n{:?}", bias_updated);
let tolerance = Tolerance::absolute(1e-6);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: Some(Param::from_data(bias, &device)),
};
LinearConfig::new(6, 6).init(&device).load_record(record)
}
#[allow(dead_code)]
fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
Tensor::<TestAutodiffBackend, 2>::random(
Shape::new([2, 20]),
Distribution::Default,
&Default::default(),
)
}
fn create_rmsprop()
-> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
RmsPropConfig {
alpha: 0.99,
epsilon: 1e-9,
centered: false,
weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
momentum: 0.9,
grad_clipping: None,
}
.init()
}
}

View File

@@ -0,0 +1,181 @@
use burn_core as burn;
use super::SimpleOptimizer;
use super::adaptor::OptimizerAdaptor;
use super::decay::{WeightDecay, WeightDecayConfig};
use super::momentum::{Momentum, MomentumConfig, MomentumState};
use crate::LearningRate;
use crate::grad_clipping::GradientClippingConfig;
use burn::config::Config;
use burn::module::AutodiffModule;
use burn::record::Record;
use burn::tensor::Tensor;
use burn::tensor::backend::{AutodiffBackend, Backend};
/// Configuration to create the [Sgd](Sgd) optimizer.
#[derive(Config, Debug)]
pub struct SgdConfig {
/// [Weight decay](WeightDecayConfig) config.
weight_decay: Option<WeightDecayConfig>,
/// [Momentum](MomentumConfig) config.
momentum: Option<MomentumConfig>,
/// [Gradient Clipping](GradientClippingConfig) config.
gradient_clipping: Option<GradientClippingConfig>,
}
/// Optimizer that implements stochastic gradient descent with momentum.
///
/// The optimizer can be configured with [SgdConfig](SgdConfig).
#[derive(Clone)]
pub struct Sgd<B: Backend> {
momentum: Option<Momentum<B>>,
weight_decay: Option<WeightDecay>,
}
/// State of [Sgd](Sgd).
#[derive(Record, Clone, new)]
pub struct SgdState<B: Backend, const D: usize> {
/// The current state of the momentum (if any).
pub momentum: Option<MomentumState<B, D>>,
}
impl SgdConfig {
/// Creates a new [SgdConfig](SgdConfig) with default values.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
let momentum = self.momentum.as_ref().map(Momentum::new);
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
let mut optim = OptimizerAdaptor::from(Sgd {
momentum,
weight_decay,
});
if let Some(config) = &self.gradient_clipping {
optim = optim.with_grad_clipping(config.init());
}
optim
}
}
impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
type State<const D: usize> = SgdState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
mut grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let mut state_momentum = None;
if let Some(state) = state {
state_momentum = state.momentum;
}
if let Some(weight_decay) = &self.weight_decay {
grad = weight_decay.transform(grad, tensor.clone());
}
if let Some(momentum) = &self.momentum {
let (grad_out, state) = momentum.transform(grad, state_momentum);
state_momentum = Some(state);
grad = grad_out;
}
let state = SgdState::new(state_momentum);
let delta = grad.mul_scalar(lr);
(tensor - delta, Some(state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
state.momentum = state.momentum.map(|state| state.to_device(device));
state
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
TestAutodiffBackend, TestBackend,
grad_clipping::GradientClipping,
optim::{GradientsParams, Optimizer},
};
use burn::tensor::{Distribution, Shape};
use burn_nn::{Linear, LinearConfig};
const LEARNING_RATE: LearningRate = 0.02;
#[test]
fn with_updated_params_should_have_state() {
let device = Default::default();
let layer = layer::<TestAutodiffBackend>(&device);
let mut optim = sgd_with_all();
let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &layer);
let _layer = optim.step(LEARNING_RATE, layer, grads);
let record = optim.to_record();
assert!(!record.is_empty());
}
#[test]
fn without_updated_params_should_not_have_state() {
let optim = sgd_with_all();
let record = optim.to_record();
assert!(record.is_empty());
}
#[test]
fn can_attach_gradient_clipping() {
let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
assert!(optim.has_gradient_clipping());
}
#[test]
fn should_load_state() {
let device = Default::default();
let layer = layer::<TestAutodiffBackend>(&device);
let mut optim = sgd_with_all();
let loss = layer.forward(random_tensor(&device));
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &layer);
let _layer = optim.step(LEARNING_RATE, layer, grads);
let record = optim.to_record();
let optim_new = sgd_with_all();
let record_new = optim_new.to_record();
let optim_new = optim_new.load_record(record.clone());
let state_restored = optim_new.to_record();
assert_ne!(record.len(), record_new.len());
assert_eq!(record.len(), state_restored.len());
}
fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
}
fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
LinearConfig::new(20, 20).init(device)
}
fn sgd_with_all()
-> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
SgdConfig {
weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
momentum: Some(MomentumConfig {
momentum: 0.9,
dampening: 0.1,
nesterov: true,
}),
gradient_clipping: None,
}
.init()
}
}

View File

@@ -0,0 +1,210 @@
use burn_core::{self as burn, prelude::Backend, tensor::Device};
use super::{SimpleOptimizer, record::AdaptorRecord};
use crate::{
LearningRate, MultiGradientsParams,
grad_clipping::GradientClipping,
optim::{GradientsParams, Optimizer},
};
use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
use burn::tensor::{Tensor, backend::AutodiffBackend};
use core::marker::PhantomData;
use hashbrown::HashMap;
/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
/// an [optimizer](Optimizer).
#[derive(Clone)]
pub struct OptimizerAdaptor<O, M, B>
where
O: SimpleOptimizer<B::InnerBackend>,
M: AutodiffModule<B>,
B: AutodiffBackend,
{
optim: O,
records: HashMap<ParamId, AdaptorRecord<O, B>>,
module: PhantomData<M>,
grad_clipping: Option<GradientClipping>,
}
impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
O: SimpleOptimizer<B::InnerBackend>,
{
fn from(optim: O) -> Self {
Self {
optim,
records: HashMap::new(),
module: PhantomData,
grad_clipping: None,
}
}
}
impl<O, M, B> OptimizerAdaptor<O, M, B>
where
O: SimpleOptimizer<B::InnerBackend>,
M: AutodiffModule<B>,
B: AutodiffBackend,
{
/// Sets the gradient clipping.
///
/// # Arguments
///
/// * `gradient_clipping` - The gradient clipping.
///
/// # Returns
///
/// The optimizer.
pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
self.grad_clipping = Some(gradient_clipping);
self
}
#[cfg(test)]
pub(crate) fn has_gradient_clipping(&self) -> bool {
self.grad_clipping.is_some()
}
}
impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
O: SimpleOptimizer<B::InnerBackend>,
{
type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {
let mut grads = GradAdaptor::Single(grads);
let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
&self.optim,
&mut self.records,
&mut grads,
lr,
self.grad_clipping.as_ref(),
);
module.map(&mut mapper)
}
fn step_multi(&mut self, lr: LearningRate, module: M, grads: crate::MultiGradientsParams) -> M {
let mut grads = GradAdaptor::Multi(grads);
let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
&self.optim,
&mut self.records,
&mut grads,
lr,
self.grad_clipping.as_ref(),
);
module.map(&mut mapper)
}
fn to_record(&self) -> Self::Record {
self.records.clone()
}
fn load_record(mut self, record: Self::Record) -> Self {
self.records = record;
self
}
}
enum GradAdaptor {
Single(GradientsParams),
Multi(MultiGradientsParams),
}
impl GradAdaptor {
fn remove<B: Backend, const D: usize>(
&mut self,
id: ParamId,
) -> Option<(Tensor<B, D>, Device<B>)> {
match self {
GradAdaptor::Single(grads) => grads.remove(id).map(|t| {
let device = t.device();
(t, device)
}),
GradAdaptor::Multi(grads) => grads.remove(id),
}
}
}
#[derive(new)]
struct SimpleOptimizerMapper<'a, M, B, O>
where
M: AutodiffModule<B>,
B: AutodiffBackend,
O: SimpleOptimizer<B::InnerBackend>,
{
optimizer: &'a O,
records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
grads: &'a mut GradAdaptor,
lr: LearningRate,
phantom: PhantomData<M>,
grad_clipping: Option<&'a GradientClipping>,
}
impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
where
M: AutodiffModule<B>,
B: AutodiffBackend,
O: SimpleOptimizer<B::InnerBackend>,
{
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let grad = self.grads.remove(id);
let tensor = if let Some((grad, device)) = grad {
let is_require_grad = tensor.is_require_grad();
let (key, record) = self.records.remove_entry(&id).unzip();
let tensor = if tensor.device() != device {
tensor.to_device(&device)
} else {
tensor
};
debug_assert_eq!(
grad.device(),
device,
"The gradient is on the provided device"
);
let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
g_clipping.clip_gradient(grad)
} else {
grad
};
debug_assert_eq!(
tensor.device(),
device,
"Tensor and gradients are on the same device."
);
let (tensor, state) = self.optimizer.step(
self.lr,
tensor.inner(),
clipped_grad,
record.map(|record| O::to_device(record.into_state(), &device)),
);
if let Some(state) = state {
self.records
.insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
}
let mut tensor = Tensor::from_inner(tensor);
if is_require_grad {
tensor = tensor.require_grad();
}
tensor
} else {
tensor
};
Param::from_mapped_value(id, tensor, mapper)
}
}

View File

@@ -0,0 +1,36 @@
use burn_core as burn;
use crate::LearningRate;
use burn::record::Record;
use burn::tensor::{Tensor, backend::Backend};
/// Simple optimizer is an opinionated trait to simplify the process of implementing an
/// optimizer.
///
/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
/// module parameter structure, handle tracked and untracked tensors, and the likes.
pub trait SimpleOptimizer<B>: Send + Sync + Clone
where
B: Backend,
{
/// The state of the optimizer. It also implements [record](Record), so that it can be saved.
type State<const D: usize>: Record<B> + Clone + 'static;
/// The optimizer step is performed for one tensor at a time with its gradient and state.
///
/// Note that the state is passed as parameter, so implementations don't have to handle
/// the saving and loading of recorded states.
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>);
/// Change the device of the state.
///
/// This function will be called accordingly to have the state on the same device as the
/// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.
fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
}

View File

@@ -0,0 +1,8 @@
mod base;
pub use base::*;
/// Adaptor module for optimizers.
pub mod adaptor;
/// Record module for optimizers.
pub mod record;

View File

@@ -0,0 +1,93 @@
use burn_core as burn;
use super::{AdaptorRecordItemV1, AdaptorRecordV1};
use crate::optim::SimpleOptimizer;
use burn::record::{PrecisionSettings, Record};
use burn::tensor::backend::AutodiffBackend;
use serde::{Deserialize, Serialize};
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.
///
/// Records are versioned for backward compatibility, so old records can be loaded.
pub enum AdaptorRecord<O, B>
where
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
/// Version 1.
V1(AdaptorRecordV1<O, B::InnerBackend>),
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
pub enum AdaptorRecordItem<
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
S: PrecisionSettings,
> {
/// Version 1.
V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
}
impl<O, B> Record<B> for AdaptorRecord<O, B>
where
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
match self {
AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()),
}
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
}
}
}
impl<O, B> Clone for AdaptorRecord<O, B>
where
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
fn clone(&self) -> Self {
match self {
AdaptorRecord::V1(record) => Self::V1(record.clone()),
}
}
}
impl<O, B> AdaptorRecord<O, B>
where
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
/// Converts the record into the optimizer state.
///
/// # Returns
///
/// The optimizer state.
pub fn into_state<const D: usize>(self) -> O::State<D> {
match self {
AdaptorRecord::V1(record) => record.into_state(),
}
}
/// Converts the optimizer state into the record.
///
/// # Arguments
///
/// * `state`: The optimizer state.
///
/// # Returns
///
/// The record.
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
Self::V1(AdaptorRecordV1::from_state(state))
}
}

View File

@@ -0,0 +1,5 @@
mod base;
mod v1;
pub use base::*;
pub use v1::*;

View File

@@ -0,0 +1,201 @@
use burn_core as burn;
use crate::optim::SimpleOptimizer;
use burn::record::{PrecisionSettings, Record};
use burn::tensor::backend::Backend;
use core::any::Any;
use serde::{Deserialize, Serialize};
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
/// Rank 0.
Rank0(O::State<0>),
/// Rank 1.
Rank1(O::State<1>),
/// Rank 2.
Rank2(O::State<2>),
/// Rank 3.
Rank3(O::State<3>),
/// Rank 4.
Rank4(O::State<4>),
/// Rank 5.
Rank5(O::State<5>),
/// Rank 6.
Rank6(O::State<6>),
/// Rank 7.
Rank7(O::State<7>),
/// Rank 8.
Rank8(O::State<8>),
}
impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
fn clone(&self) -> Self {
match self {
AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),
AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),
AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),
AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),
AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()),
AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()),
AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()),
AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()),
AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()),
}
}
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
/// Rank 0.
Rank0(<O::State<0> as Record<B>>::Item<S>),
/// Rank 1.
Rank1(<O::State<1> as Record<B>>::Item<S>),
/// Rank 2.
Rank2(<O::State<2> as Record<B>>::Item<S>),
/// Rank 3.
Rank3(<O::State<3> as Record<B>>::Item<S>),
/// Rank 4.
Rank4(<O::State<4> as Record<B>>::Item<S>),
/// Rank 5.
Rank5(<O::State<5> as Record<B>>::Item<S>),
/// Rank 6.
Rank6(<O::State<6> as Record<B>>::Item<S>),
/// Rank 7.
Rank7(<O::State<7> as Record<B>>::Item<S>),
/// Rank 8.
Rank8(<O::State<8> as Record<B>>::Item<S>),
}
impl<O, B> AdaptorRecordV1<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
{
/// Convert the record into the state.
///
/// # Returns
///
/// The state.
///
/// # Panics
///
/// Panics if the state dimension is not supported.
pub fn into_state<const D: usize>(self) -> O::State<D> {
let boxed_state: Box<dyn Any> = match self {
AdaptorRecordV1::Rank0(s) => Box::new(s),
AdaptorRecordV1::Rank1(s) => Box::new(s),
AdaptorRecordV1::Rank2(s) => Box::new(s),
AdaptorRecordV1::Rank3(s) => Box::new(s),
AdaptorRecordV1::Rank4(s) => Box::new(s),
AdaptorRecordV1::Rank5(s) => Box::new(s),
AdaptorRecordV1::Rank6(s) => Box::new(s),
AdaptorRecordV1::Rank7(s) => Box::new(s),
AdaptorRecordV1::Rank8(s) => Box::new(s),
};
let state = boxed_state
.downcast::<O::State<D>>()
.expect("Unsupported state dimension, dimension up to 8 are supported.");
*state
}
/// Convert the state into the record.
///
/// # Arguments
///
/// * `state`: The state.
///
/// # Returns
///
/// The record.
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
let state: Box<dyn Any> = Box::new(state);
match D {
0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),
1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),
2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),
3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),
4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()),
5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()),
6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()),
7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()),
8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()),
_ => panic!("Unsupported state dimension, dimension up to 8 are supported."),
}
}
}
impl<O, B> Record<B> for AdaptorRecordV1<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = AdaptorRecordItemV1<O, B, S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
match self {
AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),
AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),
AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),
AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),
AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()),
AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()),
AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()),
AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()),
AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()),
}
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
AdaptorRecordItemV1::Rank0(item) => {
AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank1(item) => {
AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank2(item) => {
AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank3(item) => {
AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank4(item) => {
AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank5(item) => {
AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank6(item) => {
AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank7(item) => {
AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank8(item) => {
AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))
}
}
}
}

View File

@@ -0,0 +1,60 @@
use burn_core as burn;
use super::GradientsParams;
use burn::module::{AutodiffModule, ModuleVisitor, Param, ParamId};
use burn::tensor::{Tensor, backend::AutodiffBackend};
use core::marker::PhantomData;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(new)]
pub struct GradientsParamsConverter<'a, M: AutodiffModule<B>, B: AutodiffBackend> {
grads: &'a mut B::Gradients,
grads_params: &'a mut GradientsParams,
phatom: PhantomData<M>,
filter: Option<Vec<ParamId>>,
}
#[derive(new)]
pub struct GradientsParamsChangeDevice<'a, M: AutodiffModule<B>, B: AutodiffBackend> {
device: &'a B::Device,
grads: &'a mut GradientsParams,
phatom: PhantomData<M>,
}
impl<B, M> ModuleVisitor<B> for GradientsParamsConverter<'_, M, B>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
{
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
if let Some(filter) = self.filter.as_ref()
&& !filter.contains(&param.id)
{
return;
}
let Some(grad) = param.val().grad_remove(self.grads) else {
return;
};
self.grads_params
.register::<B::InnerBackend, D>(param.id, grad);
}
}
impl<B, M> ModuleVisitor<B> for GradientsParamsChangeDevice<'_, M, B>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
{
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let Some(grad) = self.grads.remove::<B::InnerBackend, D>(param.id) else {
return;
};
self.grads
.register::<B::InnerBackend, D>(param.id, grad.to_device(self.device));
}
}