feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
102
crates/stable-diffusion-burn/burn-crates/burn-optim/Cargo.toml
Normal file
102
crates/stable-diffusion-burn/burn-crates/burn-optim/Cargo.toml
Normal file
@@ -0,0 +1,102 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Optimizer building blocks for the Burn deep learning framework"
|
||||
documentation = "https://docs.rs/burn-optim"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-optim"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-optim"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"std",
|
||||
"burn-core/default",
|
||||
]
|
||||
doc = [
|
||||
"std",
|
||||
# Doc features
|
||||
"burn-core/doc",
|
||||
]
|
||||
std = [
|
||||
"burn-core/std",
|
||||
"num-traits/std",
|
||||
"serde/std",
|
||||
"log",
|
||||
]
|
||||
tracing = [
|
||||
"burn-collective?/tracing",
|
||||
"burn-core/tracing",
|
||||
"burn-cuda?/tracing",
|
||||
"burn-fusion?/tracing",
|
||||
"burn-remote?/tracing",
|
||||
"burn-rocm?/tracing",
|
||||
"burn-router?/tracing",
|
||||
"burn-tch?/tracing",
|
||||
"burn-wgpu?/tracing",
|
||||
]
|
||||
|
||||
collective = ["burn-collective"]
|
||||
|
||||
test-cuda = [
|
||||
"burn-cuda/default",
|
||||
] # To use cuda during testing, default uses ndarray.
|
||||
test-rocm = [
|
||||
"burn-rocm/default",
|
||||
] # To use hip during testing, default uses ndarray.
|
||||
test-tch = [
|
||||
"burn-tch/default",
|
||||
] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = [
|
||||
"burn-wgpu/default",
|
||||
] # To use wgpu during testing, default uses ndarray.
|
||||
test-vulkan = [
|
||||
"test-wgpu",
|
||||
"burn-wgpu/vulkan",
|
||||
] # To use wgpu-spirv during testing, default uses ndarray.
|
||||
test-metal = [
|
||||
"test-wgpu",
|
||||
"burn-wgpu/metal",
|
||||
] # To use wgpu-spirv during testing, default uses ndarray.
|
||||
|
||||
# Memory checks are disabled by default
|
||||
test-memory-checks = ["burn-fusion/memory-checks"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
|
||||
num-traits = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true, optional = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
|
||||
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
|
||||
|
||||
# FOR TESTING
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2" }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
|
||||
rstest = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1,3 @@
|
||||
# Burn Optimizers
|
||||
|
||||
Core building blocks for Burn optimizers.
|
||||
@@ -0,0 +1,144 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::{config::Config, tensor::Tensor};
|
||||
|
||||
/// Gradient Clipping provides a way to mitigate exploding gradients
|
||||
#[derive(Config, Debug)]
|
||||
pub enum GradientClippingConfig {
|
||||
/// Clip the gradient by value.
|
||||
Value(f32),
|
||||
|
||||
/// Clip the gradient by norm.
|
||||
Norm(f32),
|
||||
}
|
||||
|
||||
impl GradientClippingConfig {
|
||||
/// Initialize the gradient clipping.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The gradient clipping.
|
||||
pub fn init(&self) -> GradientClipping {
|
||||
match self {
|
||||
GradientClippingConfig::Value(val) => GradientClipping::Value(*val),
|
||||
GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gradient Clipping provides a way to mitigate exploding gradients
|
||||
/// by clipping every component of the gradient by value or by norm during
|
||||
/// backpropagation.
|
||||
#[derive(Clone)]
|
||||
pub enum GradientClipping {
|
||||
/// Clip the gradient by value.
|
||||
Value(f32),
|
||||
|
||||
/// Clip the gradient by norm.
|
||||
Norm(f32),
|
||||
}
|
||||
|
||||
impl GradientClipping {
|
||||
/// Clip the gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `grad` - The gradient to clip.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The clipped gradient.
|
||||
pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {
|
||||
match self {
|
||||
GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),
|
||||
GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm),
|
||||
}
|
||||
}
|
||||
|
||||
fn clip_by_value<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
threshold: f32,
|
||||
) -> Tensor<B, D> {
|
||||
let greater_mask = grad.clone().greater_elem(threshold);
|
||||
let lower_mask = grad.clone().lower_elem(-threshold);
|
||||
|
||||
let clipped_grad = grad.mask_fill(greater_mask, threshold);
|
||||
|
||||
clipped_grad.mask_fill(lower_mask, -threshold)
|
||||
}
|
||||
|
||||
fn clip_by_norm<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
threshold: f32,
|
||||
) -> Tensor<B, D> {
|
||||
let norm = Self::l2_norm(grad.clone());
|
||||
let clip_coef = threshold / norm.add_scalar(1e-6); // avoid div by zero
|
||||
let clip_coef_clamped = clip_coef.clamp_max(1.0);
|
||||
grad.mul(clip_coef_clamped.unsqueeze())
|
||||
}
|
||||
|
||||
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
|
||||
let squared = tensor.square();
|
||||
let sum = squared.sum();
|
||||
sum.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
#[test]
|
||||
fn test_clip_by_value() {
|
||||
let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient);
|
||||
let clipped_gradient_data = clipped_gradient.into_data();
|
||||
|
||||
for value in clipped_gradient_data.iter::<f32>() {
|
||||
assert!(value <= 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clip_by_norm() {
|
||||
let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient);
|
||||
let clipped_gradient_data = clipped_gradient.into_data();
|
||||
|
||||
for value in clipped_gradient_data.iter::<f32>() {
|
||||
assert!(value <= 0.88);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_clip_by_norm_no_clipping() {
|
||||
let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
|
||||
[[0.3, 0.4, 0.5, 0.2], [0.1, 0.6, 0.3, 0.4]],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient.clone());
|
||||
|
||||
clipped_gradient
|
||||
.into_data()
|
||||
.assert_eq(&gradient.into_data(), true);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
mod base;
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,63 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
//! Burn optimizers.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Optimizer module.
|
||||
pub mod optim;
|
||||
pub use optim::*;
|
||||
|
||||
/// Gradient clipping module.
|
||||
pub mod grad_clipping;
|
||||
|
||||
/// Learning rate scheduler module.
|
||||
#[cfg(feature = "std")]
|
||||
pub mod lr_scheduler;
|
||||
|
||||
/// Type alias for the learning rate.
|
||||
///
|
||||
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
|
||||
/// can be used for constant learning rate.
|
||||
pub type LearningRate = f64; // We could potentially change the type.
|
||||
|
||||
/// Backend for test cases
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-tch"),
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda"),
|
||||
not(feature = "test-rocm")
|
||||
))]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-tch"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_tch::LibTorch<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_cuda::Cuda;
|
||||
|
||||
#[cfg(all(test, feature = "test-rocm"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_rocm::Rocm;
|
||||
|
||||
/// Backend for autodiff test cases
|
||||
#[cfg(test)]
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
|
||||
#[cfg(all(test, feature = "test-memory-checks"))]
|
||||
mod tests {
|
||||
burn_fusion::memory_checks!();
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
pub(super) use alloc::string::String;
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::record::Record;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use crate::LearningRate;
|
||||
|
||||
/// Learning rate scheduler defines how the learning rate will evolve during training.
|
||||
pub trait LrScheduler: Clone + Send + Sync {
|
||||
/// Scheduler associative type to be used when saving and loading the state.
|
||||
type Record<B: Backend>: Record<B>;
|
||||
|
||||
/// Perform the scheduler step, potentially updating its state, and returning the effective
|
||||
/// learning rate.
|
||||
fn step(&mut self) -> LearningRate;
|
||||
|
||||
/// Get the current state of the scheduler as a [record](Record).
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B>;
|
||||
|
||||
/// Load the state of the scheduler as a [record](Record).
|
||||
fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) mod test_utils {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
// A small tolerance for learning rate comparisons. Depending on how learning rates are
|
||||
// computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is
|
||||
// used here.
|
||||
const LOOSE_EPSILON: LearningRate = 1e-10;
|
||||
|
||||
pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
|
||||
where
|
||||
I: IntoIterator<Item = LearningRate>,
|
||||
S: LrScheduler,
|
||||
{
|
||||
expected_lrs
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.for_each(|(i, expected)| {
|
||||
let lr = scheduler.step();
|
||||
assert!(
|
||||
(lr - expected).abs() < LOOSE_EPSILON,
|
||||
"Scheduled learning rate {lr} is not approximately equal to the expected value \
|
||||
{expected} at step {i}",
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// save_at_step is the number of steps to run the scheduler before saving and loading back its
|
||||
// state.
|
||||
pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)
|
||||
where
|
||||
S: Clone + LrScheduler,
|
||||
{
|
||||
let mut truth = scheduler.clone();
|
||||
// Consume some steps before saving and loading back
|
||||
(0..save_at_step).for_each(|_| {
|
||||
truth.step();
|
||||
scheduler.step();
|
||||
});
|
||||
let rec = scheduler.to_record::<TestBackend>();
|
||||
scheduler = scheduler.load_record::<TestBackend>(rec);
|
||||
|
||||
// Validate that the scheduler resumes from where it left off.
|
||||
compare_steps(&mut scheduler, &mut truth, save_at_step);
|
||||
}
|
||||
|
||||
// Check if two schedulers produce the same learning rate sequences over the specified number of
|
||||
// steps.
|
||||
pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
|
||||
(0..num_steps).for_each(|i| {
|
||||
let lr_a = a.step();
|
||||
let lr_b = b.step();
|
||||
assert!(
|
||||
(lr_a - lr_b).abs() < LOOSE_EPSILON,
|
||||
"The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
|
||||
sequences are not approximately equal",
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};
|
||||
use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};
|
||||
use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};
|
||||
use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};
|
||||
use super::{LrScheduler, String};
|
||||
use crate::LearningRate;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Compose multiple [learning rate schedulers](LrScheduler) together.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ComposedLrSchedulerConfig {
|
||||
#[config(default = "Vec::new()")]
|
||||
schedulers: Vec<LrSchedulerConfig>,
|
||||
#[config(default = "SchedulerReduction::Prod")]
|
||||
reduction: SchedulerReduction,
|
||||
}
|
||||
|
||||
/// Compose multiple [learning rate schedulers](LrScheduler) together.
|
||||
#[derive(Clone)]
|
||||
pub struct ComposedLrScheduler {
|
||||
schedulers: Vec<LrSchedulerItem>,
|
||||
reduction: SchedulerReduction,
|
||||
}
|
||||
|
||||
/// Defines how the learning rates generated by the schedulers are combined.
|
||||
#[derive(Config, Debug, Copy)]
|
||||
pub enum SchedulerReduction {
|
||||
/// All learning rates are averaged.
|
||||
Avg,
|
||||
/// All learning rates are summed.
|
||||
Sum,
|
||||
/// All learning rates are multiplied.
|
||||
Prod,
|
||||
}
|
||||
|
||||
impl ComposedLrSchedulerConfig {
|
||||
/// Initialize the learning rate scheduler.
|
||||
pub fn init(&self) -> Result<ComposedLrScheduler, String> {
|
||||
let mut schedulers = Vec::with_capacity(self.schedulers.len());
|
||||
for config in self.schedulers.iter() {
|
||||
let config = match config {
|
||||
LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),
|
||||
LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),
|
||||
LrSchedulerConfig::Exponential(config) => {
|
||||
LrSchedulerItem::Exponential(config.init()?)
|
||||
}
|
||||
LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),
|
||||
};
|
||||
schedulers.push(config);
|
||||
}
|
||||
|
||||
Ok(ComposedLrScheduler {
|
||||
schedulers,
|
||||
reduction: self.reduction,
|
||||
})
|
||||
}
|
||||
|
||||
/// Appends a [linear scheduler](LinearLrScheduler).
|
||||
pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {
|
||||
self.schedulers.push(LrSchedulerConfig::Linear(config));
|
||||
self
|
||||
}
|
||||
|
||||
/// Appends a [cosine scheduler](ComposedLrSchedulerConfig).
|
||||
pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {
|
||||
self.schedulers.push(LrSchedulerConfig::Cosine(config));
|
||||
self
|
||||
}
|
||||
|
||||
/// Appends an [exponential scheduler](ExponentialLrScheduler).
|
||||
pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {
|
||||
self.schedulers.push(LrSchedulerConfig::Exponential(config));
|
||||
self
|
||||
}
|
||||
|
||||
/// Appends a [noam scheduler](NoamLrScheduler).
|
||||
pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {
|
||||
self.schedulers.push(LrSchedulerConfig::Noam(config));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
enum LrSchedulerConfig {
|
||||
Linear(LinearLrSchedulerConfig),
|
||||
Cosine(CosineAnnealingLrSchedulerConfig),
|
||||
Exponential(ExponentialLrSchedulerConfig),
|
||||
Noam(NoamLrSchedulerConfig),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum LrSchedulerItem {
|
||||
Linear(LinearLrScheduler),
|
||||
Cosine(CosineAnnealingLrScheduler),
|
||||
Exponential(ExponentialLrScheduler),
|
||||
Noam(NoamLrScheduler),
|
||||
}
|
||||
|
||||
#[derive(Record)]
|
||||
/// Record item for the [composed learning rate scheduler](ComposedLrScheduler).
|
||||
pub enum LrSchedulerRecord<B: Backend> {
|
||||
/// The linear variant.
|
||||
Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
|
||||
/// The cosine variant.
|
||||
Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),
|
||||
/// The exponential variant.
|
||||
Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),
|
||||
/// The noam variant.
|
||||
Noam(<NoamLrScheduler as LrScheduler>::Record<B>),
|
||||
}
|
||||
|
||||
#[derive(Record)]
|
||||
/// Records for the [composed learning rate scheduler](ComposedLrScheduler).
|
||||
pub struct ComposedLrSchedulerRecord<B: Backend> {
|
||||
schedulers: Vec<LrSchedulerRecord<B>>,
|
||||
}
|
||||
|
||||
impl LrScheduler for ComposedLrScheduler {
|
||||
type Record<B: Backend> = ComposedLrSchedulerRecord<B>;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
let mut step = match self.reduction {
|
||||
SchedulerReduction::Avg => 0.0,
|
||||
SchedulerReduction::Sum => 0.0,
|
||||
SchedulerReduction::Prod => 1.0,
|
||||
};
|
||||
let num_scheduler = self.schedulers.len() as f64;
|
||||
|
||||
for lr in self.schedulers.iter_mut().map(|s| match s {
|
||||
LrSchedulerItem::Linear(item) => item.step(),
|
||||
LrSchedulerItem::Cosine(item) => item.step(),
|
||||
LrSchedulerItem::Exponential(item) => item.step(),
|
||||
LrSchedulerItem::Noam(item) => item.step(),
|
||||
}) {
|
||||
step = match self.reduction {
|
||||
SchedulerReduction::Avg => step + (lr / num_scheduler),
|
||||
SchedulerReduction::Sum => step + lr,
|
||||
SchedulerReduction::Prod => step * lr,
|
||||
}
|
||||
}
|
||||
|
||||
step
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {
|
||||
ComposedLrSchedulerRecord::<B> {
|
||||
schedulers: self
|
||||
.schedulers
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
LrSchedulerItem::Linear(item) => {
|
||||
LrSchedulerRecord::Linear(item.to_record::<B>())
|
||||
}
|
||||
LrSchedulerItem::Cosine(item) => {
|
||||
LrSchedulerRecord::Linear(item.to_record::<B>())
|
||||
}
|
||||
LrSchedulerItem::Exponential(item) => {
|
||||
LrSchedulerRecord::Exponential(item.to_record::<B>())
|
||||
}
|
||||
LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
|
||||
self.schedulers = self
|
||||
.schedulers
|
||||
.into_iter()
|
||||
.zip(record.schedulers)
|
||||
.map(|scheduler| match scheduler {
|
||||
(LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {
|
||||
LrSchedulerItem::Linear(item.load_record::<B>(record))
|
||||
}
|
||||
(LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {
|
||||
LrSchedulerItem::Cosine(item.load_record::<B>(record))
|
||||
}
|
||||
(LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {
|
||||
LrSchedulerItem::Exponential(item.load_record::<B>(record))
|
||||
}
|
||||
(LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {
|
||||
LrSchedulerItem::Noam(item.load_record::<B>(record))
|
||||
}
|
||||
_ => panic!("Invalid state"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use super::LrScheduler;
|
||||
use crate::LearningRate;
|
||||
|
||||
/// Constant learning rate implementing [learning rate scheduler](LrScheduler).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// You can also use [learning rate](LearningRate) which the same effect.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct ConstantLr {
|
||||
lr: LearningRate,
|
||||
}
|
||||
|
||||
impl From<LearningRate> for ConstantLr {
|
||||
fn from(lr: LearningRate) -> Self {
|
||||
Self { lr }
|
||||
}
|
||||
}
|
||||
|
||||
impl LrScheduler for ConstantLr {
|
||||
type Record<B: Backend> = ();
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
self.lr
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {}
|
||||
|
||||
fn load_record<B: Backend>(self, _record: Self::Record<B>) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl LrScheduler for LearningRate {
|
||||
type Record<B: Backend> = ();
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
*self
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {}
|
||||
|
||||
fn load_record<B: Backend>(self, _record: Self::Record<B>) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::{LrScheduler, String};
|
||||
use crate::LearningRate;
|
||||
use burn::config::Config;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// The configuration for creating a [Cosine Annealing learning rate scheduler with warm
|
||||
/// restarts](CosineAnnealingLrScheduler).
|
||||
///
|
||||
/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by
|
||||
/// following a cosine function. After `num_iters` iterations, the learning rate is reset to
|
||||
/// `initial_lr`.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct CosineAnnealingLrSchedulerConfig {
|
||||
// The initial learning rate.
|
||||
initial_lr: LearningRate,
|
||||
// The final learning rate.
|
||||
#[config(default = 0.0)]
|
||||
min_lr: LearningRate,
|
||||
// The number of iterations between two restarts. The two restart iterations themselves are not
|
||||
// included.
|
||||
num_iters: usize,
|
||||
}
|
||||
|
||||
impl CosineAnnealingLrSchedulerConfig {
|
||||
/// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// An error will be returned if any of the following conditions is true:
|
||||
///
|
||||
/// * `initial_lr` is out of range (0.0, 1.0]
|
||||
/// * `min_lr` is out of range [0.0, `initial_lr`]
|
||||
/// * `num_iters` is 0
|
||||
pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
|
||||
if self.initial_lr <= 0. || self.initial_lr > 1. {
|
||||
return Err("Initial learning rate must be greater than 0 and at most 1".into());
|
||||
}
|
||||
if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
|
||||
return Err(
|
||||
"Minimum learning rate must be at least 0 and at most equal to the initial \
|
||||
learning rate"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
if self.num_iters == 0 {
|
||||
return Err("Number of iterations must be at least 1".into());
|
||||
}
|
||||
|
||||
Ok(CosineAnnealingLrScheduler {
|
||||
min_lr: self.min_lr,
|
||||
max_lr: self.initial_lr,
|
||||
num_iters: self.num_iters,
|
||||
current_iter: usize::MAX,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Cosine Annealing learning rate scheduler.
|
||||
///
|
||||
/// This scheduler is described in [SGDR: Stochastic Gradient Descent with Warm
|
||||
/// Restarts](https://arxiv.org/abs/1608.03983). See [CosineAnnealingLrSchedulerConfig] for more
|
||||
/// information.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct CosineAnnealingLrScheduler {
|
||||
min_lr: LearningRate,
|
||||
max_lr: LearningRate,
|
||||
num_iters: usize,
|
||||
current_iter: usize,
|
||||
}
|
||||
|
||||
impl LrScheduler for CosineAnnealingLrScheduler {
|
||||
type Record<B: Backend> = usize;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
// Make current_iter overflow from usize::MAX to 0 to get the initial learning rate on the
|
||||
// first call. We could've used i64 with an initial value -1, but keeping it in usize saves
|
||||
// us from some type casting here.
|
||||
self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);
|
||||
self.min_lr
|
||||
+ 0.5
|
||||
* (self.max_lr - self.min_lr)
|
||||
* (1.0
|
||||
+ (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)
|
||||
.cos())
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {
|
||||
self.current_iter
|
||||
}
|
||||
|
||||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
|
||||
self.current_iter = record;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::test_utils;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn config_initial_lr_too_low() {
|
||||
let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Initial learning rate must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_initial_lr_too_high() {
|
||||
let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Initial learning rate must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_min_lr_too_low() {
|
||||
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
|
||||
.with_min_lr(-0.1)
|
||||
.init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Minimum learning rate must be at least 0 and at most equal to the initial learning \
|
||||
rate",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_min_lr_too_high() {
|
||||
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
|
||||
.with_min_lr(0.6)
|
||||
.init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Minimum learning rate must be at least 0 and at most equal to the initial learning \
|
||||
rate",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_num_iters_too_low() {
|
||||
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Number of iterations must be at least 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_change() {
|
||||
const INITIAL_LR: LearningRate = 0.5;
|
||||
const MIN_LR: LearningRate = 0.1;
|
||||
|
||||
let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
|
||||
.with_min_lr(MIN_LR)
|
||||
.init()
|
||||
.unwrap();
|
||||
let expected_lrs = [
|
||||
INITIAL_LR, // cos(0)
|
||||
(INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2)
|
||||
MIN_LR, // cos(PI)
|
||||
INITIAL_LR, // restart
|
||||
];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
const NUM_ITERS: usize = 9;
|
||||
let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
|
||||
.init()
|
||||
.unwrap();
|
||||
test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::{LrScheduler, String};
|
||||
use crate::LearningRate;
|
||||
use burn::config::Config;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// The configuration for creating an [exponential learning rate scheduler](ExponentialLrScheduler).
|
||||
///
|
||||
/// This scheduler returns the learning rate `initial_lr` at the first step, then multiplies it by
|
||||
/// a constant `gamma` at every iteration. At any iteration `i` (which starts from 0), the learning
|
||||
/// rate is given by `initial_lr * gamma^i`.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ExponentialLrSchedulerConfig {
|
||||
// The initial learning rate.
|
||||
initial_lr: LearningRate,
|
||||
// The constant that the learning rate is multiplied by on each iteration.
|
||||
gamma: f64,
|
||||
}
|
||||
|
||||
impl ExponentialLrSchedulerConfig {
|
||||
/// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// An error will be returned if any of the following conditions is true:
|
||||
///
|
||||
/// * `initial_lr` is out of range (0.0, 1.0]
|
||||
/// * `gamma` is out of range (0.0, 1.0]
|
||||
pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
|
||||
if self.initial_lr <= 0. || self.initial_lr > 1. {
|
||||
return Err("Initial learning rate must be greater than 0 and at most 1".into());
|
||||
}
|
||||
if self.gamma <= 0. || self.gamma > 1. {
|
||||
return Err("Gamma must be greater than 0 and at most 1".into());
|
||||
}
|
||||
|
||||
Ok(ExponentialLrScheduler {
|
||||
// Such an initial value eliminates the need for special-case handling of the first
|
||||
// learning rate.
|
||||
previous_lr: self.initial_lr / self.gamma,
|
||||
gamma: self.gamma,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A exponential learning rate scheduler.
|
||||
///
|
||||
/// See [ExponentialLrSchedulerConfig] for more information.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ExponentialLrScheduler {
|
||||
// The previous iteration's learning rate.
|
||||
previous_lr: LearningRate,
|
||||
// The constant that the learning rate is multiplied by on each iteration.
|
||||
gamma: f64,
|
||||
}
|
||||
|
||||
impl LrScheduler for ExponentialLrScheduler {
|
||||
type Record<B: Backend> = LearningRate;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
self.previous_lr *= self.gamma;
|
||||
self.previous_lr
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {
|
||||
self.previous_lr
|
||||
}
|
||||
|
||||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
|
||||
self.previous_lr = record;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::test_utils;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn config_initial_lr_too_low() {
|
||||
let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Initial learning rate must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_initial_lr_too_high() {
|
||||
let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Initial learning rate must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_gamma_too_low() {
|
||||
let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Gamma must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_gamma_too_high() {
|
||||
let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Gamma must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_change() {
|
||||
let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
|
||||
let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
|
||||
.init()
|
||||
.unwrap();
|
||||
test_utils::check_save_load(scheduler, 7);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::{LrScheduler, String};
|
||||
use crate::LearningRate;
|
||||
use burn::config::Config;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// The configuration for creating a [linear learning rate scheduler](LinearLrScheduler).
|
||||
///
|
||||
/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by a
|
||||
/// constant amount on each iteration until reaching a final learning rate `final_lr`. The
|
||||
/// `num_iters` parameter controls how many iterations are needed to go from `initial_lr` to
|
||||
/// `final_lr`.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LinearLrSchedulerConfig {
|
||||
// The initial learning rate.
|
||||
initial_lr: LearningRate,
|
||||
// The final learning rate.
|
||||
final_lr: LearningRate,
|
||||
// The number of iterations before reaching the final learning rate.
|
||||
num_iters: usize,
|
||||
}
|
||||
|
||||
impl LinearLrSchedulerConfig {
|
||||
/// Initializes a [linear learning rate scheduler](LinearLrScheduler).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// An error will be returned if any of the following conditions is true:
|
||||
///
|
||||
/// * `initial_lr` is out of range (0.0, 1.0]
|
||||
/// * `final_lr` is out of range [0.0, 1.0]
|
||||
/// * `num_iters` is 0
|
||||
pub fn init(&self) -> Result<LinearLrScheduler, String> {
|
||||
if self.initial_lr <= 0. || self.initial_lr > 1. {
|
||||
return Err("Initial learning rate must be greater than 0 and at most 1".into());
|
||||
}
|
||||
if self.final_lr < 0. || self.final_lr > 1. {
|
||||
return Err("Final learning rate must be at least 0 and at most 1".into());
|
||||
}
|
||||
if self.num_iters == 0 {
|
||||
return Err("Number of iterations must be at least 1".into());
|
||||
}
|
||||
|
||||
Ok(LinearLrScheduler {
|
||||
final_lr: self.final_lr,
|
||||
step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,
|
||||
remaining_iters: self.num_iters + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A linear learning rate scheduler.
|
||||
///
|
||||
/// See [LinearLrSchedulerConfig] for more information.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct LinearLrScheduler {
|
||||
// The final learning rate after the linear changing process stops.
|
||||
final_lr: LearningRate,
|
||||
// The amount that the learning rate changes by on each iteration.
|
||||
step_size: f64,
|
||||
// The number of iterations left before reaching the final learning rate.
|
||||
remaining_iters: usize,
|
||||
}
|
||||
|
||||
impl LrScheduler for LinearLrScheduler {
|
||||
type Record<B: Backend> = usize;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
self.remaining_iters -= (self.remaining_iters != 0) as usize;
|
||||
self.final_lr - self.step_size * self.remaining_iters as f64
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {
|
||||
self.remaining_iters
|
||||
}
|
||||
|
||||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
|
||||
self.remaining_iters = record;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::test_utils;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn config_initial_lr_too_low() {
|
||||
let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Initial learning rate must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_initial_lr_too_high() {
|
||||
let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Initial learning rate must be greater than 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_final_lr_too_low() {
|
||||
let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Final learning rate must be at least 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_final_lr_too_high() {
|
||||
let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Final learning rate must be at least 0 and at most 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_num_iters_too_low() {
|
||||
let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
assert_eq!(
|
||||
r.unwrap_err(),
|
||||
"Number of iterations must be at least 1",
|
||||
"Error messages should match",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_decreasing() {
|
||||
let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();
|
||||
let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_increasing() {
|
||||
let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();
|
||||
let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_unchanging() {
|
||||
let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();
|
||||
let expected_lrs = [0.3, 0.3, 0.3, 0.3];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
const NUM_ITERS: usize = 6;
|
||||
let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)
|
||||
.init()
|
||||
.unwrap();
|
||||
test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
/// Constant learning rate scheduler
|
||||
pub mod constant;
|
||||
|
||||
/// Composed learning rate scheduler
|
||||
pub mod composed;
|
||||
|
||||
/// Linear learning rate scheduler
|
||||
pub mod linear;
|
||||
|
||||
/// Noam learning rate scheduler
|
||||
pub mod noam;
|
||||
|
||||
/// Exponential learning rate scheduler
|
||||
pub mod exponential;
|
||||
|
||||
/// Cosine learning rate scheduler
|
||||
pub mod cosine;
|
||||
|
||||
/// Step learning rate scheduler
|
||||
pub mod step;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,136 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use super::{LrScheduler, String};
|
||||
use crate::LearningRate;
|
||||
|
||||
/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct NoamLrSchedulerConfig {
|
||||
/// The overall scale factor for the learning rate decay.
|
||||
factor: f64,
|
||||
/// The number of steps before the exponential decay stats.
|
||||
#[config(default = 4000)]
|
||||
warmup_steps: usize,
|
||||
/// The size of the model.
|
||||
#[config(default = 512)]
|
||||
model_size: usize,
|
||||
}
|
||||
|
||||
/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NoamLrScheduler {
|
||||
warmup_steps: f64,
|
||||
embedding_size: f64,
|
||||
factor: f64,
|
||||
step: f64,
|
||||
}
|
||||
|
||||
impl NoamLrSchedulerConfig {
|
||||
/// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// An error will be returned if any of the following conditions is true:
|
||||
///
|
||||
/// * `warmup_steps` is 0
|
||||
/// * `model_size` is 0
|
||||
pub fn init(&self) -> Result<NoamLrScheduler, String> {
|
||||
if self.warmup_steps == 0 {
|
||||
return Err(
|
||||
"Number of steps before exponential decay starts must be greater than 0".into(),
|
||||
);
|
||||
}
|
||||
if self.model_size == 0 {
|
||||
return Err("Model size must be greater than 0".into());
|
||||
}
|
||||
|
||||
Ok(NoamLrScheduler {
|
||||
warmup_steps: self.warmup_steps as f64,
|
||||
embedding_size: self.model_size as f64,
|
||||
factor: self.factor,
|
||||
step: 0.0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LrScheduler for NoamLrScheduler {
|
||||
type Record<B: Backend> = usize;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
self.step += 1.0;
|
||||
|
||||
let arg1 = self.step.powf(-0.5);
|
||||
let arg2 = self.step * self.warmup_steps.powf(-1.5);
|
||||
|
||||
self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {
|
||||
self.step as usize
|
||||
}
|
||||
|
||||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
|
||||
self.step = record as f64;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_warmup_steps_invalid() {
|
||||
let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_warmup_steps_valid() {
|
||||
let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();
|
||||
assert!(r.is_ok(), "Should return a success value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_model_size_invalid() {
|
||||
let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_model_size_valid() {
|
||||
let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();
|
||||
assert!(r.is_ok(), "Should return a success value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_increase_and_decrease() {
|
||||
let warmup_steps = 100;
|
||||
let mut scheduler = NoamLrSchedulerConfig::new(10.0)
|
||||
.with_warmup_steps(warmup_steps)
|
||||
.init()
|
||||
.unwrap();
|
||||
let mut lr_current = 0.0;
|
||||
|
||||
for _ in 0..warmup_steps {
|
||||
let lr = scheduler.step();
|
||||
assert!(
|
||||
lr > lr_current,
|
||||
"Learning rate should increase before the warmup_steps is reached."
|
||||
);
|
||||
lr_current = lr;
|
||||
}
|
||||
|
||||
for _ in 0..warmup_steps {
|
||||
let lr = scheduler.step();
|
||||
assert!(
|
||||
lr < lr_current,
|
||||
"Learning rate should decrease after the warmup_steps is reached."
|
||||
);
|
||||
lr_current = lr;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
use super::{LrScheduler, String};
|
||||
use crate::LearningRate;
|
||||
|
||||
/// The configuration for create a [step learning rate scheduler](StepLrScheduler).
|
||||
///
|
||||
/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until
|
||||
/// the same value has been given for `step_size` times. Then it multiplies the learning rate by
|
||||
/// `gamma` before repeating the process.
|
||||
///
|
||||
/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but
|
||||
/// a warning log will be output for such a value in case of mistyping.
|
||||
///
|
||||
/// ## Notes
|
||||
///
|
||||
/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than
|
||||
/// `i32::MAX + 1` times.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct StepLrSchedulerConfig {
|
||||
// The learning rate at the initial step.
|
||||
initial_lr: LearningRate,
|
||||
// The number of iterations over which the learning rate remains unchanged before the next
|
||||
// update.
|
||||
step_size: usize,
|
||||
/// The factor by which the learning rate is multiplied with each update. Default: 0.1.
|
||||
#[config(default = 0.1)]
|
||||
gamma: f64,
|
||||
}
|
||||
|
||||
impl StepLrSchedulerConfig {
|
||||
/// Initializes a [step learning rate scheduler](StepLrScheduler).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// An error will be returned if `step_size` is 0.
|
||||
pub fn init(&self) -> Result<StepLrScheduler, String> {
|
||||
if self.step_size == 0 {
|
||||
return Err("Step size must be greater than 0".into());
|
||||
}
|
||||
|
||||
// Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful
|
||||
// in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518).
|
||||
if self.initial_lr <= 0.0 {
|
||||
log::warn!(
|
||||
"Initial learning rate value of {} is not a positive number. Ignore this warning \
|
||||
if it is intended.",
|
||||
self.initial_lr
|
||||
);
|
||||
}
|
||||
if self.gamma <= 0.0 || self.gamma >= 1.0 {
|
||||
log::warn!(
|
||||
"Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
|
||||
intended.",
|
||||
self.gamma
|
||||
);
|
||||
}
|
||||
|
||||
Ok(StepLrScheduler {
|
||||
init_lr: self.initial_lr,
|
||||
step_size: self.step_size,
|
||||
gamma: self.gamma,
|
||||
iter_idx: -1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Step learning rate scheduler.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StepLrScheduler {
|
||||
init_lr: LearningRate,
|
||||
step_size: usize,
|
||||
gamma: f64,
|
||||
// The index of the current iteration.
|
||||
// `i32` is used for avoiding truncating the exponent when taking powers of `gamma`.
|
||||
iter_idx: i32,
|
||||
}
|
||||
|
||||
impl LrScheduler for StepLrScheduler {
|
||||
type Record<B: Backend> = i32;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
self.iter_idx = self
|
||||
.iter_idx
|
||||
.checked_add(1)
|
||||
.expect("`.step()` should be called no more than `i32::MAX + 1` times");
|
||||
// Type casting below causes no truncation, as all the values fall within the ranges.
|
||||
self.init_lr
|
||||
* self
|
||||
.gamma
|
||||
.powi((self.iter_idx as usize / self.step_size) as i32)
|
||||
}
|
||||
|
||||
fn to_record<B: Backend>(&self) -> Self::Record<B> {
|
||||
self.iter_idx
|
||||
}
|
||||
|
||||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
|
||||
self.iter_idx = record;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::test_utils;
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
// Warning logs for initial LR and gamma are not tested because there seems no straightforward
|
||||
// way to do it.
|
||||
//
|
||||
// Creating a mock logger that collects logs into `String` for later examination seems a possible
|
||||
// solution, but unit tests run in the same process in parallel, where the single logger would
|
||||
// be shared by multiple tests, so logs from different tests would be mixed up with no easy way
|
||||
// to separate them.
|
||||
// Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is
|
||||
// worth the slowdown would be a question. Also, using a primitive provided by `std` to
|
||||
// synchronize the logger across tests is not an option since we need to support `no-std`.
|
||||
// Maybe the mocking approach can be reconsidered after we are given an option to run tests in
|
||||
// separate processes like what the issue below is proposing:
|
||||
// https://github.com/rust-lang/rust/issues/47506
|
||||
//
|
||||
// As a side note, a helper crate exists for the exact purpose:
|
||||
// https://crates.io/crates/testing_logger
|
||||
// but the crate has been unmaintained and using it would introduce another dependency.
|
||||
|
||||
#[test]
|
||||
fn test_config_step_size_zero() {
|
||||
let r = StepLrSchedulerConfig::new(1.0, 0).init();
|
||||
assert!(r.is_err(), "Should return an error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_step_size_nonzero() {
|
||||
let r = StepLrSchedulerConfig::new(1.0, 1).init();
|
||||
assert!(r.is_ok(), "Should return a success value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_default_gamma() {
|
||||
const INIT_LR: LearningRate = 0.4;
|
||||
const STEP_SIZE: usize = 2;
|
||||
|
||||
let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
|
||||
.init()
|
||||
.unwrap();
|
||||
let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
|
||||
.with_gamma(0.1)
|
||||
.init()
|
||||
.unwrap();
|
||||
test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_decreasing() {
|
||||
let scheduler = StepLrSchedulerConfig::new(0.5, 3)
|
||||
.with_gamma(0.1)
|
||||
.init()
|
||||
.unwrap();
|
||||
let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_increasing() {
|
||||
let scheduler = StepLrSchedulerConfig::new(0.1, 2)
|
||||
.with_gamma(2.0)
|
||||
.init()
|
||||
.unwrap();
|
||||
let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lr_unchanging() {
|
||||
let scheduler = StepLrSchedulerConfig::new(3.1, 1)
|
||||
.with_gamma(1.0)
|
||||
.init()
|
||||
.unwrap();
|
||||
let expected_lrs = [3.1, 3.1, 3.1];
|
||||
test_utils::check_lr_sequence(scheduler, expected_lrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
const STEP_SIZE: usize = 10;
|
||||
|
||||
let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
|
||||
.with_gamma(0.03)
|
||||
.init()
|
||||
.unwrap();
|
||||
test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
|
||||
}
|
||||
|
||||
// It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that
|
||||
// depends on private fields is used to implement the test.
|
||||
#[test]
|
||||
fn test_number_of_calls_within_limit() {
|
||||
// Create a scheduler that has already run `i32::MAX` steps
|
||||
let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
|
||||
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
|
||||
scheduler.step();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "i32::MAX"]
|
||||
fn test_number_of_calls_over_limit() {
|
||||
// Create a scheduler that has already run `i32::MAX` steps
|
||||
let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
|
||||
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
|
||||
scheduler.step();
|
||||
scheduler.step();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::{module::AutodiffModule, record::Record};
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use burn::tensor::{backend::Backend, ops::Device};
|
||||
|
||||
use super::{
|
||||
SimpleOptimizer,
|
||||
adaptor::OptimizerAdaptor,
|
||||
decay::{WeightDecay, WeightDecayConfig},
|
||||
};
|
||||
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
|
||||
|
||||
/// AdaGrad configuration.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AdaGradConfig {
|
||||
#[config(default = 0.)]
|
||||
lr_decay: f64,
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f32,
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
weight_decay: Option<WeightDecayConfig>,
|
||||
/// [Gradient Clipping](GradientClippingConfig) config.
|
||||
grad_clipping: Option<GradientClippingConfig>,
|
||||
}
|
||||
|
||||
/// AdaGrad optimizer
|
||||
#[derive(Clone)]
|
||||
pub struct AdaGrad {
|
||||
lr_decay: LrDecay,
|
||||
weight_decay: Option<WeightDecay>,
|
||||
}
|
||||
|
||||
/// AdaGrad state.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct AdaGradState<B: Backend, const D: usize> {
|
||||
lr_decay: LrDecayState<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for AdaGrad {
|
||||
type State<const D: usize> = AdaGradState<B, D>;
|
||||
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
lr: LearningRate,
|
||||
tensor: Tensor<B, D>,
|
||||
mut grad: Tensor<B, D>,
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||
let mut state_lr_decay = None;
|
||||
|
||||
if let Some(state) = state {
|
||||
state_lr_decay = Some(state.lr_decay);
|
||||
}
|
||||
|
||||
if let Some(weight_decay) = &self.weight_decay {
|
||||
grad = weight_decay.transform(grad, tensor.clone());
|
||||
}
|
||||
|
||||
let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay);
|
||||
|
||||
let state = AdaGradState::new(state_lr_decay);
|
||||
|
||||
(tensor - grad, Some(state))
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
|
||||
state.lr_decay = state.lr_decay.to_device(device);
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaGradConfig {
|
||||
/// Initialize AdaGrad optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
&self,
|
||||
) -> OptimizerAdaptor<AdaGrad, M, B> {
|
||||
let optim = AdaGrad {
|
||||
lr_decay: LrDecay {
|
||||
lr_decay: self.lr_decay,
|
||||
epsilon: self.epsilon,
|
||||
},
|
||||
weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
|
||||
};
|
||||
|
||||
let mut optim = OptimizerAdaptor::from(optim);
|
||||
if let Some(config) = &self.grad_clipping {
|
||||
optim = optim.with_grad_clipping(config.init());
|
||||
}
|
||||
optim
|
||||
}
|
||||
}
|
||||
|
||||
/// Learning rate decay state (also includes sum state).
|
||||
#[derive(Record, new, Clone)]
|
||||
pub struct LrDecayState<B: Backend, const D: usize> {
|
||||
time: usize,
|
||||
sum: Tensor<B, D>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LrDecay {
|
||||
lr_decay: f64,
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl LrDecay {
|
||||
pub fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
lr: LearningRate,
|
||||
lr_decay_state: Option<LrDecayState<B, D>>,
|
||||
) -> (Tensor<B, D>, LrDecayState<B, D>) {
|
||||
let state = if let Some(mut state) = lr_decay_state {
|
||||
state.sum = state.sum.add(grad.clone().square());
|
||||
state.time += 1;
|
||||
state
|
||||
} else {
|
||||
LrDecayState::new(1, grad.clone().square())
|
||||
};
|
||||
|
||||
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
||||
|
||||
let grad = grad
|
||||
.div(state.sum.clone().sqrt().add_scalar(self.epsilon))
|
||||
.mul_scalar(new_lr);
|
||||
|
||||
(grad, state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> LrDecayState<B, D> {
|
||||
/// Move state to device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns state moved to device.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.sum = self.sum.to_device(device);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::Tolerance;
|
||||
use burn::tensor::ops::FloatElem;
|
||||
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use crate::{GradientsParams, Optimizer};
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Distribution, Tensor, TensorData};
|
||||
use burn_nn::{Linear, LinearConfig, LinearRecord};
|
||||
|
||||
const LEARNING_RATE: LearningRate = 0.01;
|
||||
|
||||
#[test]
|
||||
fn test_adagrad_optimizer_save_load_state() {
|
||||
let device = Default::default();
|
||||
let linear = LinearConfig::new(6, 6).init(&device);
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
|
||||
let mut optimizer = create_adagrad();
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
BinFileRecorder::<FullPrecisionSettings>::default()
|
||||
.record(
|
||||
optimizer.to_record(),
|
||||
std::env::temp_dir().as_path().join("test_optim_adagrad"),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.record(optimizer.to_record(), ())
|
||||
.unwrap();
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
let state_optim_before = optimizer.to_record();
|
||||
let state_optim_before_copy = optimizer.to_record();
|
||||
let optimizer = create_adagrad();
|
||||
let optimizer = optimizer.load_record(state_optim_before_copy);
|
||||
let state_optim_after = optimizer.to_record();
|
||||
|
||||
assert_eq!(state_optim_before.len(), state_optim_after.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adagrad_optimizer_with_numbers() {
|
||||
let device = Default::default();
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdaGradConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_lr_decay(0.5)
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x_1).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x_2).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weights_expected = TensorData::from([
|
||||
[-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711],
|
||||
[
|
||||
0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756,
|
||||
],
|
||||
[
|
||||
-0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538,
|
||||
],
|
||||
[
|
||||
-0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964,
|
||||
],
|
||||
[
|
||||
0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504,
|
||||
],
|
||||
[-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714,
|
||||
]);
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.val().into_data(),
|
||||
state_updated.bias.unwrap().val().into_data(),
|
||||
);
|
||||
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
let tolerance = Tolerance::absolute(1e-6);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
}
|
||||
|
||||
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
|
||||
let device = Default::default();
|
||||
let record = LinearRecord {
|
||||
weight: Param::from_data(weight, &device),
|
||||
bias: Some(Param::from_data(bias, &device)),
|
||||
};
|
||||
|
||||
LinearConfig::new(6, 6).init(&device).load_record(record)
|
||||
}
|
||||
|
||||
fn create_adagrad()
|
||||
-> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
|
||||
let config = AdaGradConfig::new();
|
||||
AdaGrad {
|
||||
lr_decay: LrDecay {
|
||||
lr_decay: config.lr_decay,
|
||||
epsilon: config.epsilon,
|
||||
},
|
||||
weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::{module::AutodiffModule, record::Record};
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use burn::tensor::{backend::Backend, ops::Device};
|
||||
|
||||
use super::{
|
||||
SimpleOptimizer,
|
||||
adaptor::OptimizerAdaptor,
|
||||
decay::{WeightDecay, WeightDecayConfig},
|
||||
};
|
||||
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Adam configuration.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AdamConfig {
|
||||
/// Parameter for Adam.
|
||||
#[config(default = 0.9)]
|
||||
beta_1: f32,
|
||||
/// Parameter for Adam.
|
||||
#[config(default = 0.999)]
|
||||
beta_2: f32,
|
||||
/// A value required for numerical stability.
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f32,
|
||||
/// Whether to use AMSGrad algorithm
|
||||
#[config(default = false)]
|
||||
amsgrad: bool,
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
weight_decay: Option<WeightDecayConfig>,
|
||||
/// [Gradient Clipping](GradientClippingConfig) config.
|
||||
grad_clipping: Option<GradientClippingConfig>,
|
||||
}
|
||||
|
||||
/// Adam optimizer.
|
||||
///
|
||||
/// See:
|
||||
/// - [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
|
||||
/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
|
||||
#[derive(Clone)]
|
||||
pub struct Adam {
|
||||
momentum: AdaptiveMomentum,
|
||||
weight_decay: Option<WeightDecay>,
|
||||
}
|
||||
|
||||
/// Adam state.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct AdamState<B: Backend, const D: usize> {
|
||||
/// The current adaptive momentum.
|
||||
pub momentum: AdaptiveMomentumState<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for Adam {
|
||||
type State<const D: usize> = AdamState<B, D>;
|
||||
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
lr: LearningRate,
|
||||
tensor: Tensor<B, D>,
|
||||
mut grad: Tensor<B, D>,
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||
let mut state_momentum = None;
|
||||
|
||||
if let Some(state) = state {
|
||||
state_momentum = Some(state.momentum);
|
||||
}
|
||||
|
||||
if let Some(weight_decay) = &self.weight_decay {
|
||||
grad = weight_decay.transform(grad, tensor.clone());
|
||||
}
|
||||
|
||||
let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
|
||||
|
||||
let state = AdamState::new(state_momentum);
|
||||
let delta = grad.mul_scalar(lr);
|
||||
|
||||
(tensor - delta, Some(state))
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
|
||||
state.momentum = state.momentum.to_device(device);
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
impl AdamConfig {
|
||||
/// Initialize Adam optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
|
||||
let optim = Adam {
|
||||
momentum: AdaptiveMomentum {
|
||||
beta_1: self.beta_1,
|
||||
beta_2: self.beta_2,
|
||||
epsilon: self.epsilon,
|
||||
amsgrad: self.amsgrad,
|
||||
},
|
||||
weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
|
||||
};
|
||||
|
||||
let mut optim = OptimizerAdaptor::from(optim);
|
||||
if let Some(config) = &self.grad_clipping {
|
||||
optim = optim.with_grad_clipping(config.init());
|
||||
}
|
||||
optim
|
||||
}
|
||||
}
|
||||
|
||||
/// Adaptive momentum state.
|
||||
#[derive(Record, new, Clone)]
|
||||
pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
|
||||
/// The number of iterations aggregated.
|
||||
pub time: usize,
|
||||
/// The first order momentum.
|
||||
pub moment_1: Tensor<B, D>,
|
||||
/// The second order momentum.
|
||||
pub moment_2: Tensor<B, D>,
|
||||
/// Max of second order momentum (for AMSGrad)
|
||||
#[new(default)]
|
||||
pub max_moment_2: Option<Tensor<B, D>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AdaptiveMomentum {
|
||||
beta_1: f32,
|
||||
beta_2: f32,
|
||||
epsilon: f32,
|
||||
amsgrad: bool,
|
||||
}
|
||||
|
||||
impl AdaptiveMomentum {
|
||||
pub fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
momentum_state: Option<AdaptiveMomentumState<B, D>>,
|
||||
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
|
||||
let state = if let Some(mut state) = momentum_state {
|
||||
let factor = 1.0 - self.beta_1;
|
||||
state.moment_1 = state
|
||||
.moment_1
|
||||
.mul_scalar(self.beta_1)
|
||||
.add(grad.clone().mul_scalar(factor));
|
||||
|
||||
let factor = 1.0 - self.beta_2;
|
||||
state.moment_2 = state
|
||||
.moment_2
|
||||
.mul_scalar(self.beta_2)
|
||||
.add(grad.square().mul_scalar(factor));
|
||||
if self.amsgrad {
|
||||
let max_v = state
|
||||
.max_moment_2
|
||||
.take()
|
||||
.unwrap_or_else(|| state.moment_2.clone());
|
||||
|
||||
let new_max = max_v.max_pair(state.moment_2.clone());
|
||||
state.max_moment_2 = Some(new_max);
|
||||
}
|
||||
|
||||
state.time += 1;
|
||||
|
||||
state
|
||||
} else {
|
||||
let factor = 1.0 - self.beta_1;
|
||||
let moment_1 = grad.clone().mul_scalar(factor);
|
||||
|
||||
let factor = 1.0 - self.beta_2;
|
||||
let moment_2 = grad.square().mul_scalar(factor);
|
||||
let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
|
||||
AdaptiveMomentumState {
|
||||
time: 1,
|
||||
moment_1,
|
||||
moment_2,
|
||||
max_moment_2,
|
||||
}
|
||||
};
|
||||
|
||||
let time = state.time as i32;
|
||||
let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt();
|
||||
let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time));
|
||||
|
||||
let v_to_use = if self.amsgrad {
|
||||
state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
|
||||
} else {
|
||||
&state.moment_2
|
||||
};
|
||||
|
||||
let grad = state.moment_1.clone().mul_scalar(combined_factor).div(
|
||||
v_to_use
|
||||
.clone()
|
||||
.sqrt()
|
||||
.add_scalar(self.epsilon * bias_correction2_sqrt),
|
||||
);
|
||||
(grad, state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
|
||||
/// Move state to device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns state moved to device.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.moment_1 = self.moment_1.to_device(device);
|
||||
self.moment_2 = self.moment_2.to_device(device);
|
||||
self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::Tolerance;
|
||||
use burn::tensor::ops::FloatElem;
|
||||
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use crate::{GradientsParams, Optimizer};
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Distribution, Tensor, TensorData};
|
||||
use burn_nn::{Linear, LinearConfig, LinearRecord};
|
||||
|
||||
const LEARNING_RATE: LearningRate = 0.01;
|
||||
|
||||
#[test]
|
||||
fn test_adam_optimizer_save_load_state() {
|
||||
let device = Default::default();
|
||||
let linear = LinearConfig::new(6, 6).init(&device);
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
|
||||
let mut optimizer = create_adam();
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
BinFileRecorder::<FullPrecisionSettings>::default()
|
||||
.record(
|
||||
optimizer.to_record(),
|
||||
std::env::temp_dir().as_path().join("test_optim_adam"),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.record(optimizer.to_record(), ())
|
||||
.unwrap();
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
let state_optim_before = optimizer.to_record();
|
||||
let state_optim_before_copy = optimizer.to_record();
|
||||
let optimizer = create_adam();
|
||||
let optimizer = optimizer.load_record(state_optim_before_copy);
|
||||
let state_optim_after = optimizer.to_record();
|
||||
|
||||
assert_eq!(state_optim_before.len(), state_optim_after.len());
|
||||
}
|
||||
#[test]
|
||||
fn test_adam_optimizer_with_amsgrad_50_steps() {
|
||||
let device = Default::default();
|
||||
let mut linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
|
||||
let mut optimizer = AdamConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_amsgrad(true)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
|
||||
.init();
|
||||
|
||||
for i in 1..=50 {
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
|
||||
.mul_scalar(i as f32 * 0.1)
|
||||
.require_grad();
|
||||
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
}
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weight_updated = state_updated.weight.to_data();
|
||||
let bias_updated = state_updated.bias.unwrap().to_data();
|
||||
|
||||
let weights_expected = TensorData::from([
|
||||
[
|
||||
-0.9125810265541077,
|
||||
-0.45855265855789185,
|
||||
-0.1915993094444275,
|
||||
-0.2759990692138672,
|
||||
-0.5099529027938843,
|
||||
-0.5287043452262878,
|
||||
],
|
||||
[
|
||||
-0.5181325674057007,
|
||||
-0.6139854788780212,
|
||||
-0.9574727416038513,
|
||||
-0.34102925658226013,
|
||||
-0.400514155626297,
|
||||
-0.8847861886024475,
|
||||
],
|
||||
[
|
||||
-0.614483118057251,
|
||||
-0.5611032247543335,
|
||||
-0.8887064456939697,
|
||||
-0.34762972593307495,
|
||||
-0.8708556890487671,
|
||||
-0.2830044627189636,
|
||||
],
|
||||
[
|
||||
-0.8904699683189392,
|
||||
-0.8151527643203735,
|
||||
-0.9621278643608093,
|
||||
-0.8905676603317261,
|
||||
-0.671261191368103,
|
||||
-0.4333854615688324,
|
||||
],
|
||||
[
|
||||
-0.26599061489105225,
|
||||
-0.8119961023330688,
|
||||
-0.22424538433551788,
|
||||
-0.7672406435012817,
|
||||
-0.2163349837064743,
|
||||
-0.6258266568183899,
|
||||
],
|
||||
[
|
||||
-0.611397922039032,
|
||||
-0.6075160503387451,
|
||||
-0.4701341986656189,
|
||||
-0.4039117991924286,
|
||||
-0.5663845539093018,
|
||||
-0.21262989938259125,
|
||||
],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.8817203044891357,
|
||||
-0.4038999378681183,
|
||||
-0.5889149308204651,
|
||||
-0.37475723028182983,
|
||||
-0.3557940721511841,
|
||||
-0.47914788126945496,
|
||||
]);
|
||||
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
let tolerance = Tolerance::absolute(1e-5);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
}
|
||||
#[test]
|
||||
fn test_adam_optimizer_with_numbers() {
|
||||
let device = Default::default();
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdamConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x_1).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x_2).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weights_expected = TensorData::from([
|
||||
[-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
|
||||
[
|
||||
0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
|
||||
],
|
||||
[
|
||||
-0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
|
||||
],
|
||||
[
|
||||
-0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
|
||||
],
|
||||
[
|
||||
0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
|
||||
],
|
||||
[-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
|
||||
]);
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.to_data(),
|
||||
state_updated.bias.unwrap().to_data(),
|
||||
);
|
||||
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
let tolerance = Tolerance::absolute(1e-2);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adam_optimizer_no_nan() {
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&Default::default(),
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdamConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x.clone()).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
|
||||
}
|
||||
|
||||
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
|
||||
let device = Default::default();
|
||||
let record = LinearRecord {
|
||||
weight: Param::from_data(weight, &device),
|
||||
bias: Some(Param::from_data(bias, &device)),
|
||||
};
|
||||
|
||||
LinearConfig::new(6, 6).init(&device).load_record(record)
|
||||
}
|
||||
|
||||
fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
|
||||
let config = AdamConfig::new();
|
||||
Adam {
|
||||
momentum: AdaptiveMomentum {
|
||||
beta_1: config.beta_1,
|
||||
beta_2: config.beta_2,
|
||||
epsilon: config.epsilon,
|
||||
amsgrad: config.amsgrad,
|
||||
},
|
||||
weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,598 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use burn::tensor::{backend::Backend, ops::Device};
|
||||
use burn::{module::AutodiffModule, record::Record};
|
||||
|
||||
use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};
|
||||
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// [`AdamW`] Configuration.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AdamWConfig {
|
||||
/// Parameter for AdamW.
|
||||
#[config(default = 0.9)]
|
||||
beta_1: f32,
|
||||
/// Parameter for AdamW.
|
||||
#[config(default = 0.999)]
|
||||
beta_2: f32,
|
||||
/// A value required for numerical stability.
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f32,
|
||||
/// Weight decay config.
|
||||
#[config(default = 1e-4)]
|
||||
weight_decay: f32,
|
||||
|
||||
/// Cautious weight decay config.
|
||||
///
|
||||
/// See: <https://arxiv.org/abs/2510.12402>
|
||||
#[config(default = false)]
|
||||
cautious_weight_decay: bool,
|
||||
|
||||
/// Whether to use AMSGrad algorithm
|
||||
#[config(default = false)]
|
||||
amsgrad: bool,
|
||||
/// [Gradient Clipping](GradientClippingConfig) config.
|
||||
grad_clipping: Option<GradientClippingConfig>,
|
||||
}
|
||||
|
||||
/// AdamW optimizer.
|
||||
///
|
||||
/// See:
|
||||
/// - [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
|
||||
/// - [Cautious Weight Decay, 2025](https://arxiv.org/abs/2510.12402)
|
||||
/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
|
||||
///
|
||||
/// Configured by [`AdamWConfig`].
|
||||
#[derive(Clone)]
|
||||
pub struct AdamW {
|
||||
momentum: AdaptiveMomentumW,
|
||||
weight_decay: f32,
|
||||
cautious_weight_decay: bool,
|
||||
}
|
||||
|
||||
/// AdamW state.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct AdamWState<B: Backend, const D: usize> {
|
||||
/// Th current adaptive momentum state.
|
||||
pub momentum: AdaptiveMomentumState<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for AdamW {
|
||||
type State<const D: usize> = AdamWState<B, D>;
|
||||
|
||||
/// A single optimization step for any tensor that represents the parameters of a model.
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
// Learning rate.
|
||||
lr: LearningRate,
|
||||
// Any tensor that represents the parameters of a model.
|
||||
tensor: Tensor<B, D>,
|
||||
// Gradient of the loss w.r.t. the parameters.
|
||||
grad: Tensor<B, D>,
|
||||
// State of the optimizer.
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||
let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
|
||||
|
||||
let decay_rate = lr * (self.weight_decay as f64);
|
||||
|
||||
let decayed_tensor = if decay_rate == 0.0 {
|
||||
tensor.clone()
|
||||
} else if self.cautious_weight_decay {
|
||||
// Cautious weight decay.
|
||||
// See: https://arxiv.org/abs/2510.12402
|
||||
let tensor_pos = tensor.clone().greater_equal_elem(0.0);
|
||||
let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);
|
||||
let differ = tensor_pos.not_equal(grad_pos);
|
||||
|
||||
// Zero out the decay where the decay is counter to the update direction.
|
||||
tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)
|
||||
} else {
|
||||
tensor.clone().mul_scalar(1.0 - decay_rate)
|
||||
};
|
||||
|
||||
let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);
|
||||
|
||||
let state = AdamWState {
|
||||
momentum: momentum_state,
|
||||
};
|
||||
|
||||
(tensor_updated, Some(state))
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
|
||||
state.momentum = state.momentum.to_device(device);
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
impl AdamWConfig {
|
||||
/// Initialize AdamW optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
|
||||
let optim = AdamW {
|
||||
momentum: AdaptiveMomentumW {
|
||||
beta_1: self.beta_1,
|
||||
beta_2: self.beta_2,
|
||||
epsilon: self.epsilon,
|
||||
amsgrad: self.amsgrad,
|
||||
},
|
||||
weight_decay: self.weight_decay,
|
||||
cautious_weight_decay: self.cautious_weight_decay,
|
||||
};
|
||||
|
||||
let mut optim = OptimizerAdaptor::from(optim);
|
||||
if let Some(config) = &self.grad_clipping {
|
||||
optim = optim.with_grad_clipping(config.init());
|
||||
}
|
||||
optim
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AdaptiveMomentumW {
|
||||
beta_1: f32,
|
||||
beta_2: f32,
|
||||
epsilon: f32,
|
||||
amsgrad: bool,
|
||||
}
|
||||
|
||||
impl AdaptiveMomentumW {
|
||||
pub fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
state: Option<AdaptiveMomentumState<B, D>>,
|
||||
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
|
||||
let factor_1 = 1.0 - self.beta_1;
|
||||
let factor_2 = 1.0 - self.beta_2;
|
||||
|
||||
let state = if let Some(mut state) = state {
|
||||
// Update first moment estimate.
|
||||
state.moment_1 = state
|
||||
.moment_1
|
||||
.mul_scalar(self.beta_1)
|
||||
.add(grad.clone().mul_scalar(factor_1));
|
||||
|
||||
// Update second moment estimate.
|
||||
state.moment_2 = state
|
||||
.moment_2
|
||||
.mul_scalar(self.beta_2)
|
||||
.add(grad.square().mul_scalar(factor_2));
|
||||
|
||||
if self.amsgrad {
|
||||
let max_v = state
|
||||
.max_moment_2
|
||||
.take()
|
||||
.unwrap_or_else(|| state.moment_2.clone());
|
||||
state.max_moment_2 = Some(max_v.max_pair(state.moment_2.clone()));
|
||||
}
|
||||
|
||||
// Update time.
|
||||
state.time += 1;
|
||||
|
||||
state
|
||||
} else {
|
||||
// Initialize first moment estimate.
|
||||
let moment_1 = grad.clone().mul_scalar(factor_1);
|
||||
|
||||
// Initialize second moment estimate.
|
||||
let moment_2 = grad.square().mul_scalar(factor_2);
|
||||
let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
|
||||
AdaptiveMomentumState {
|
||||
time: 1,
|
||||
moment_1,
|
||||
moment_2,
|
||||
max_moment_2,
|
||||
}
|
||||
};
|
||||
|
||||
let time: i32 = state.time as i32;
|
||||
|
||||
// Compute bias-corrected first and second moment estimates.
|
||||
let moment_1_corrected = state
|
||||
.moment_1
|
||||
.clone()
|
||||
.div_scalar(1f32 - self.beta_1.powi(time));
|
||||
|
||||
let v_to_use = if self.amsgrad {
|
||||
state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
|
||||
} else {
|
||||
&state.moment_2
|
||||
};
|
||||
|
||||
let moment_2_corrected = v_to_use.clone().div_scalar(1f32 - self.beta_2.powi(time));
|
||||
|
||||
let update_delta =
|
||||
moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
|
||||
|
||||
(update_delta, state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use crate::{GradientsParams, Optimizer};
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Distribution, Tensor, TensorData};
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_nn::{Linear, LinearConfig, LinearRecord};
|
||||
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
|
||||
const LEARNING_RATE: LearningRate = 0.01;
|
||||
|
||||
#[test]
|
||||
fn test_adamw_optimizer_save_load_state() {
|
||||
let device = Default::default();
|
||||
let linear = LinearConfig::new(6, 6).init(&device);
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
|
||||
let mut optimizer = create_adamw();
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
BinFileRecorder::<FullPrecisionSettings>::default()
|
||||
.record(
|
||||
optimizer.to_record(),
|
||||
std::env::temp_dir().as_path().join("test_optim_adamw"),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.record(optimizer.to_record(), ())
|
||||
.unwrap();
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
let state_optim_before = optimizer.to_record();
|
||||
let state_optim_before_copy = optimizer.to_record();
|
||||
let optimizer = create_adamw();
|
||||
let optimizer = optimizer.load_record(state_optim_before_copy);
|
||||
let state_optim_after = optimizer.to_record();
|
||||
|
||||
assert_eq!(state_optim_before.len(), state_optim_after.len());
|
||||
}
|
||||
#[test]
|
||||
fn test_adamw_optimizer_with_amsgrad_50_steps() {
|
||||
let device = Default::default();
|
||||
let mut linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_amsgrad(true)
|
||||
.with_weight_decay(0.5)
|
||||
.init();
|
||||
|
||||
for i in 1..=50 {
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
|
||||
.mul_scalar(i as f32 * 0.1)
|
||||
.require_grad();
|
||||
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
}
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weight_updated = state_updated.weight.to_data();
|
||||
let bias_updated = state_updated.bias.unwrap().to_data();
|
||||
|
||||
let weights_expected = TensorData::from([
|
||||
[
|
||||
-0.7822558283805847,
|
||||
-0.42578864097595215,
|
||||
-0.21805696189403534,
|
||||
-0.28366872668266296,
|
||||
-0.46587175130844116,
|
||||
-0.4805040955543518,
|
||||
],
|
||||
[
|
||||
-0.4722539782524109,
|
||||
-0.5471276640892029,
|
||||
-0.8181359767913818,
|
||||
-0.33425918221473694,
|
||||
-0.3805687427520752,
|
||||
-0.7601516842842102,
|
||||
],
|
||||
[
|
||||
-0.5475167632102966,
|
||||
-0.5057991743087769,
|
||||
-0.763265073299408,
|
||||
-0.3393959403038025,
|
||||
-0.7490996718406677,
|
||||
-0.28911691904067993,
|
||||
],
|
||||
[
|
||||
-0.7646660208702087,
|
||||
-0.7050473093986511,
|
||||
-0.8218720555305481,
|
||||
-0.7647438049316406,
|
||||
-0.5919585227966309,
|
||||
-0.40617525577545166,
|
||||
],
|
||||
[
|
||||
-0.27588561177253723,
|
||||
-0.7025567889213562,
|
||||
-0.24343004822731018,
|
||||
-0.6672990918159485,
|
||||
-0.23728127777576447,
|
||||
-0.556389570236206,
|
||||
],
|
||||
[
|
||||
-0.5451040267944336,
|
||||
-0.5420684814453125,
|
||||
-0.4348171353340149,
|
||||
-0.3832150399684906,
|
||||
-0.5099242925643921,
|
||||
-0.23440153896808624,
|
||||
],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.7473056316375732,
|
||||
-0.3745720386505127,
|
||||
-0.5188710689544678,
|
||||
-0.35184532403945923,
|
||||
-0.33705732226371765,
|
||||
-0.4332566559314728,
|
||||
]);
|
||||
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
let tolerance = Tolerance::absolute(1e-5);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
}
|
||||
#[test]
|
||||
fn test_adamw_optimizer_with_numbers() {
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
let device = Default::default();
|
||||
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_weight_decay(0.5)
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x_1).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x_2).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weights_expected = TensorData::from([
|
||||
[-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
|
||||
[
|
||||
0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
|
||||
],
|
||||
[
|
||||
-0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
|
||||
],
|
||||
[
|
||||
-0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
|
||||
],
|
||||
[
|
||||
0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
|
||||
],
|
||||
[-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
|
||||
]);
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.to_data(),
|
||||
state_updated.bias.unwrap().to_data(),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::absolute(1e-2);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adamw_optimizer_with_numbers_cautious() {
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
let device = Default::default();
|
||||
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_cautious_weight_decay(true)
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_weight_decay(0.5)
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x_1).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x_2).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weights_expected = TensorData::from([
|
||||
[-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
|
||||
[
|
||||
0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
|
||||
],
|
||||
[
|
||||
-0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
|
||||
],
|
||||
[
|
||||
-0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
|
||||
],
|
||||
[
|
||||
0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
|
||||
],
|
||||
[
|
||||
-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,
|
||||
],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
|
||||
]);
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.to_data(),
|
||||
state_updated.bias.unwrap().to_data(),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::absolute(1e-2);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adam_optimizer_no_nan() {
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&Default::default(),
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_weight_decay(0.5)
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x.clone()).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
|
||||
}
|
||||
|
||||
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
|
||||
let device = Default::default();
|
||||
let record = LinearRecord {
|
||||
weight: Param::from_data(weight, &device),
|
||||
bias: Some(Param::from_data(bias, &device)),
|
||||
};
|
||||
|
||||
LinearConfig::new(6, 6).init(&device).load_record(record)
|
||||
}
|
||||
|
||||
fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
|
||||
let config = AdamWConfig::new();
|
||||
AdamW {
|
||||
momentum: AdaptiveMomentumW {
|
||||
beta_1: config.beta_1,
|
||||
beta_2: config.beta_2,
|
||||
epsilon: config.epsilon,
|
||||
amsgrad: config.amsgrad,
|
||||
},
|
||||
weight_decay: config.weight_decay,
|
||||
cautious_weight_decay: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use burn_core::{self as burn, Tensor};
|
||||
|
||||
use burn_core::module::ParamId;
|
||||
use burn_core::prelude::{Backend, DeviceOps};
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::DeviceId;
|
||||
|
||||
use super::GradientsParams;
|
||||
use crate::LearningRate;
|
||||
use alloc::vec::Vec;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
|
||||
#[derive(Default)]
|
||||
/// Exposes multiple gradients for each parameter.
|
||||
pub struct MultiGradientsParams {
|
||||
/// Each [GradientsParams] has its associated [DeviceId].
|
||||
pub grads: Vec<(GradientsParams, DeviceId)>,
|
||||
}
|
||||
|
||||
impl MultiGradientsParams {
|
||||
/// Removes the gradients for the given [parameter id](ParamId).
|
||||
///
|
||||
/// Potentially accumulates the gradients from multiple sources using a device associated with
|
||||
/// a parameter id. The same parameter will be accumulated using the same device during
|
||||
/// all training.
|
||||
pub fn remove<B: Backend, const D: usize>(
|
||||
&mut self,
|
||||
id: ParamId,
|
||||
) -> Option<(Tensor<B, D>, Device<B>)> {
|
||||
let (mut tensor, device, index) = self.select(id)?;
|
||||
|
||||
for (i, (grads, _)) in self.grads.iter_mut().enumerate() {
|
||||
if i == index {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(grad) = grads.remove::<B, D>(id) {
|
||||
tensor = tensor + grad.to_device(&device);
|
||||
}
|
||||
}
|
||||
|
||||
Some((tensor, device))
|
||||
}
|
||||
|
||||
fn select<B: Backend, const D: usize>(
|
||||
&mut self,
|
||||
id: ParamId,
|
||||
) -> Option<(Tensor<B, D>, Device<B>, usize)> {
|
||||
let id_val = id.val() as usize;
|
||||
for i in 0..self.grads.len() {
|
||||
let selected_device_index = (id_val + i) % self.grads.len();
|
||||
|
||||
if let Some(acc) = self.grads[selected_device_index].0.remove::<B, D>(id) {
|
||||
let device_id = self.grads[selected_device_index].1;
|
||||
let device = <B::Device as DeviceOps>::from_id(device_id);
|
||||
return Some((acc.to_device(&device), device, selected_device_index));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// General trait to optimize [module](AutodiffModule).
|
||||
pub trait Optimizer<M, B>: Send + Clone
|
||||
where
|
||||
M: AutodiffModule<B>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
/// Optimizer associative type to be used when saving and loading the state.
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Perform the optimizer step using the given learning rate and gradients.
|
||||
/// The updated module is returned.
|
||||
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
|
||||
|
||||
/// Perform the optimizer step using the given learning rate and gradients.
|
||||
/// The updated module is returned.
|
||||
fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M;
|
||||
|
||||
/// Get the current state of the optimizer as a [record](Record).
|
||||
fn to_record(&self) -> Self::Record;
|
||||
|
||||
/// Load the state of the optimizer as a [record](Record).
|
||||
fn load_record(self, record: Self::Record) -> Self;
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
/// Configuration to create [weight decay](WeightDecay).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct WeightDecayConfig {
|
||||
/// L2 penalty.
|
||||
pub penalty: f32,
|
||||
}
|
||||
|
||||
/// State of [weight decay](WeightDecay).
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct WeightDecayState<B: Backend, const D: usize> {
|
||||
pub(crate) grad_last_step: Tensor<B, D>,
|
||||
}
|
||||
|
||||
/// Weight decay implementation that transforms gradients.
|
||||
#[derive(Clone)]
|
||||
pub struct WeightDecay {
|
||||
penalty: f32,
|
||||
}
|
||||
|
||||
impl WeightDecay {
|
||||
/// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig).
|
||||
pub fn new(config: &WeightDecayConfig) -> Self {
|
||||
Self {
|
||||
penalty: config.penalty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transforms a gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `grad` - Gradient to transform.
|
||||
/// * `tensor` - Tensor param of the last iteration.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `grad` - Transformed gradient.
|
||||
pub fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
tensor: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
tensor.mul_scalar(self.penalty).add(grad)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.grad_last_step = self.grad_last_step.to_device(device);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use burn::module::{AutodiffModule, ModuleVisitor, Param};
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
|
||||
use super::GradientsParams;
|
||||
|
||||
/// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object.
|
||||
pub struct GradientsAccumulator<M> {
|
||||
grads: GradientsParams,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<M> Default for GradientsAccumulator<M> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<M> GradientsAccumulator<M> {
|
||||
/// Create a new gradients accumulator.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
grads: GradientsParams::new(),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M> GradientsAccumulator<M> {
|
||||
/// Accumulate the given gradients for each parameter in the given module.
|
||||
pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)
|
||||
where
|
||||
M: AutodiffModule<B>,
|
||||
{
|
||||
let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
|
||||
module.visit(&mut visitor);
|
||||
}
|
||||
|
||||
/// Return the accumulated gradients and reset the accumulator state.
|
||||
pub fn grads(&mut self) -> GradientsParams {
|
||||
let mut grads = GradientsParams::new();
|
||||
core::mem::swap(&mut self.grads, &mut grads);
|
||||
|
||||
grads
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct ModuleGradsAccumulator<'a, M> {
|
||||
grads: &'a mut GradientsParams,
|
||||
grads_new: GradientsParams,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(param.id) {
|
||||
Some(new) => match self.grads.remove::<B::InnerBackend, D>(param.id) {
|
||||
Some(grad) => grad.add(new),
|
||||
None => new,
|
||||
},
|
||||
None => match self.grads.remove::<B::InnerBackend, D>(param.id) {
|
||||
Some(grad) => grad,
|
||||
None => return,
|
||||
},
|
||||
};
|
||||
|
||||
self.grads
|
||||
.register::<B::InnerBackend, D>(param.id, grad_updated);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use burn::tensor::{Distribution, backend::Backend};
|
||||
use burn_nn::{Linear, LinearConfig};
|
||||
|
||||
#[test]
|
||||
fn test_accumulate_gradients_one_step() {
|
||||
let device = Default::default();
|
||||
let mut accumulator = GradientsAccumulator::new();
|
||||
let layer = layer::<TestAutodiffBackend>(&device);
|
||||
let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
|
||||
let grads = GradientsParams::from_grads(loss.backward(), &layer);
|
||||
|
||||
accumulator.accumulate(&layer, grads);
|
||||
|
||||
let grads = accumulator.grads();
|
||||
assert!(!grads.is_empty())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accumulate_gradients_two_steps() {
|
||||
let device = Default::default();
|
||||
let mut accumulator = GradientsAccumulator::new();
|
||||
let layer = layer::<TestAutodiffBackend>(&device);
|
||||
let loss_1 = layer.forward(random_tensor(&device));
|
||||
let loss_2 = layer.forward(random_tensor(&device));
|
||||
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
|
||||
let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
|
||||
|
||||
accumulator.accumulate(&layer, grads_1);
|
||||
accumulator.accumulate(&layer, grads_2);
|
||||
|
||||
let grads = accumulator.grads();
|
||||
assert_eq!(grads.len(), 2)
|
||||
}
|
||||
|
||||
fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
|
||||
LinearConfig::new(20, 20).init(device)
|
||||
}
|
||||
|
||||
fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
|
||||
Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
use burn_core as burn;
|
||||
|
||||
#[cfg(feature = "collective")]
|
||||
use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};
|
||||
|
||||
use burn::{
|
||||
Tensor,
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
container::TensorContainer,
|
||||
},
|
||||
};
|
||||
|
||||
use burn::module::{AutodiffModule, ParamId};
|
||||
|
||||
use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
|
||||
|
||||
/// Data type that contains gradients for parameters.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct GradientsParams {
|
||||
container: TensorContainer<ParamId>,
|
||||
}
|
||||
|
||||
impl GradientsParams {
|
||||
/// Creates a new [GradientsParams](GradientsParams).
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Extract each tensor gradients for the given [module](AutodiffModule).
|
||||
///
|
||||
/// Note: This consumes the gradients. See ['from_module'] to extract gradients only for
|
||||
/// a specific module.
|
||||
pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
grads: B::Gradients,
|
||||
module: &M,
|
||||
) -> Self {
|
||||
let mut grads = grads;
|
||||
Self::from_module(&mut grads, module)
|
||||
}
|
||||
|
||||
/// Extract each tensor gradients for the given [module](AutodiffModule).
|
||||
pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
grads: &mut B::Gradients,
|
||||
module: &M,
|
||||
) -> Self {
|
||||
let mut grads_params = GradientsParams::new();
|
||||
let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
|
||||
module.visit(&mut visitor);
|
||||
grads_params
|
||||
}
|
||||
|
||||
/// Extract tensor gradients for the given [module](AutodiffModule) and given parameters.
|
||||
pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
grads: &mut B::Gradients,
|
||||
module: &M,
|
||||
params: &[ParamId],
|
||||
) -> Self {
|
||||
let mut grads_params = GradientsParams::new();
|
||||
let mut visitor =
|
||||
GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
|
||||
module.visit(&mut visitor);
|
||||
grads_params
|
||||
}
|
||||
|
||||
/// Get the gradients for the given [parameter id](ParamId).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// You should use [remove](GradientsParams::remove) if you want to get the gradients
|
||||
/// only one time.
|
||||
pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
self.container.get(&id).map(Tensor::from_primitive)
|
||||
}
|
||||
|
||||
/// Remove the gradients for the given [parameter id](ParamId).
|
||||
pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
self.container.remove(&id).map(Tensor::from_primitive)
|
||||
}
|
||||
|
||||
/// Register a gradients tensor for the given [parameter id](ParamId).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced.
|
||||
pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
self.container.register(id, value.into_primitive())
|
||||
}
|
||||
|
||||
/// The number of gradients tensors registered.
|
||||
pub fn len(&self) -> usize {
|
||||
self.container.len()
|
||||
}
|
||||
|
||||
/// If any tensor is contained.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Change the device of each tensor gradients registered for the given [module](AutodiffModule).
|
||||
pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
mut self,
|
||||
device: &B::Device,
|
||||
module: &M,
|
||||
) -> Self {
|
||||
let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
|
||||
module.visit(&mut visitor);
|
||||
self
|
||||
}
|
||||
|
||||
/// Syncs the gradient params with the other peers in the collective.
|
||||
#[cfg(feature = "collective")]
|
||||
pub fn all_reduce<B: Backend>(
|
||||
mut self,
|
||||
peer_id: PeerId,
|
||||
op: ReduceOperation,
|
||||
) -> Result<Self, CollectiveError> {
|
||||
let mut ids = self
|
||||
.container
|
||||
.ids()
|
||||
.into_iter()
|
||||
.copied()
|
||||
.collect::<Vec<ParamId>>();
|
||||
// This is crucial, since the all-reduce operations need to happen in the same order for the same parameters on all nodes!
|
||||
ids.sort();
|
||||
|
||||
for id in ids {
|
||||
let Some(grad) = self.container.remove::<B>(&id) else {
|
||||
todo!()
|
||||
};
|
||||
|
||||
let grad = match grad {
|
||||
burn::tensor::TensorPrimitive::Float(grad) => {
|
||||
let grad = all_reduce::<B>(peer_id, grad, op)?;
|
||||
burn::tensor::TensorPrimitive::Float(grad)
|
||||
}
|
||||
burn::tensor::TensorPrimitive::QFloat(_grad) => {
|
||||
unimplemented!("quantized all-reduce unimplemented")
|
||||
}
|
||||
};
|
||||
|
||||
self.container.register::<B>(id, grad);
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use burn::module::{Module, list_param_ids};
|
||||
use burn::tensor::{Distribution, backend::Backend};
|
||||
use burn_nn::{Linear, LinearConfig};
|
||||
|
||||
#[test]
|
||||
fn test_convert_grads() {
|
||||
let device = Default::default();
|
||||
let layer_1 = layer::<TestAutodiffBackend>(&device);
|
||||
let mut layer_2 = layer_1.clone();
|
||||
layer_2 = layer_2.fork(&device);
|
||||
let loss_1 = layer_1.forward(random_tensor(&device));
|
||||
let loss_2 = layer_2.forward(random_tensor(&device));
|
||||
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
|
||||
let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
|
||||
|
||||
let param_ids_1 = list_param_ids(&layer_1);
|
||||
let param_ids_2 = list_param_ids(&layer_2);
|
||||
|
||||
assert_eq!(param_ids_1, param_ids_2);
|
||||
assert_eq!(grads_1.len(), param_ids_1.len());
|
||||
assert_eq!(grads_2.len(), param_ids_2.len());
|
||||
}
|
||||
|
||||
fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
|
||||
LinearConfig::new(20, 20).init(device)
|
||||
}
|
||||
|
||||
fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
|
||||
Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,978 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use super::GradientsParams;
|
||||
use crate::LearningRate;
|
||||
use burn::config::Config;
|
||||
use burn::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
|
||||
use burn::prelude::ToElement;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Cubic Interpolate
|
||||
///
|
||||
/// Uses two points (x1, f1), (x2, f2) and their first derivatives g1,g2 to construct
|
||||
/// a cubic interpolant and return its minimum within the given bounds.
|
||||
fn cubic_interpolate(
|
||||
x1: f64,
|
||||
f1: f64,
|
||||
g1: f64,
|
||||
x2: f64,
|
||||
f2: f64,
|
||||
g2: f64,
|
||||
bounds: Option<(f64, f64)>,
|
||||
) -> f64 {
|
||||
// Compute bounds of interpolation area
|
||||
let (min_bound, max_bound) = bounds.unwrap_or(if x1 <= x2 { (x1, x2) } else { (x2, x1) });
|
||||
// Code for most common case: cubic interpolation of 2 points
|
||||
// with function and derivative values for both
|
||||
// Solution in this case (where x2 is the farthest point)
|
||||
// d1 = g1 + g2 - 3*(f1 - f2) / (x1-x2);
|
||||
// d2 = sqrt(d1^2 - g1 * g2);
|
||||
// min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
|
||||
// t_new = min(max(min_pos,min_bound), max_bound);
|
||||
let d1 = g1 + g2 - 3.0 * (f1 - f2) / (x1 - x2);
|
||||
let d2_square = d1 * d1 - g1 * g2;
|
||||
|
||||
if d2_square >= 0.0 {
|
||||
let d2 = d2_square.sqrt();
|
||||
let min_pos = if x1 <= x2 {
|
||||
x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2.0 * d2))
|
||||
} else {
|
||||
x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2.0 * d2))
|
||||
};
|
||||
min_pos.max(min_bound).min(max_bound)
|
||||
} else {
|
||||
(min_bound + max_bound) / 2.0
|
||||
}
|
||||
}
|
||||
/// Auxiliary Struct For Strong_Wolfe
|
||||
struct LineSearchSample<B: Backend> {
|
||||
// step size
|
||||
t: f64,
|
||||
// loss
|
||||
f: f64,
|
||||
// gradient
|
||||
g: Tensor<B, 1>,
|
||||
// directional derivative
|
||||
gtd: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn strong_wolfe<B: Backend, F>(
|
||||
// obj_func(x,step size,direction) -> (loss,grad)
|
||||
obj_func: &mut F,
|
||||
x: &Tensor<B, 1>,
|
||||
// initial step size
|
||||
mut t: f64,
|
||||
d: &Tensor<B, 1>,
|
||||
f: f64,
|
||||
g: Tensor<B, 1>,
|
||||
gtd: f64,
|
||||
c1: f64,
|
||||
c2: f64,
|
||||
tolerance_change: f64,
|
||||
max_ls: usize,
|
||||
) -> (f64, Tensor<B, 1>, f64, usize)
|
||||
where
|
||||
F: FnMut(&Tensor<B, 1>, f64, &Tensor<B, 1>) -> (f64, Tensor<B, 1>),
|
||||
{
|
||||
let d_norm = d.clone().abs().max().into_scalar().to_f64();
|
||||
|
||||
// evaluate objective and gradient using initial step
|
||||
let (mut f_new, mut g_new) = obj_func(x, t, d);
|
||||
let mut ls_func_evals = 1;
|
||||
let mut gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
|
||||
|
||||
// bracket an interval [t_prev,t] containing a point satisfying the Wolfe criteria
|
||||
let (mut t_prev, mut f_prev, mut g_prev, mut gtd_prev) = (0.0, f, g.clone(), gtd);
|
||||
let mut done = false;
|
||||
let mut ls_iter = 0;
|
||||
|
||||
// the interval [low,high] using for Zoom phase
|
||||
let mut bracket: Option<[LineSearchSample<B>; 2]> = None;
|
||||
// point which satisfy the wolfe condition
|
||||
let mut wolfe_bracket: Option<LineSearchSample<B>> = None;
|
||||
while ls_iter < max_ls {
|
||||
// Checking Conditions.
|
||||
|
||||
// Checking the Armijo Condition and function value increasing condition.
|
||||
// Armijo: f(x+t*d) <= f(x) + c_1 t gtd
|
||||
if f_new > (f + c1 * t * gtd) || (ls_iter > 1 && f_new >= f_prev) {
|
||||
bracket = Some([
|
||||
LineSearchSample {
|
||||
t: t_prev,
|
||||
f: f_prev,
|
||||
g: g_prev,
|
||||
gtd: gtd_prev,
|
||||
},
|
||||
LineSearchSample {
|
||||
t,
|
||||
f: f_new,
|
||||
g: g_new.clone(),
|
||||
gtd: gtd_new,
|
||||
},
|
||||
]);
|
||||
break;
|
||||
}
|
||||
|
||||
// Checking Strong Wolfe Condition
|
||||
// |gtd_new| <= -c_2 gtd
|
||||
if gtd_new.abs() <= -c2 * gtd {
|
||||
wolfe_bracket = Some(LineSearchSample {
|
||||
t,
|
||||
f: f_new,
|
||||
g: g_new.clone(),
|
||||
gtd: gtd_new,
|
||||
});
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// gtd_new >=0 , there must be a local minimum in the interval.
|
||||
if gtd_new >= 0.0 {
|
||||
bracket = Some([
|
||||
LineSearchSample {
|
||||
t: t_prev,
|
||||
f: f_prev,
|
||||
g: g_prev,
|
||||
gtd: gtd_prev,
|
||||
},
|
||||
LineSearchSample {
|
||||
t,
|
||||
f: f_new,
|
||||
g: g_new.clone(),
|
||||
gtd: gtd_new,
|
||||
},
|
||||
]);
|
||||
break;
|
||||
}
|
||||
|
||||
// interpolate
|
||||
let min_step = t + 0.01 * (t - t_prev);
|
||||
let max_step = t * 10.0;
|
||||
let t_next = cubic_interpolate(
|
||||
t_prev,
|
||||
f_prev,
|
||||
gtd_prev,
|
||||
t,
|
||||
f_new,
|
||||
gtd_new,
|
||||
Some((min_step, max_step)),
|
||||
);
|
||||
t_prev = t;
|
||||
f_prev = f_new;
|
||||
g_prev = g_new;
|
||||
gtd_prev = gtd_new;
|
||||
|
||||
// next step
|
||||
t = t_next;
|
||||
(f_new, g_new) = obj_func(x, t, d);
|
||||
ls_func_evals += 1;
|
||||
gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
|
||||
ls_iter += 1;
|
||||
}
|
||||
if let Some(sample) = wolfe_bracket {
|
||||
return (sample.f, sample.g, sample.t, ls_func_evals);
|
||||
}
|
||||
|
||||
let mut bracket = bracket.unwrap_or_else(|| {
|
||||
[
|
||||
LineSearchSample {
|
||||
t: 0.0,
|
||||
f,
|
||||
g: g.clone(),
|
||||
gtd,
|
||||
},
|
||||
LineSearchSample {
|
||||
t,
|
||||
f: f_new,
|
||||
g: g_new.clone(),
|
||||
gtd: gtd_new,
|
||||
},
|
||||
]
|
||||
});
|
||||
|
||||
// zoom phase
|
||||
let mut insuf_progress = false;
|
||||
|
||||
// find high and low points in bracket
|
||||
let (mut low_idx, mut high_idx) = if bracket[0].f <= bracket[1].f {
|
||||
(0, 1)
|
||||
} else {
|
||||
(1, 0)
|
||||
};
|
||||
|
||||
while !done && ls_iter < max_ls {
|
||||
let diff = (bracket[1].t - bracket[0].t).abs();
|
||||
// line-search bracket is so small
|
||||
if diff * d_norm < tolerance_change {
|
||||
break;
|
||||
}
|
||||
|
||||
// compute new trial value
|
||||
t = cubic_interpolate(
|
||||
bracket[0].t,
|
||||
bracket[0].f,
|
||||
bracket[0].gtd,
|
||||
bracket[1].t,
|
||||
bracket[1].f,
|
||||
bracket[1].gtd,
|
||||
None,
|
||||
);
|
||||
|
||||
let b_min = bracket[0].t.min(bracket[1].t);
|
||||
let b_max = bracket[0].t.max(bracket[1].t);
|
||||
let eps = 0.1 * (b_max - b_min);
|
||||
|
||||
if (b_max - t).min(t - b_min) < eps {
|
||||
// interpolation close to boundary
|
||||
if insuf_progress || t >= b_max || t <= b_min {
|
||||
t = if (t - b_max).abs() < (t - b_min).abs() {
|
||||
b_max - eps
|
||||
} else {
|
||||
b_min + eps
|
||||
};
|
||||
insuf_progress = false;
|
||||
} else {
|
||||
insuf_progress = true;
|
||||
}
|
||||
} else {
|
||||
insuf_progress = false;
|
||||
}
|
||||
|
||||
// Evaluate new point
|
||||
(f_new, g_new) = obj_func(x, t, d);
|
||||
|
||||
ls_func_evals += 1;
|
||||
gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
|
||||
ls_iter += 1;
|
||||
|
||||
let armijo_holds = f_new <= (f + c1 * t * gtd) && f_new < bracket[low_idx].f;
|
||||
|
||||
if !armijo_holds {
|
||||
bracket[high_idx] = LineSearchSample {
|
||||
t,
|
||||
f: f_new,
|
||||
g: g_new,
|
||||
gtd: gtd_new,
|
||||
};
|
||||
} else {
|
||||
if gtd_new.abs() <= -c2 * gtd {
|
||||
return (f_new, g_new, t, ls_func_evals);
|
||||
}
|
||||
|
||||
if gtd_new * (bracket[high_idx].t - bracket[low_idx].t) >= 0.0 {
|
||||
bracket[high_idx] = LineSearchSample {
|
||||
t: bracket[low_idx].t,
|
||||
f: bracket[low_idx].f,
|
||||
g: bracket[low_idx].g.clone(),
|
||||
gtd: bracket[low_idx].gtd,
|
||||
};
|
||||
}
|
||||
bracket[low_idx] = LineSearchSample {
|
||||
t,
|
||||
f: f_new,
|
||||
g: g_new,
|
||||
gtd: gtd_new,
|
||||
};
|
||||
}
|
||||
|
||||
if bracket[0].f <= bracket[1].f {
|
||||
low_idx = 0;
|
||||
high_idx = 1;
|
||||
} else {
|
||||
low_idx = 1;
|
||||
high_idx = 0;
|
||||
}
|
||||
}
|
||||
// return stuff
|
||||
(
|
||||
bracket[low_idx].f,
|
||||
bracket[low_idx].g.clone(),
|
||||
bracket[low_idx].t,
|
||||
ls_func_evals,
|
||||
)
|
||||
}
|
||||
|
||||
/// Strategy for the line search optimization phase
|
||||
#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum LineSearchFn {
|
||||
/// No line search performed
|
||||
#[default]
|
||||
None,
|
||||
/// strong wolfe conditions
|
||||
///
|
||||
/// See: <https://en.wikipedia.org/wiki/Wolfe_conditions>
|
||||
StrongWolfe,
|
||||
}
|
||||
|
||||
/// LBFGS Configuration.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LBFGSConfig {
|
||||
/// Maximal number of iterations per optimization step (default: 20)
|
||||
#[config(default = 20)]
|
||||
pub max_iter: usize,
|
||||
/// Update history size (default: 100).
|
||||
#[config(default = 100)]
|
||||
pub history_size: usize,
|
||||
/// Termination tolerance on first order optimality (default: 1e-7).
|
||||
#[config(default = 1e-7)]
|
||||
pub tolerance_grad: f64,
|
||||
/// Termination tolerance on function value/parameter changes (default: 1e-9).
|
||||
#[config(default = 1e-9)]
|
||||
pub tolerance_change: f64,
|
||||
/// Maximal number of function evaluations per optimization step (default: max_iter * 1.25).
|
||||
#[config(default = "None")]
|
||||
pub max_eval: Option<usize>,
|
||||
/// Either ‘strong_wolfe’ or None (default: None).
|
||||
#[config(default = "LineSearchFn::None")]
|
||||
pub line_search_fn: LineSearchFn,
|
||||
}
|
||||
|
||||
impl LBFGSConfig {
|
||||
/// Initialize AdamW optimizer
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module
|
||||
pub fn init<B: AutodiffBackend>(&self) -> LBFGS<B> {
|
||||
// by default max_eval = max_iter * 5/4
|
||||
let max_eval = self.max_eval.unwrap_or(self.max_iter * 5 / 4);
|
||||
LBFGS {
|
||||
config: LBFGSConfig {
|
||||
max_iter: self.max_iter,
|
||||
history_size: self.history_size,
|
||||
tolerance_grad: self.tolerance_grad,
|
||||
tolerance_change: self.tolerance_change,
|
||||
max_eval: Some(max_eval),
|
||||
line_search_fn: self.line_search_fn,
|
||||
},
|
||||
state: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Collects gradients in module visit order.
|
||||
struct FlattenGradsVisitorInner<'a, B: AutodiffBackend> {
|
||||
grads: &'a GradientsParams,
|
||||
tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenGradsVisitorInner<'_, B> {
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
if let Some(g) = self.grads.get::<B::InnerBackend, D>(param.id) {
|
||||
let numel = g.shape().num_elements();
|
||||
self.tensors.push(g.reshape([numel]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flatten params to inner backend 1D tensor.
|
||||
fn flatten_params_inner<B: AutodiffBackend, M: Module<B>>(
|
||||
module: &M,
|
||||
) -> Tensor<B::InnerBackend, 1> {
|
||||
let mut tensors = Vec::new();
|
||||
let mut visitor = FlattenParamsVisitorInner::<B> {
|
||||
tensors: &mut tensors,
|
||||
};
|
||||
module.visit(&mut visitor);
|
||||
if tensors.is_empty() {
|
||||
return Tensor::empty([0], &module.devices()[0]);
|
||||
}
|
||||
Tensor::cat(tensors, 0)
|
||||
}
|
||||
|
||||
struct FlattenParamsVisitorInner<'a, B: AutodiffBackend> {
|
||||
tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenParamsVisitorInner<'_, B> {
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
let t = param.val().inner();
|
||||
let numel = t.shape().num_elements();
|
||||
self.tensors.push(t.reshape([numel]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Flatten gradients for a module.
|
||||
fn flatten_grads_inner<B: AutodiffBackend, M: Module<B>>(
|
||||
module: &M,
|
||||
grads: &GradientsParams,
|
||||
) -> Tensor<B::InnerBackend, 1> {
|
||||
let mut tensors = Vec::new();
|
||||
let mut visitor = FlattenGradsVisitorInner {
|
||||
grads,
|
||||
tensors: &mut tensors,
|
||||
};
|
||||
module.visit(&mut visitor);
|
||||
if tensors.is_empty() {
|
||||
return Tensor::empty([0], &module.devices()[0]);
|
||||
}
|
||||
Tensor::cat(tensors, 0)
|
||||
}
|
||||
|
||||
/// Mapper that assigns each float param from a flat inner-backend 1D tensor.
|
||||
struct ParamsFromFlatMapperInner<'a, B: AutodiffBackend> {
|
||||
flat: &'a Tensor<B::InnerBackend, 1>,
|
||||
offset: &'a mut usize,
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> ParamsFromFlatMapperInner<'_, B> {
|
||||
fn take_slice(&mut self, numel: usize) -> Tensor<B::InnerBackend, 1> {
|
||||
let start = *self.offset;
|
||||
*self.offset += numel;
|
||||
self.flat.clone().slice(start..*self.offset)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> ModuleMapper<B> for ParamsFromFlatMapperInner<'_, B> {
|
||||
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
let numel = tensor.shape().num_elements();
|
||||
let slice_1d = self.take_slice(numel);
|
||||
let new_inner = slice_1d.reshape(tensor.shape());
|
||||
let new_tensor = Tensor::from_inner(new_inner).require_grad();
|
||||
Param::from_mapped_value(id, new_tensor, mapper)
|
||||
}
|
||||
}
|
||||
|
||||
/// Overwrite module parameters from a flat inner-backend 1D tensor
|
||||
fn set_params_from_flat_inner<B: AutodiffBackend, M: Module<B>>(
|
||||
module: M,
|
||||
flat: Tensor<B::InnerBackend, 1>,
|
||||
) -> M {
|
||||
let mut offset = 0;
|
||||
let mut mapper = ParamsFromFlatMapperInner {
|
||||
flat: &flat,
|
||||
offset: &mut offset,
|
||||
};
|
||||
module.map(&mut mapper)
|
||||
}
|
||||
|
||||
/// L-BFGS optimizer state
|
||||
#[derive(Clone, Record)]
|
||||
pub struct LBFGSState<B: Backend> {
|
||||
/// Historical displacement vectors
|
||||
pub history_s: Vec<Tensor<B, 1>>,
|
||||
/// Historical gradient difference vectors
|
||||
pub history_y: Vec<Tensor<B, 1>>,
|
||||
/// Search direction
|
||||
pub d: Option<Tensor<B, 1>>,
|
||||
/// Step size from the previous iteration
|
||||
pub t: Option<f64>,
|
||||
/// Flattened gradient from the previous iteration
|
||||
pub prev_flat_grad: Option<Tensor<B, 1>>,
|
||||
/// Loss value from the previous iteration
|
||||
pub prev_loss: Option<f64>,
|
||||
/// Global iteration count
|
||||
pub g_iter: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> LBFGSState<B> {
|
||||
/// Moves all historical tensors to the target device.
|
||||
pub fn to_device(self, device: &B::Device) -> Self {
|
||||
Self {
|
||||
history_s: self
|
||||
.history_s
|
||||
.into_iter()
|
||||
.map(|t| t.to_device(device))
|
||||
.collect(),
|
||||
history_y: self
|
||||
.history_y
|
||||
.into_iter()
|
||||
.map(|t| t.to_device(device))
|
||||
.collect(),
|
||||
d: self.d.map(|t| t.to_device(device)),
|
||||
t: self.t,
|
||||
prev_flat_grad: self.prev_flat_grad.map(|t| t.to_device(device)),
|
||||
prev_loss: self.prev_loss,
|
||||
g_iter: self.g_iter,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<B: Backend> Default for LBFGSState<B> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
history_s: Vec::new(),
|
||||
history_y: Vec::new(),
|
||||
d: None,
|
||||
t: Some(1.0),
|
||||
prev_flat_grad: None,
|
||||
prev_loss: None,
|
||||
g_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// L-BFGS optimizer.
|
||||
///
|
||||
/// Ported from [pytorch](https://github.com/pytorch/pytorch/torch/optim/lbfgs.py). Heavily inspired by [miniFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html)
|
||||
///
|
||||
/// See also:
|
||||
/// - [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS)
|
||||
///
|
||||
/// # Note
|
||||
/// This optimizer is memory intensive
|
||||
#[derive(Clone)]
|
||||
pub struct LBFGS<B: Backend + AutodiffBackend> {
|
||||
config: LBFGSConfig,
|
||||
state: LBFGSState<B::InnerBackend>,
|
||||
}
|
||||
|
||||
impl<B: Backend + AutodiffBackend> LBFGS<B> {
|
||||
/// A single optimization step for any tensor that represents the parameters of a model.
|
||||
pub fn step<M, F>(&mut self, lr: LearningRate, mut module: M, mut closure: F) -> (M, f64)
|
||||
where
|
||||
M: AutodiffModule<B> + Clone,
|
||||
F: FnMut(M) -> (f64, GradientsParams),
|
||||
{
|
||||
// evaluate initial f(x) and df/dx
|
||||
let (mut loss, grads) = closure(module.clone());
|
||||
let mut current_evals = 1;
|
||||
|
||||
let mut flat_grad = flatten_grads_inner::<B, M>(&module, &grads);
|
||||
let mut x_flat = flatten_params_inner::<B, M>(&module);
|
||||
|
||||
let opt_cond =
|
||||
flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad;
|
||||
// optimal condition
|
||||
if opt_cond {
|
||||
return (module, loss);
|
||||
}
|
||||
|
||||
// tensors cached in state
|
||||
let mut d = self
|
||||
.state
|
||||
.d
|
||||
.take()
|
||||
.unwrap_or_else(|| flat_grad.clone().neg());
|
||||
let mut t = self.state.t.unwrap_or(lr);
|
||||
let mut prev_flat_grad = self.state.prev_flat_grad.take();
|
||||
|
||||
let mut n_iter = 0;
|
||||
|
||||
// optimize for a max of max_iter iterations
|
||||
while n_iter < self.config.max_iter {
|
||||
// keep track of nb of iterations
|
||||
n_iter += 1;
|
||||
self.state.g_iter += 1;
|
||||
|
||||
// compute gradient descent direction
|
||||
if self.state.g_iter == 1 {
|
||||
d = flat_grad.clone().neg();
|
||||
self.state.history_s.clear();
|
||||
self.state.history_y.clear();
|
||||
} else {
|
||||
// do lbfgs update (update memory)
|
||||
if let Some(pg) = prev_flat_grad.as_ref() {
|
||||
let y = flat_grad.clone().sub(pg.clone());
|
||||
let s = d.clone().mul_scalar(t);
|
||||
|
||||
let ys = y.clone().dot(s.clone()).into_scalar().to_f64();
|
||||
|
||||
if ys > 1e-10 {
|
||||
// updating memory
|
||||
if self.state.history_s.len() >= self.config.history_size {
|
||||
// shift history by one (limited-memory)
|
||||
self.state.history_s.remove(0);
|
||||
self.state.history_y.remove(0);
|
||||
}
|
||||
self.state.history_s.push(s);
|
||||
self.state.history_y.push(y);
|
||||
}
|
||||
}
|
||||
|
||||
// compute the approximate (L-BFGS) inverse Hessian
|
||||
// multiplied by the gradient
|
||||
let num_old = self.state.history_s.len();
|
||||
let mut q = flat_grad.clone().neg();
|
||||
let mut alphas: Vec<Tensor<B::InnerBackend, 1>> =
|
||||
vec![Tensor::zeros([1], &flat_grad.device()); num_old];
|
||||
|
||||
if num_old > 0 {
|
||||
// multiply by initial Hessian
|
||||
// r/d is the final direction
|
||||
for i in (0..num_old).rev() {
|
||||
let s = &self.state.history_s[i];
|
||||
let y = &self.state.history_y[i];
|
||||
let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
|
||||
let alpha = rho.clone().mul(s.clone().dot(q.clone()));
|
||||
alphas[i] = alpha.clone();
|
||||
q = q.sub(y.clone().mul(alpha));
|
||||
}
|
||||
|
||||
let last_s = &self.state.history_s[num_old - 1];
|
||||
let last_y = &self.state.history_y[num_old - 1];
|
||||
let ys = last_y.clone().dot(last_s.clone());
|
||||
let yy = last_y.clone().dot(last_y.clone());
|
||||
let h_diag = ys.div(yy);
|
||||
|
||||
let mut r = q.mul(h_diag);
|
||||
|
||||
for ((s, y), alpha) in self
|
||||
.state
|
||||
.history_s
|
||||
.iter()
|
||||
.zip(self.state.history_y.iter())
|
||||
.zip(alphas.into_iter())
|
||||
.take(num_old)
|
||||
{
|
||||
let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
|
||||
|
||||
let beta = rho.mul(y.clone().dot(r.clone()));
|
||||
|
||||
r = r.add(s.clone().mul(alpha.sub(beta)));
|
||||
}
|
||||
d = r;
|
||||
} else {
|
||||
d = q;
|
||||
}
|
||||
}
|
||||
|
||||
prev_flat_grad = Some(flat_grad.clone());
|
||||
let prev_loss_iter = loss;
|
||||
|
||||
// compute step len
|
||||
if self.state.g_iter == 1 {
|
||||
let grad_l1 = flat_grad.clone().abs().sum().into_scalar().to_f64();
|
||||
t = (1.0f64 / grad_l1).min(1.0) * lr;
|
||||
} else {
|
||||
t = lr;
|
||||
}
|
||||
|
||||
// directional derivative
|
||||
let gtd = flat_grad.clone().dot(d.clone()).into_scalar().to_f64();
|
||||
|
||||
if gtd > -self.config.tolerance_change {
|
||||
break;
|
||||
}
|
||||
|
||||
let ls_func_evals;
|
||||
|
||||
if let LineSearchFn::StrongWolfe = self.config.line_search_fn {
|
||||
// perform line search, using user function
|
||||
let mut obj_func =
|
||||
|current_x: &Tensor<B::InnerBackend, 1>,
|
||||
step: f64,
|
||||
dir: &Tensor<B::InnerBackend, 1>| {
|
||||
let update = dir.clone().mul_scalar(step);
|
||||
let new_x = current_x.clone().add(update);
|
||||
let tmp_module = set_params_from_flat_inner::<B, M>(module.clone(), new_x);
|
||||
let (l, g) = closure(tmp_module);
|
||||
(l, flatten_grads_inner::<B, M>(&module, &g))
|
||||
};
|
||||
|
||||
let (ls_f, ls_g, ls_t, evals) = strong_wolfe(
|
||||
&mut obj_func,
|
||||
&x_flat,
|
||||
t,
|
||||
&d,
|
||||
loss,
|
||||
flat_grad.clone(),
|
||||
gtd,
|
||||
1e-4,
|
||||
0.9,
|
||||
self.config.tolerance_change,
|
||||
self.config.max_eval.unwrap() - current_evals,
|
||||
);
|
||||
|
||||
loss = ls_f;
|
||||
flat_grad = ls_g;
|
||||
t = ls_t;
|
||||
ls_func_evals = evals;
|
||||
|
||||
x_flat = x_flat.add(d.clone().mul_scalar(t));
|
||||
module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
|
||||
} else {
|
||||
// no line search, simply move with fixed-step
|
||||
let step_vec = d.clone().mul_scalar(t);
|
||||
x_flat = x_flat.add(step_vec);
|
||||
module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
|
||||
// re-evaluate function only if not in last iteration
|
||||
// the reason we do this: in a stochastic setting,
|
||||
// no use to re-evaluate that function here
|
||||
let (new_loss, new_grads) = closure(module.clone());
|
||||
loss = new_loss;
|
||||
flat_grad = flatten_grads_inner::<B, M>(&module, &new_grads);
|
||||
ls_func_evals = 1;
|
||||
}
|
||||
|
||||
// update func eval
|
||||
current_evals += ls_func_evals;
|
||||
|
||||
// check conditions
|
||||
|
||||
if current_evals >= self.config.max_eval.unwrap() {
|
||||
break;
|
||||
}
|
||||
|
||||
if flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad {
|
||||
break;
|
||||
}
|
||||
|
||||
if d.clone().mul_scalar(t).abs().max().into_scalar().to_f64()
|
||||
<= self.config.tolerance_change
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (loss - prev_loss_iter).abs() < self.config.tolerance_change {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.state.d = Some(d);
|
||||
self.state.t = Some(t);
|
||||
self.state.prev_flat_grad = prev_flat_grad;
|
||||
self.state.prev_loss = Some(loss);
|
||||
(module, loss)
|
||||
}
|
||||
/// Moves the optimizer state to the specified device.
|
||||
pub fn to_device(self, device: &B::Device) -> Self {
|
||||
Self {
|
||||
config: self.config,
|
||||
// History tensors reside in InnerBackend, so we convert the device accordingly
|
||||
state: self.state.to_device(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::GradientsParams;
|
||||
use crate::TestAutodiffBackend;
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Tensor, TensorData};
|
||||
use burn_nn::{Linear, LinearConfig, LinearRecord};
|
||||
|
||||
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
|
||||
let device = Default::default();
|
||||
let record = LinearRecord {
|
||||
weight: Param::from_data(weight, &device),
|
||||
bias: Some(Param::from_data(bias, &device)),
|
||||
};
|
||||
|
||||
LinearConfig::new(6, 6).init(&device).load_record(record)
|
||||
}
|
||||
#[test]
|
||||
fn test_cubic_interpolate() {
|
||||
let tolerance = 1e-8;
|
||||
|
||||
// basic
|
||||
let (x1, f1, g1, x2, f2, g2) = (-1.0, 1.0, -2.0, 1.0, 1.0, 2.0);
|
||||
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
|
||||
assert!(
|
||||
(result - 0.00000).abs() < tolerance,
|
||||
"Basic: Result {} should be close to 0.0",
|
||||
result
|
||||
);
|
||||
|
||||
// bound
|
||||
let (x1, f1, g1, x2, f2, g2) = (0.0, 0.25, -1.0, 1.0, 0.25, 1.0);
|
||||
let bounds = Some((0.6, 1.0));
|
||||
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds);
|
||||
assert!(
|
||||
(result - 0.6000000000).abs() < tolerance,
|
||||
"Bound: Result {} should be clamped to 0.6",
|
||||
result
|
||||
);
|
||||
|
||||
// d2_square < 0,should return mid value
|
||||
let (x1, f1, g1, x2, f2, g2) = (0.0, 0.0, 10.0, 1.0, 5.0, 10.0);
|
||||
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((0.0, 1.0)));
|
||||
assert!(
|
||||
(result - 0.5000000).abs() < tolerance,
|
||||
"Fallback: Result {} should be midpoint 0.5",
|
||||
result
|
||||
);
|
||||
|
||||
// asymmetric
|
||||
let (x1, f1, g1, x2, f2, g2) = (0.0, 1.0, -5.0, 1.0, 0.5, 1.0);
|
||||
let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
|
||||
assert!(
|
||||
(result - 0.4606553370833684).abs() < tolerance,
|
||||
"Asymmetric: Result {} should be 0.4606553370833684",
|
||||
result
|
||||
);
|
||||
|
||||
// not good value
|
||||
let (x1, f1, g1, x2, f2, g2) = (
|
||||
1.231232145,
|
||||
-0.12567458754,
|
||||
9.1231243007,
|
||||
8.239105015,
|
||||
-100.9012398021,
|
||||
123201321.0293982,
|
||||
);
|
||||
let result_1 = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
|
||||
let result_2 = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((-4.4, 4.4)));
|
||||
assert!(
|
||||
(result_1 - 5.9031480234724434).abs() < tolerance,
|
||||
"not good value 1: Result {} should be 5.9031480234724434",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
(result_2 - 4.4000000000000004).abs() < tolerance,
|
||||
"not good value 2: Result {} should be 4.4000000000000004",
|
||||
result
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_strong_wolfe_direct_comparison() {
|
||||
let device = Default::default();
|
||||
let tol = 1e-8;
|
||||
|
||||
{
|
||||
let x = Tensor::<TestAutodiffBackend, 1>::from_floats([2.1321912957_f64], &device);
|
||||
let d = Tensor::<TestAutodiffBackend, 1>::from_floats([0.91312321_f64], &device);
|
||||
let t_initial = 1.213132_f64;
|
||||
fn func<B: Backend>(
|
||||
x_base: &Tensor<B, 1>,
|
||||
t_val: f64,
|
||||
d_vec: &Tensor<B, 1>,
|
||||
) -> (f64, Tensor<B, 1>) {
|
||||
let curr_x = x_base.clone().add(d_vec.clone().mul_scalar(t_val));
|
||||
let x2 = curr_x.clone().mul(curr_x.clone());
|
||||
let x3 = x2.clone().mul(curr_x.clone());
|
||||
let x4 = x2.clone().mul(x2.clone());
|
||||
|
||||
// f(x) = x^4 - 2*x^2 + x
|
||||
let f_elements = x4 - x2.mul_scalar(2.0) + curr_x.clone();
|
||||
|
||||
let f_val = f_elements.sum().into_scalar().to_f64();
|
||||
|
||||
// g(x) = 4*x^3 - 4*x + 1
|
||||
let g = x3.mul_scalar(4.0) - curr_x.clone().mul_scalar(4.0)
|
||||
+ Tensor::ones_like(&curr_x);
|
||||
|
||||
(f_val, g)
|
||||
}
|
||||
let (f_init, g_init) = func(&x, 0.0, &d);
|
||||
let gtd_init = g_init.clone().dot(d.clone()).into_scalar().to_f64();
|
||||
println!("Initial State: f={},gtd = {}", f_init, gtd_init);
|
||||
assert!((f_init - 13.7080059052).abs() < tol);
|
||||
assert!((gtd_init - 28.5305728912).abs() < tol);
|
||||
let mut obj_func =
|
||||
|xb: &Tensor<TestAutodiffBackend, 1>,
|
||||
tv: f64,
|
||||
dv: &Tensor<TestAutodiffBackend, 1>| func(xb, tv, dv);
|
||||
|
||||
let (f_final, _g_final, t_final, evals) = strong_wolfe(
|
||||
&mut obj_func,
|
||||
&x,
|
||||
t_initial,
|
||||
&d,
|
||||
f_init,
|
||||
g_init,
|
||||
gtd_init,
|
||||
1e-4, // c1
|
||||
0.9, // c2
|
||||
1e-9, // tolerance_change
|
||||
10, // max_ls
|
||||
);
|
||||
let g_f = _g_final.into_scalar().to_f64();
|
||||
println!(
|
||||
"f_final:{:?},_g_final:{:?},t_final:{:?},evals:{:?}",
|
||||
f_final, g_f, t_final, evals
|
||||
);
|
||||
assert!((f_final - 13.708005905151367).abs() < tol);
|
||||
assert!((g_f - 31.2450428009).abs() < tol);
|
||||
assert!((t_final.to_f64() - 0.0).abs() < tol);
|
||||
assert!((evals == 11));
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_lbfgs_strong_wolfe_comparison() {
|
||||
let device = Default::default();
|
||||
let tol = 1e-5;
|
||||
let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
|
||||
let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
|
||||
let weight = TensorData::from([[0.5f64]]);
|
||||
let bias = TensorData::from([0.1f64]);
|
||||
let module = given_linear_layer(weight, bias);
|
||||
|
||||
let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
|
||||
.with_line_search_fn(LineSearchFn::StrongWolfe)
|
||||
.init();
|
||||
let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
|
||||
let output = mod_in.forward(x_data.clone());
|
||||
let loss = burn_nn::loss::MseLoss::new().forward(
|
||||
output,
|
||||
y_true.clone(),
|
||||
burn_nn::loss::Reduction::Sum,
|
||||
);
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads_params = GradientsParams::from_grads(grads, &mod_in);
|
||||
|
||||
(loss.into_scalar().to_f64(), grads_params)
|
||||
};
|
||||
let initial_loss = closure(module.clone()).0;
|
||||
assert!((initial_loss - 50.1300048828).abs() < tol);
|
||||
let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
|
||||
assert!((final_loss - 0.0234732367).abs() < tol);
|
||||
let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
|
||||
let optimized_bias: f64 = updated_module
|
||||
.bias
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.val()
|
||||
.into_scalar()
|
||||
.to_f64();
|
||||
assert!((optimized_data - 2.0570652485).abs() < tol);
|
||||
assert!((optimized_bias - 0.8106800914).abs() < tol);
|
||||
}
|
||||
#[test]
|
||||
fn test_lbfgs_no_strong_wolfe_comparison() {
|
||||
let device = Default::default();
|
||||
let tol = 1e-5;
|
||||
let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
|
||||
let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
|
||||
let weight = TensorData::from([[0.5f64]]);
|
||||
let bias = TensorData::from([0.1f64]);
|
||||
let module = given_linear_layer(weight, bias);
|
||||
|
||||
let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
|
||||
.with_line_search_fn(LineSearchFn::None)
|
||||
.init();
|
||||
let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
|
||||
let output = mod_in.forward(x_data.clone());
|
||||
let loss = burn_nn::loss::MseLoss::new().forward(
|
||||
output,
|
||||
y_true.clone(),
|
||||
burn_nn::loss::Reduction::Sum,
|
||||
);
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads_params = GradientsParams::from_grads(grads, &mod_in);
|
||||
|
||||
(loss.into_scalar().to_f64(), grads_params)
|
||||
};
|
||||
let initial_loss = closure(module.clone()).0;
|
||||
assert!((initial_loss - 50.1300048828).abs() < tol);
|
||||
let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
|
||||
assert!((final_loss - 48.2181930542).abs() < tol);
|
||||
let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
|
||||
let optimized_bias: f64 = updated_module
|
||||
.bias
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.val()
|
||||
.into_scalar()
|
||||
.to_f64();
|
||||
|
||||
assert!((optimized_data - 0.5302446192).abs() < tol);
|
||||
assert!((optimized_bias - 0.1142520783).abs() < tol);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
/// Weight decay module for optimizers.
|
||||
pub mod decay;
|
||||
|
||||
/// Momentum module for optimizers.
|
||||
pub mod momentum;
|
||||
|
||||
mod adagrad;
|
||||
mod adam;
|
||||
mod adamw;
|
||||
mod base;
|
||||
mod grad_accum;
|
||||
mod grads;
|
||||
mod lbfgs;
|
||||
mod muon;
|
||||
mod rmsprop;
|
||||
mod sgd;
|
||||
mod simple;
|
||||
mod visitor;
|
||||
|
||||
pub use adagrad::*;
|
||||
pub use adam::*;
|
||||
pub use adamw::*;
|
||||
pub use base::*;
|
||||
pub use grad_accum::*;
|
||||
pub use grads::*;
|
||||
pub use lbfgs::*;
|
||||
pub use muon::*;
|
||||
pub use rmsprop::*;
|
||||
pub use sgd::*;
|
||||
pub use simple::*;
|
||||
@@ -0,0 +1,94 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{ElementConversion, Tensor};
|
||||
|
||||
/// Configuration to create [momentum](Momentum).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MomentumConfig {
|
||||
/// Momentum factor
|
||||
#[config(default = 0.9)]
|
||||
pub momentum: f64,
|
||||
/// Dampening factor.
|
||||
#[config(default = 0.1)]
|
||||
pub dampening: f64,
|
||||
/// Enables Nesterov momentum, see [On the importance of initialization and
|
||||
/// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
|
||||
#[config(default = false)]
|
||||
pub nesterov: bool,
|
||||
}
|
||||
|
||||
/// State of [momentum](Momentum).
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct MomentumState<B: Backend, const D: usize> {
|
||||
velocity: Tensor<B, D>,
|
||||
}
|
||||
|
||||
/// Momentum implementation that transforms gradients.
|
||||
#[derive(Clone)]
|
||||
pub struct Momentum<B: Backend> {
|
||||
momentum: B::FloatElem,
|
||||
dampening: f64,
|
||||
nesterov: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> Momentum<B> {
|
||||
/// Creates a new [momentum](Momentum) from a [config](MomentumConfig).
|
||||
pub fn new(config: &MomentumConfig) -> Self {
|
||||
Self {
|
||||
momentum: config.momentum.elem(),
|
||||
dampening: config.dampening,
|
||||
nesterov: config.nesterov,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transforms a gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `grad` - Gradient to transform.
|
||||
/// * `state` - State of the optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `grad` - Transformed gradient.
|
||||
/// * `state` - State of the optimizer.
|
||||
pub fn transform<const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
state: Option<MomentumState<B, D>>,
|
||||
) -> (Tensor<B, D>, MomentumState<B, D>) {
|
||||
let velocity = if let Some(state) = state {
|
||||
grad.clone()
|
||||
.mul_scalar(1.0 - self.dampening)
|
||||
.add(state.velocity.mul_scalar(self.momentum))
|
||||
} else {
|
||||
grad.clone()
|
||||
};
|
||||
|
||||
let grad = match self.nesterov {
|
||||
true => velocity.clone().mul_scalar(self.momentum).add(grad),
|
||||
false => velocity.clone(),
|
||||
};
|
||||
|
||||
(grad, MomentumState::new(velocity))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> MomentumState<B, D> {
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.velocity = self.velocity.to_device(device);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,775 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::{module::AutodiffModule, record::Record};
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use burn::tensor::{backend::Backend, ops::Device};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{
|
||||
SimpleOptimizer,
|
||||
adaptor::OptimizerAdaptor,
|
||||
decay::WeightDecayConfig,
|
||||
momentum::{Momentum, MomentumConfig, MomentumState},
|
||||
};
|
||||
use crate::LearningRate;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Learning rate adjustment method for Muon optimizer.
|
||||
///
|
||||
/// Muon adjusts the learning rate based on parameter shape to maintain consistent
|
||||
/// RMS across rectangular matrices.
|
||||
///
|
||||
/// # References
|
||||
///
|
||||
/// - Original: [Muon: An optimizer for hidden layers](https://kellerjordan.github.io/posts/muon/)
|
||||
/// - Moonshot: [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
|
||||
#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AdjustLrFn {
|
||||
/// Keller Jordan's original method: `lr * sqrt(max(1, A/B))`
|
||||
///
|
||||
/// This scales the learning rate based on the aspect ratio of the weight matrix,
|
||||
/// ensuring that tall matrices (more rows than columns) get proportionally larger
|
||||
/// learning rates.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// For a [1024, 512] matrix: `lr * sqrt(1024/512) = lr * 1.414`
|
||||
#[default]
|
||||
Original,
|
||||
|
||||
/// Moonshot's method: `lr * 0.2 * sqrt(max(A, B))`
|
||||
///
|
||||
/// This method is designed to match AdamW's RMS, allowing Muon to directly reuse
|
||||
/// learning rates and weight decay values tuned for AdamW without retuning.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// For a [1024, 512] matrix: `lr * 0.2 * sqrt(1024) = lr * 6.4`
|
||||
MatchRmsAdamW,
|
||||
}
|
||||
|
||||
impl AdjustLrFn {
|
||||
/// Calculate the learning rate adjustment ratio for a given parameter shape.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - Parameter shape (uses first two dimensions)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Adjustment ratio to multiply with the base learning rate
|
||||
fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
|
||||
if shape.len() < 2 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let a = shape[0] as f64;
|
||||
let b = shape[1] as f64;
|
||||
|
||||
match self {
|
||||
Self::Original => {
|
||||
// sqrt(max(1, A/B))
|
||||
let ratio = a / b;
|
||||
ratio.max(1.0).sqrt()
|
||||
}
|
||||
Self::MatchRmsAdamW => {
|
||||
// 0.2 * sqrt(max(A, B))
|
||||
0.2 * a.max(b).sqrt()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Muon configuration.
|
||||
///
|
||||
/// Muon is an optimizer specifically designed for 2D parameters of neural network
|
||||
/// hidden layers (weight matrices). Other parameters such as biases and embeddings
|
||||
/// should be optimized using a standard method such as AdamW.
|
||||
///
|
||||
/// # Learning Rate Adjustment
|
||||
///
|
||||
/// Muon adjusts the learning rate based on parameter shape to maintain consistent
|
||||
/// RMS across rectangular matrices. Two methods are available:
|
||||
///
|
||||
/// - **Original**: Uses `sqrt(max(1, A/B))` where A and B are the first two dimensions.
|
||||
/// This is Keller Jordan's method and is the default.
|
||||
///
|
||||
/// - **MatchRmsAdamW**: Uses `0.2 * sqrt(max(A, B))`. This is Moonshot's method
|
||||
/// designed to match AdamW's RMS, allowing direct reuse of AdamW hyperparameters.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn_optim::{MuonConfig, AdjustLrFn};
|
||||
///
|
||||
/// // Using default (Original) method
|
||||
/// let optimizer = MuonConfig::new().init();
|
||||
///
|
||||
/// // Using MatchRmsAdamW for AdamW-compatible hyperparameters
|
||||
/// let optimizer = MuonConfig::new()
|
||||
/// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
|
||||
/// .init();
|
||||
/// ```
|
||||
///
|
||||
/// # References
|
||||
///
|
||||
/// - [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/)
|
||||
/// - [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
|
||||
/// - [PyTorch Implementation](https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py)
|
||||
/// - [Original Implementation](https://github.com/KellerJordan/Muon)
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MuonConfig {
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
weight_decay: Option<WeightDecayConfig>,
|
||||
|
||||
/// [Momentum](MomentumConfig) config.
|
||||
///
|
||||
/// Muon always uses momentum. Default configuration:
|
||||
/// - momentum: 0.95
|
||||
/// - dampening: 0.0
|
||||
/// - nesterov: true
|
||||
#[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
|
||||
momentum: MomentumConfig,
|
||||
|
||||
/// Newton-Schulz iteration coefficients (a, b, c).
|
||||
///
|
||||
/// These coefficients are selected to maximize the slope at zero for the
|
||||
/// quintic iteration. Default values are from Keller Jordan's implementation.
|
||||
#[config(default = "(3.4445, -4.775, 2.0315)")]
|
||||
ns_coefficients: (f32, f32, f32),
|
||||
|
||||
/// Epsilon for numerical stability.
|
||||
#[config(default = 1e-7)]
|
||||
epsilon: f32,
|
||||
|
||||
/// Number of Newton-Schulz iteration steps.
|
||||
#[config(default = 5)]
|
||||
ns_steps: usize,
|
||||
|
||||
/// Learning rate adjustment method.
|
||||
///
|
||||
/// Controls how the learning rate is adjusted based on parameter shape.
|
||||
/// See [`AdjustLrFn`] for available methods.
|
||||
#[config(default = "AdjustLrFn::Original")]
|
||||
adjust_lr_fn: AdjustLrFn,
|
||||
}
|
||||
|
||||
impl MuonConfig {
|
||||
/// Initialize Muon optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer adaptor that can be used to optimize a module.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
|
||||
///
|
||||
/// // Basic configuration with default (Original) LR adjustment
|
||||
/// let optimizer = MuonConfig::new()
|
||||
/// .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
|
||||
/// .init();
|
||||
///
|
||||
/// // With AdamW-compatible settings using MatchRmsAdamW
|
||||
/// let optimizer = MuonConfig::new()
|
||||
/// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
|
||||
/// .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
|
||||
/// .init();
|
||||
///
|
||||
/// // Custom momentum and NS settings
|
||||
/// let optimizer = MuonConfig::new()
|
||||
/// .with_momentum(MomentumConfig {
|
||||
/// momentum: 0.9,
|
||||
/// dampening: 0.1,
|
||||
/// nesterov: false,
|
||||
/// })
|
||||
/// .with_ns_steps(7)
|
||||
/// .init();
|
||||
/// ```
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
&self,
|
||||
) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
|
||||
let momentum = Momentum::new(&self.momentum);
|
||||
let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
|
||||
|
||||
let optim = Muon {
|
||||
momentum,
|
||||
ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
|
||||
weight_decay_penalty,
|
||||
epsilon: self.epsilon,
|
||||
adjust_lr_fn: self.adjust_lr_fn,
|
||||
};
|
||||
|
||||
OptimizerAdaptor::from(optim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters for Newton-Schulz orthogonalization.
|
||||
#[derive(Clone, Copy)]
|
||||
struct NewtonSchulzParams {
|
||||
a: f32,
|
||||
b: f32,
|
||||
c: f32,
|
||||
steps: usize,
|
||||
}
|
||||
|
||||
impl NewtonSchulzParams {
|
||||
fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
|
||||
Self {
|
||||
a: coefficients.0,
|
||||
b: coefficients.1,
|
||||
c: coefficients.2,
|
||||
steps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Muon optimizer.
|
||||
///
|
||||
/// Muon internally runs standard SGD-momentum, and then performs an orthogonalization
|
||||
/// post-processing step, in which each 2D parameter's update is replaced with the
|
||||
/// nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz
|
||||
/// iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
///
|
||||
/// # Important Notes
|
||||
///
|
||||
/// 1. **Only for 2D+ parameters**: Muon is designed for weight matrices. Use AdamW
|
||||
/// or SGD for biases, embeddings, and layer norms.
|
||||
///
|
||||
/// 2. **Learning rate adjustment**: Muon automatically adjusts the learning rate based
|
||||
/// on parameter shape. See [`AdjustLrFn`] for details.
|
||||
///
|
||||
/// 3. **Weight decay timing**: Unlike typical optimizers, Muon applies weight decay
|
||||
/// AFTER orthogonalization but uses the original (unadjusted) learning rate for it.
|
||||
#[derive(Clone)]
|
||||
pub struct Muon<B: Backend> {
|
||||
momentum: Momentum<B>,
|
||||
ns_params: NewtonSchulzParams,
|
||||
weight_decay_penalty: Option<f32>,
|
||||
epsilon: f32,
|
||||
adjust_lr_fn: AdjustLrFn,
|
||||
}
|
||||
|
||||
impl<B: Backend> Muon<B> {
|
||||
/// Adjust learning rate based on parameter shape.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lr` - Base learning rate
|
||||
/// * `shape` - Parameter shape (uses first two dimensions)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Adjusted learning rate
|
||||
///
|
||||
/// ```ignore
|
||||
/// // For a [1024, 512] weight matrix with lr=0.01:
|
||||
/// // Original: 0.01 * sqrt(1024/512) = 0.01 * 1.414 = 0.01414
|
||||
/// // MatchRmsAdamW: 0.01 * 0.2 * sqrt(1024) = 0.01 * 0.2 * 32 = 0.064
|
||||
/// ```
|
||||
fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
|
||||
lr * self.adjust_lr_fn.adjustment_ratio(shape)
|
||||
}
|
||||
|
||||
/// Perform Newton-Schulz orthogonalization on a gradient tensor.
|
||||
///
|
||||
/// This computes the zeroth power (orthogonalization) of the input matrix G
|
||||
/// using a quintic Newton-Schulz iteration.
|
||||
///
|
||||
/// # Algorithm
|
||||
///
|
||||
/// 1. Transpose if tall matrix (A > B)
|
||||
/// 2. Normalize: X = X / ||X||
|
||||
/// 3. For k steps:
|
||||
/// - A = X @ X^T
|
||||
/// - B = b*A + c*A^2
|
||||
/// - X = a*X + B@X
|
||||
/// 4. Transpose back if needed
|
||||
///
|
||||
/// # References
|
||||
///
|
||||
/// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py
|
||||
/// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py
|
||||
fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let shape = g.shape();
|
||||
let dim_m2 = shape[D - 2];
|
||||
let dim_m1 = shape[D - 1];
|
||||
|
||||
// Step 1: Transpose if tall matrix (more rows than columns)
|
||||
let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
|
||||
(g.swap_dims(D - 2, D - 1), true)
|
||||
} else {
|
||||
(g, false)
|
||||
};
|
||||
|
||||
// Step 2: Normalize by Frobenius norm
|
||||
// X = X / (||X|| + epsilon)
|
||||
let norm = x
|
||||
.clone()
|
||||
.powf_scalar(2.0)
|
||||
.sum()
|
||||
.sqrt()
|
||||
.clamp_min(self.epsilon)
|
||||
.unsqueeze();
|
||||
|
||||
x = x.div(norm);
|
||||
|
||||
// Step 3: Newton-Schulz iteration
|
||||
// This is the quintic iteration with coefficients (a, b, c)
|
||||
let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
|
||||
|
||||
for _ in 0..steps {
|
||||
// A = X @ X^T
|
||||
let x_t = x.clone().swap_dims(D - 2, D - 1);
|
||||
let a_matrix = x.clone().matmul(x_t);
|
||||
|
||||
// B = b*A + c*A@A
|
||||
let a_squared = a_matrix.clone().matmul(a_matrix.clone());
|
||||
let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
|
||||
|
||||
// X = a*X + B@X
|
||||
x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
|
||||
}
|
||||
|
||||
// Step 4: Restore transpose if it was a tall matrix
|
||||
if needs_transpose {
|
||||
x = x.swap_dims(D - 2, D - 1);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
/// Muon state.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct MuonState<B: Backend, const D: usize> {
|
||||
/// Current momentum state
|
||||
pub momentum: MomentumState<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
|
||||
type State<const D: usize> = MuonState<B, D>;
|
||||
|
||||
/// Perform a single Muon optimization step.
|
||||
///
|
||||
/// # Algorithm
|
||||
///
|
||||
/// 1. Apply momentum to gradient
|
||||
/// 2. Orthogonalize update via Newton-Schulz
|
||||
/// 3. Adjust learning rate based on parameter shape
|
||||
/// 4. Apply weight decay (using original lr)
|
||||
/// 5. Update parameter (using adjusted lr)
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Unlike typical optimizers, the weight decay and parameter update use
|
||||
/// different learning rates:
|
||||
/// - Weight decay uses the original `lr`
|
||||
/// - Parameter update uses the shape-adjusted `lr`
|
||||
///
|
||||
/// # Panics
|
||||
/// This function will panic if the input tensors are not 2D.
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
lr: LearningRate,
|
||||
tensor: Tensor<B, D>,
|
||||
grad: Tensor<B, D>,
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||
assert!(
|
||||
D == 2,
|
||||
"Newton-Schulz iteration requires 2D tensors, got {}D",
|
||||
D
|
||||
);
|
||||
|
||||
// Step 1: Apply momentum
|
||||
let state_momentum = state.map(|s| s.momentum);
|
||||
let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
|
||||
|
||||
// Step 2: Orthogonalize via Newton-Schulz
|
||||
let update = self.zeropower_via_newtonschulz(grad);
|
||||
|
||||
// Step 3: Adjust learning rate based on parameter shape
|
||||
let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
|
||||
|
||||
// Step 4: Apply weight decay (using ORIGINAL lr, not adjusted)
|
||||
// Muon applies weight decay AFTER orthogonalization
|
||||
let tensor = if let Some(penalty) = self.weight_decay_penalty {
|
||||
let decay_factor = 1.0 - lr * penalty as f64;
|
||||
tensor.mul_scalar(decay_factor)
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
|
||||
// Step 5: Update parameter (using ADJUSTED lr)
|
||||
let delta = update.mul_scalar(adjusted_lr);
|
||||
let new_state = MuonState::new(new_momentum_state);
|
||||
|
||||
(tensor - delta, Some(new_state))
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
|
||||
state.momentum = state.momentum.to_device(device);
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use crate::{GradientsParams, Optimizer};
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Distribution, Tensor, TensorData};
|
||||
use burn_nn::{Linear, LinearConfig, LinearRecord};
|
||||
|
||||
type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
const TOLERANCE: f64 = 1e-8;
|
||||
|
||||
fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
|
||||
let device = Default::default();
|
||||
let record = LinearRecord {
|
||||
weight: Param::from_data(weight, &device),
|
||||
bias: None, //No bias for Muon optimizer
|
||||
};
|
||||
|
||||
LinearConfig::new(4, 4)
|
||||
.with_bias(false)
|
||||
.init(&device)
|
||||
.load_record(record)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adjust_lr_fn_original() {
|
||||
let method = AdjustLrFn::Original;
|
||||
|
||||
// Square matrix [512, 512] -> sqrt(1) = 1.0
|
||||
let ratio = method.adjustment_ratio(&[512, 512]);
|
||||
assert!((ratio - 1.0).abs() < TOLERANCE);
|
||||
|
||||
// Tall matrix [1024, 512] -> sqrt(2) ≈ 1.414
|
||||
let ratio = method.adjustment_ratio(&[1024, 512]);
|
||||
let expected = (2.0f64).sqrt();
|
||||
assert!((ratio - expected).abs() < TOLERANCE);
|
||||
|
||||
// Wide matrix [512, 1024] -> max(1, 0.5) = 1.0
|
||||
let ratio = method.adjustment_ratio(&[512, 1024]);
|
||||
assert!((ratio - 1.0).abs() < TOLERANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adjust_lr_fn_match_rms_adamw() {
|
||||
let method = AdjustLrFn::MatchRmsAdamW;
|
||||
|
||||
// [1024, 512] -> 0.2 * sqrt(1024) = 6.4
|
||||
let ratio = method.adjustment_ratio(&[1024, 512]);
|
||||
let expected = 0.2 * 1024.0f64.sqrt();
|
||||
assert!((ratio - expected).abs() < TOLERANCE);
|
||||
|
||||
// [512, 512] -> 0.2 * sqrt(512) ≈ 4.525
|
||||
let ratio = method.adjustment_ratio(&[512, 512]);
|
||||
let expected = 0.2 * 512.0f64.sqrt();
|
||||
assert!((ratio - expected).abs() < TOLERANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
|
||||
fn test_1d_tensor_panics() {
|
||||
let device = Default::default();
|
||||
let config = MuonConfig::new();
|
||||
let optim: Muon<TestBackend> = Muon {
|
||||
momentum: Momentum::new(&config.momentum),
|
||||
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
|
||||
weight_decay_penalty: None,
|
||||
epsilon: config.epsilon,
|
||||
adjust_lr_fn: config.adjust_lr_fn,
|
||||
};
|
||||
|
||||
let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
|
||||
let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
|
||||
|
||||
let _ = optim.step(0.01, tensor_1d, grad_1d, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_muon_optimizer_save_load_state() {
|
||||
let device = Default::default();
|
||||
// Use Linear layer WITHOUT bias for Muon optimizer
|
||||
let linear = LinearConfig::new(6, 6)
|
||||
.with_bias(false) // No bias - only 2D weight matrix
|
||||
.init::<TestAutodiffBackend>(&device);
|
||||
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
|
||||
|
||||
let mut optimizer =
|
||||
MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let _linear = optimizer.step(0.01, linear, grads);
|
||||
|
||||
let state_before = optimizer.to_record();
|
||||
let state_before_copy = optimizer.to_record();
|
||||
|
||||
let optimizer_new =
|
||||
MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
|
||||
let optimizer_loaded = optimizer_new.load_record(state_before_copy);
|
||||
let state_after = optimizer_loaded.to_record();
|
||||
|
||||
assert_eq!(state_before.len(), state_after.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_muon_with_weight_decay() {
|
||||
let device = Default::default();
|
||||
// Create Linear layer WITHOUT bias for Muon
|
||||
let linear = given_linear_layer_no_bias(TensorData::from([
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
]));
|
||||
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = MuonConfig::new()
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.01)))
|
||||
.init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
|
||||
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(0.01, linear, grads);
|
||||
|
||||
let state = linear.into_record();
|
||||
let weight = state.weight.to_data();
|
||||
|
||||
for val in weight.as_slice::<f32>().unwrap() {
|
||||
assert!(
|
||||
*val < 1.0,
|
||||
"Weight should be reduced by weight decay, got {}",
|
||||
val
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_newton_schulz_orthogonalization() {
|
||||
let device = Default::default();
|
||||
let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
|
||||
|
||||
let config = MuonConfig::new();
|
||||
let muon: Muon<TestBackend> = Muon {
|
||||
momentum: Momentum::new(&config.momentum),
|
||||
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
|
||||
weight_decay_penalty: None,
|
||||
epsilon: config.epsilon,
|
||||
adjust_lr_fn: config.adjust_lr_fn,
|
||||
};
|
||||
|
||||
let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
|
||||
let o_t = orthogonalized.clone().transpose();
|
||||
let product = orthogonalized.matmul(o_t);
|
||||
|
||||
let data = product.into_data();
|
||||
let values = data.as_slice::<f32>().unwrap();
|
||||
|
||||
assert!(
|
||||
(values[0] - 1.0).abs() < 0.1,
|
||||
"Product[0,0] should be ~1.0, got {}",
|
||||
values[0]
|
||||
);
|
||||
assert!(
|
||||
(values[3] - 1.0).abs() < 0.1,
|
||||
"Product[1,1] should be ~1.0, got {}",
|
||||
values[3]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tall_matrix_transpose() {
|
||||
// Test that tall matrices (A > B) are transposed during Newton-Schulz iteration
|
||||
// and then transposed back
|
||||
let device = Default::default();
|
||||
|
||||
// Create a tall matrix: [8, 4] (more rows than columns)
|
||||
let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
|
||||
[
|
||||
[1.0, 0.5, 0.3, 0.2],
|
||||
[0.5, 1.0, 0.4, 0.1],
|
||||
[0.3, 0.4, 1.0, 0.5],
|
||||
[0.2, 0.1, 0.5, 1.0],
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.4, 0.3, 0.2, 0.1],
|
||||
[0.2, 0.4, 0.1, 0.3],
|
||||
[0.3, 0.1, 0.4, 0.2],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
let config = MuonConfig::new();
|
||||
let muon: Muon<TestBackend> = Muon {
|
||||
momentum: Momentum::new(&config.momentum),
|
||||
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
|
||||
weight_decay_penalty: None,
|
||||
epsilon: config.epsilon,
|
||||
adjust_lr_fn: config.adjust_lr_fn,
|
||||
};
|
||||
|
||||
// Perform Newton-Schulz orthogonalization
|
||||
let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
|
||||
|
||||
// Verify shape is preserved (should be transposed internally but returned in original shape)
|
||||
let original_shape = tall_matrix.shape();
|
||||
let result_shape = orthogonalized.shape();
|
||||
assert_eq!(
|
||||
original_shape.dims::<2>(),
|
||||
result_shape.dims::<2>(),
|
||||
"Shape should be preserved: [8, 4]"
|
||||
);
|
||||
|
||||
// Verify output is different from input (orthogonalization happened)
|
||||
let original_data = tall_matrix.into_data();
|
||||
let result_data = orthogonalized.into_data();
|
||||
assert_ne!(
|
||||
original_data.as_slice::<f32>().unwrap(),
|
||||
result_data.as_slice::<f32>().unwrap(),
|
||||
"Orthogonalized matrix should differ from input"
|
||||
);
|
||||
|
||||
// For comparison, test a wide matrix [4, 8] should NOT be transposed
|
||||
let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
|
||||
[
|
||||
[1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
|
||||
[0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
|
||||
[0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
|
||||
[0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
|
||||
|
||||
// Verify wide matrix shape is also preserved
|
||||
let wide_original_shape = wide_matrix.shape();
|
||||
let wide_result_shape = orthogonalized_wide.shape();
|
||||
assert_eq!(
|
||||
wide_original_shape.dims::<2>(),
|
||||
wide_result_shape.dims::<2>(),
|
||||
"Wide matrix shape should be preserved: [4, 8]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_gradient() {
|
||||
// Test that Muon handles zero gradients gracefully
|
||||
let device = Default::default();
|
||||
|
||||
let tensor = Tensor::<TestBackend, 2>::from_floats(
|
||||
[
|
||||
[1.0, 0.5, 0.3, 0.2],
|
||||
[0.5, 1.0, 0.4, 0.1],
|
||||
[0.3, 0.4, 1.0, 0.5],
|
||||
[0.2, 0.1, 0.5, 1.0],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
// Zero gradient - all zeros
|
||||
let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
|
||||
|
||||
let config = MuonConfig::new();
|
||||
let muon: Muon<TestBackend> = Muon {
|
||||
momentum: Momentum::new(&config.momentum),
|
||||
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
|
||||
weight_decay_penalty: None,
|
||||
epsilon: config.epsilon,
|
||||
adjust_lr_fn: config.adjust_lr_fn,
|
||||
};
|
||||
|
||||
// Should not panic or produce NaN
|
||||
let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
|
||||
|
||||
// Verify state was created
|
||||
assert!(state.is_some());
|
||||
|
||||
// With zero gradient and no weight decay, tensor should remain unchanged
|
||||
let original_data = tensor.into_data();
|
||||
let updated_data = updated_tensor.clone().into_data();
|
||||
|
||||
let original_vals = original_data.as_slice::<f32>().unwrap();
|
||||
let updated_vals = updated_data.as_slice::<f32>().unwrap();
|
||||
|
||||
for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
|
||||
assert!(
|
||||
(orig - upd).abs() < 1e-6,
|
||||
"With zero gradient, tensor should remain unchanged (or very close)"
|
||||
);
|
||||
}
|
||||
|
||||
// Verify no NaN values
|
||||
for val in updated_vals {
|
||||
assert!(
|
||||
!val.is_nan(),
|
||||
"Result should not contain NaN values with zero gradient"
|
||||
);
|
||||
}
|
||||
|
||||
// Test with weight decay - should still work
|
||||
let muon_with_decay: Muon<TestBackend> = Muon {
|
||||
momentum: Momentum::new(&config.momentum),
|
||||
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
|
||||
weight_decay_penalty: Some(0.01),
|
||||
epsilon: config.epsilon,
|
||||
adjust_lr_fn: config.adjust_lr_fn,
|
||||
};
|
||||
|
||||
let tensor2 = Tensor::<TestBackend, 2>::from_floats(
|
||||
[
|
||||
[1.0, 0.5, 0.3, 0.2],
|
||||
[0.5, 1.0, 0.4, 0.1],
|
||||
[0.3, 0.4, 1.0, 0.5],
|
||||
[0.2, 0.1, 0.5, 1.0],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
|
||||
|
||||
let (updated_tensor_decay, _) =
|
||||
muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
|
||||
|
||||
// With zero gradient but with weight decay, tensor should be slightly reduced
|
||||
let updated_decay_data = updated_tensor_decay.into_data();
|
||||
let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
|
||||
|
||||
for val in updated_decay_vals {
|
||||
assert!(
|
||||
!val.is_nan(),
|
||||
"Result should not contain NaN with zero gradient and weight decay"
|
||||
);
|
||||
}
|
||||
|
||||
// With weight decay, values should be slightly smaller than original
|
||||
let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
|
||||
for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
|
||||
if orig.abs() > 1e-6 {
|
||||
// Non-zero values should be reduced by weight decay
|
||||
assert!(
|
||||
upd.abs() < orig.abs(),
|
||||
"Weight decay should reduce magnitude: original={}, updated={}",
|
||||
orig,
|
||||
upd
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,566 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use burn::{module::AutodiffModule, record::Record};
|
||||
|
||||
use super::{
|
||||
SimpleOptimizer,
|
||||
adaptor::OptimizerAdaptor,
|
||||
decay::{WeightDecay, WeightDecayConfig},
|
||||
};
|
||||
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
|
||||
|
||||
use burn::config::Config;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
|
||||
|
||||
/// Configuration to create the [RmsProp](RmsProp) optimizer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct RmsPropConfig {
|
||||
/// Smoothing constant.
|
||||
#[config(default = 0.99)]
|
||||
alpha: f32,
|
||||
/// momentum for RmsProp.
|
||||
#[config(default = 0.9)]
|
||||
momentum: f32,
|
||||
/// A value required for numerical stability.
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f32,
|
||||
/// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
|
||||
#[config(default = false)]
|
||||
centered: bool,
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
weight_decay: Option<WeightDecayConfig>,
|
||||
/// [Gradient Clipping](GradientClippingConfig) config.
|
||||
grad_clipping: Option<GradientClippingConfig>,
|
||||
}
|
||||
|
||||
impl RmsPropConfig {
|
||||
/// Initialize RmsProp optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
&self,
|
||||
) -> OptimizerAdaptor<RmsProp, M, B> {
|
||||
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
|
||||
|
||||
let mut optim = OptimizerAdaptor::from(RmsProp {
|
||||
alpha: self.alpha,
|
||||
centered: self.centered,
|
||||
weight_decay,
|
||||
momentum: RmsPropMomentum {
|
||||
momentum: self.momentum,
|
||||
epsilon: self.epsilon,
|
||||
},
|
||||
});
|
||||
|
||||
if let Some(config) = &self.grad_clipping {
|
||||
optim = optim.with_grad_clipping(config.init());
|
||||
}
|
||||
|
||||
optim
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
|
||||
#[derive(Clone)]
|
||||
pub struct RmsProp {
|
||||
alpha: f32,
|
||||
// epsilon: f32,
|
||||
centered: bool,
|
||||
// momentum: Option<Momentum<B>>,
|
||||
momentum: RmsPropMomentum,
|
||||
weight_decay: Option<WeightDecay>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for RmsProp {
|
||||
type State<const D: usize> = RmsPropState<B, D>;
|
||||
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
lr: LearningRate,
|
||||
tensor: Tensor<B, D>,
|
||||
mut grad: Tensor<B, D>,
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||
// fetch state for params
|
||||
let mut state_square_avg = None;
|
||||
let mut state_centered = None;
|
||||
let mut state_momentum = None;
|
||||
if let Some(state) = state {
|
||||
state_square_avg = Some(state.square_avg);
|
||||
state_centered = Some(state.centered);
|
||||
state_momentum = state.momentum;
|
||||
}
|
||||
|
||||
// weight_decay transform
|
||||
if let Some(weight_decay) = &self.weight_decay {
|
||||
grad = weight_decay.transform(grad, tensor.clone());
|
||||
}
|
||||
|
||||
// square_avg transform
|
||||
let (grad, state_square_avg) =
|
||||
SquareAvgState::transform(self.alpha, grad, state_square_avg);
|
||||
|
||||
// centered transform
|
||||
let (grad, state_square_avg, state_centered) = CenteredState::transform(
|
||||
self.alpha,
|
||||
self.centered,
|
||||
grad,
|
||||
state_square_avg,
|
||||
state_centered,
|
||||
);
|
||||
|
||||
// momentum transform
|
||||
let (grad, state_centered, state_momentum) =
|
||||
self.momentum
|
||||
.transform(grad, state_centered, state_momentum);
|
||||
|
||||
// transition state
|
||||
let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
|
||||
|
||||
// tensor param transform
|
||||
let delta = grad.mul_scalar(lr);
|
||||
(tensor - delta, Some(state))
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
|
||||
state.square_avg = state.square_avg.to_device(device);
|
||||
state.centered = state.centered.to_device(device);
|
||||
state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
/// State of [RmsProp](RmsProp)
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct RmsPropState<B: Backend, const D: usize> {
|
||||
/// Current squared average state.
|
||||
pub square_avg: SquareAvgState<B, D>,
|
||||
/// Current centered state
|
||||
pub centered: CenteredState<B, D>,
|
||||
/// Current gradient momentum, if any.
|
||||
pub momentum: Option<RmsPropMomentumState<B, D>>,
|
||||
}
|
||||
|
||||
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct SquareAvgState<B: Backend, const D: usize> {
|
||||
/// Current squared average.
|
||||
pub square_avg: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> SquareAvgState<B, D> {
|
||||
/// transform [SquareAvgState] to the next step
|
||||
fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
|
||||
match state {
|
||||
Some(state) => {
|
||||
let square_avg = state
|
||||
.square_avg
|
||||
.mul_scalar(alpha)
|
||||
.add(grad.clone().square().mul_scalar(1. - alpha));
|
||||
(grad, Self { square_avg })
|
||||
}
|
||||
_ => {
|
||||
let square_avg = grad.clone().square().mul_scalar(1. - alpha);
|
||||
(grad, Self { square_avg })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.square_avg = self.square_avg.to_device(device);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// [CenteredState](CenteredState) is to store and pass optimizer step params.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct CenteredState<B: Backend, const D: usize> {
|
||||
/// The averaged gradient to calculate the centered gradient, if available.
|
||||
pub grad_avg: Option<Tensor<B, D>>,
|
||||
/// The current average value.
|
||||
pub avg: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> CenteredState<B, D> {
|
||||
/// transform [CenteredState] to the next step
|
||||
fn transform(
|
||||
alpha: f32,
|
||||
centered: bool,
|
||||
grad: Tensor<B, D>,
|
||||
square_avg_state: SquareAvgState<B, D>,
|
||||
centered_state: Option<Self>,
|
||||
) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
|
||||
if centered {
|
||||
let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
|
||||
let grad_avg = match centered_state {
|
||||
Some(state) => state
|
||||
.grad_avg
|
||||
.map_or(grad_avg_constant.clone(), move |grad_avg| {
|
||||
grad_avg.mul_scalar(alpha).add(grad_avg_constant)
|
||||
}),
|
||||
_ => grad_avg_constant,
|
||||
};
|
||||
let avg = square_avg_state
|
||||
.square_avg
|
||||
.clone()
|
||||
.sub(grad_avg.clone().square());
|
||||
|
||||
(
|
||||
grad,
|
||||
square_avg_state,
|
||||
Self {
|
||||
grad_avg: Some(grad_avg),
|
||||
avg,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
(
|
||||
grad,
|
||||
square_avg_state.clone(),
|
||||
Self {
|
||||
grad_avg: None,
|
||||
avg: square_avg_state.square_avg,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
|
||||
self.avg = self.avg.to_device(device);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
|
||||
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
|
||||
#[derive(Clone)]
|
||||
pub struct RmsPropMomentum {
|
||||
momentum: f32,
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl RmsPropMomentum {
|
||||
/// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
|
||||
fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
centered_state: CenteredState<B, D>,
|
||||
momentum_state: Option<RmsPropMomentumState<B, D>>,
|
||||
) -> (
|
||||
Tensor<B, D>,
|
||||
CenteredState<B, D>,
|
||||
Option<RmsPropMomentumState<B, D>>,
|
||||
) {
|
||||
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||
|
||||
if self.momentum > 0. {
|
||||
let buf = match momentum_state {
|
||||
Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
|
||||
_ => grad,
|
||||
};
|
||||
(
|
||||
buf.clone(),
|
||||
centered_state,
|
||||
Some(RmsPropMomentumState { buf }),
|
||||
)
|
||||
} else {
|
||||
(grad, centered_state, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct RmsPropMomentumState<B: Backend, const D: usize> {
|
||||
buf: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.buf = self.buf.to_device(device);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::tensor::ops::FloatElem;
|
||||
use burn::tensor::{Shape, Tolerance};
|
||||
|
||||
use super::*;
|
||||
use crate::TestAutodiffBackend;
|
||||
use crate::optim::{GradientsParams, Optimizer};
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{Distribution, Tensor, TensorData};
|
||||
use burn_nn::{Linear, LinearConfig, LinearRecord};
|
||||
|
||||
type FT = FloatElem<TestAutodiffBackend>;
|
||||
|
||||
const LEARNING_RATE: LearningRate = 0.01;
|
||||
|
||||
#[test]
|
||||
fn test_rmsprop_optimizer_save_load_state() {
|
||||
let device = Default::default();
|
||||
let linear = LinearConfig::new(6, 6).init(&device);
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
|
||||
let mut optimizer = create_rmsprop();
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
BinFileRecorder::<FullPrecisionSettings>::default()
|
||||
.record(
|
||||
optimizer.to_record(),
|
||||
std::env::temp_dir().as_path().join("test_optim_rmsprop"),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.record(optimizer.to_record(), ())
|
||||
.unwrap();
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
let state_optim_before = optimizer.to_record();
|
||||
let state_optim_before_copy = optimizer.to_record();
|
||||
let optimizer = create_rmsprop();
|
||||
let optimizer = optimizer.load_record(state_optim_before_copy);
|
||||
let state_optim_after = optimizer.to_record();
|
||||
|
||||
assert_eq!(state_optim_before.len(), state_optim_after.len());
|
||||
}
|
||||
|
||||
/// used for test differences and debug
|
||||
#[test]
|
||||
fn test_rmsprop_optimizer_with_numbers_basic() {
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
]),
|
||||
TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
);
|
||||
let device = Default::default();
|
||||
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = RmsPropConfig::new()
|
||||
.with_alpha(0.99)
|
||||
.with_epsilon(1e-8)
|
||||
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||
.with_momentum(0.9)
|
||||
.with_centered(false)
|
||||
.init();
|
||||
|
||||
// println!("linear is {:?}", linear);
|
||||
let grads = linear.forward(x_1).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
// println!("linear is {:?}", linear);
|
||||
let grads = linear.forward(x_2).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
// println!("linear is {:?}", linear);
|
||||
let state_updated = linear.into_record();
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.to_data(),
|
||||
state_updated.bias.unwrap().to_data(),
|
||||
);
|
||||
|
||||
// println!("\nweight_updated\n{:?}", weight_updated);
|
||||
// println!("\nbias_updated\n{:?}", bias_updated);
|
||||
|
||||
let weights_expected = TensorData::from([
|
||||
[0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
|
||||
[0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
|
||||
[0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
|
||||
[0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
|
||||
[0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
|
||||
[0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
|
||||
]);
|
||||
let bias_expected =
|
||||
TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
|
||||
|
||||
let tolerance = Tolerance::absolute(1e-6);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rmsprop_optimizer_with_numbers() {
|
||||
let linear = given_linear_layer(
|
||||
TensorData::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
let device = Default::default();
|
||||
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
|
||||
[
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = RmsPropConfig::new()
|
||||
.with_alpha(0.99)
|
||||
.with_epsilon(1e-8)
|
||||
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||
.with_momentum(0.9)
|
||||
.with_centered(false)
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x_1).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x_2).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
let weights_expected = TensorData::from([
|
||||
[
|
||||
-0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
|
||||
],
|
||||
[
|
||||
-0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
|
||||
],
|
||||
[
|
||||
-0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
|
||||
],
|
||||
[
|
||||
-0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
|
||||
],
|
||||
[
|
||||
0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
|
||||
],
|
||||
[
|
||||
-0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
|
||||
],
|
||||
]);
|
||||
let bias_expected = TensorData::from([
|
||||
-0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
|
||||
]);
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.to_data(),
|
||||
state_updated.bias.unwrap().to_data(),
|
||||
);
|
||||
|
||||
// println!("\nweight_updated\n{:?}", weight_updated);
|
||||
// println!("\nbias_updated\n{:?}", bias_updated);
|
||||
|
||||
let tolerance = Tolerance::absolute(1e-6);
|
||||
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
|
||||
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
|
||||
}
|
||||
|
||||
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
|
||||
let device = Default::default();
|
||||
let record = LinearRecord {
|
||||
weight: Param::from_data(weight, &device),
|
||||
bias: Some(Param::from_data(bias, &device)),
|
||||
};
|
||||
|
||||
LinearConfig::new(6, 6).init(&device).load_record(record)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
|
||||
Tensor::<TestAutodiffBackend, 2>::random(
|
||||
Shape::new([2, 20]),
|
||||
Distribution::Default,
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_rmsprop()
|
||||
-> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
|
||||
RmsPropConfig {
|
||||
alpha: 0.99,
|
||||
epsilon: 1e-9,
|
||||
centered: false,
|
||||
weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
|
||||
momentum: 0.9,
|
||||
grad_clipping: None,
|
||||
}
|
||||
.init()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::SimpleOptimizer;
|
||||
use super::adaptor::OptimizerAdaptor;
|
||||
use super::decay::{WeightDecay, WeightDecayConfig};
|
||||
use super::momentum::{Momentum, MomentumConfig, MomentumState};
|
||||
use crate::LearningRate;
|
||||
use crate::grad_clipping::GradientClippingConfig;
|
||||
use burn::config::Config;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
|
||||
/// Configuration to create the [Sgd](Sgd) optimizer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SgdConfig {
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
weight_decay: Option<WeightDecayConfig>,
|
||||
/// [Momentum](MomentumConfig) config.
|
||||
momentum: Option<MomentumConfig>,
|
||||
/// [Gradient Clipping](GradientClippingConfig) config.
|
||||
gradient_clipping: Option<GradientClippingConfig>,
|
||||
}
|
||||
|
||||
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||
///
|
||||
/// The optimizer can be configured with [SgdConfig](SgdConfig).
|
||||
#[derive(Clone)]
|
||||
pub struct Sgd<B: Backend> {
|
||||
momentum: Option<Momentum<B>>,
|
||||
weight_decay: Option<WeightDecay>,
|
||||
}
|
||||
|
||||
/// State of [Sgd](Sgd).
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct SgdState<B: Backend, const D: usize> {
|
||||
/// The current state of the momentum (if any).
|
||||
pub momentum: Option<MomentumState<B, D>>,
|
||||
}
|
||||
|
||||
impl SgdConfig {
|
||||
/// Creates a new [SgdConfig](SgdConfig) with default values.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
&self,
|
||||
) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
|
||||
let momentum = self.momentum.as_ref().map(Momentum::new);
|
||||
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
|
||||
|
||||
let mut optim = OptimizerAdaptor::from(Sgd {
|
||||
momentum,
|
||||
weight_decay,
|
||||
});
|
||||
if let Some(config) = &self.gradient_clipping {
|
||||
optim = optim.with_grad_clipping(config.init());
|
||||
}
|
||||
optim
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
|
||||
type State<const D: usize> = SgdState<B, D>;
|
||||
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
lr: LearningRate,
|
||||
tensor: Tensor<B, D>,
|
||||
mut grad: Tensor<B, D>,
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||
let mut state_momentum = None;
|
||||
|
||||
if let Some(state) = state {
|
||||
state_momentum = state.momentum;
|
||||
}
|
||||
|
||||
if let Some(weight_decay) = &self.weight_decay {
|
||||
grad = weight_decay.transform(grad, tensor.clone());
|
||||
}
|
||||
|
||||
if let Some(momentum) = &self.momentum {
|
||||
let (grad_out, state) = momentum.transform(grad, state_momentum);
|
||||
state_momentum = Some(state);
|
||||
grad = grad_out;
|
||||
}
|
||||
|
||||
let state = SgdState::new(state_momentum);
|
||||
let delta = grad.mul_scalar(lr);
|
||||
|
||||
(tensor - delta, Some(state))
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
|
||||
state.momentum = state.momentum.map(|state| state.to_device(device));
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
TestAutodiffBackend, TestBackend,
|
||||
grad_clipping::GradientClipping,
|
||||
optim::{GradientsParams, Optimizer},
|
||||
};
|
||||
use burn::tensor::{Distribution, Shape};
|
||||
use burn_nn::{Linear, LinearConfig};
|
||||
|
||||
const LEARNING_RATE: LearningRate = 0.02;
|
||||
|
||||
#[test]
|
||||
fn with_updated_params_should_have_state() {
|
||||
let device = Default::default();
|
||||
let layer = layer::<TestAutodiffBackend>(&device);
|
||||
let mut optim = sgd_with_all();
|
||||
let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &layer);
|
||||
let _layer = optim.step(LEARNING_RATE, layer, grads);
|
||||
|
||||
let record = optim.to_record();
|
||||
|
||||
assert!(!record.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn without_updated_params_should_not_have_state() {
|
||||
let optim = sgd_with_all();
|
||||
let record = optim.to_record();
|
||||
assert!(record.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_attach_gradient_clipping() {
|
||||
let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
|
||||
assert!(optim.has_gradient_clipping());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load_state() {
|
||||
let device = Default::default();
|
||||
let layer = layer::<TestAutodiffBackend>(&device);
|
||||
let mut optim = sgd_with_all();
|
||||
let loss = layer.forward(random_tensor(&device));
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &layer);
|
||||
let _layer = optim.step(LEARNING_RATE, layer, grads);
|
||||
|
||||
let record = optim.to_record();
|
||||
let optim_new = sgd_with_all();
|
||||
let record_new = optim_new.to_record();
|
||||
let optim_new = optim_new.load_record(record.clone());
|
||||
let state_restored = optim_new.to_record();
|
||||
|
||||
assert_ne!(record.len(), record_new.len());
|
||||
assert_eq!(record.len(), state_restored.len());
|
||||
}
|
||||
|
||||
fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
|
||||
Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
|
||||
}
|
||||
|
||||
fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
|
||||
LinearConfig::new(20, 20).init(device)
|
||||
}
|
||||
|
||||
fn sgd_with_all()
|
||||
-> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
|
||||
SgdConfig {
|
||||
weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
|
||||
momentum: Some(MomentumConfig {
|
||||
momentum: 0.9,
|
||||
dampening: 0.1,
|
||||
nesterov: true,
|
||||
}),
|
||||
gradient_clipping: None,
|
||||
}
|
||||
.init()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
use burn_core::{self as burn, prelude::Backend, tensor::Device};
|
||||
|
||||
use super::{SimpleOptimizer, record::AdaptorRecord};
|
||||
use crate::{
|
||||
LearningRate, MultiGradientsParams,
|
||||
grad_clipping::GradientClipping,
|
||||
optim::{GradientsParams, Optimizer},
|
||||
};
|
||||
|
||||
use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use core::marker::PhantomData;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
|
||||
/// an [optimizer](Optimizer).
|
||||
#[derive(Clone)]
|
||||
pub struct OptimizerAdaptor<O, M, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
M: AutodiffModule<B>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
optim: O,
|
||||
records: HashMap<ParamId, AdaptorRecord<O, B>>,
|
||||
module: PhantomData<M>,
|
||||
grad_clipping: Option<GradientClipping>,
|
||||
}
|
||||
|
||||
impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
fn from(optim: O) -> Self {
|
||||
Self {
|
||||
optim,
|
||||
records: HashMap::new(),
|
||||
module: PhantomData,
|
||||
grad_clipping: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, M, B> OptimizerAdaptor<O, M, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
M: AutodiffModule<B>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
/// Sets the gradient clipping.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `gradient_clipping` - The gradient clipping.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The optimizer.
|
||||
pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
|
||||
self.grad_clipping = Some(gradient_clipping);
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn has_gradient_clipping(&self) -> bool {
|
||||
self.grad_clipping.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
|
||||
|
||||
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {
|
||||
let mut grads = GradAdaptor::Single(grads);
|
||||
|
||||
let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
|
||||
&self.optim,
|
||||
&mut self.records,
|
||||
&mut grads,
|
||||
lr,
|
||||
self.grad_clipping.as_ref(),
|
||||
);
|
||||
module.map(&mut mapper)
|
||||
}
|
||||
|
||||
fn step_multi(&mut self, lr: LearningRate, module: M, grads: crate::MultiGradientsParams) -> M {
|
||||
let mut grads = GradAdaptor::Multi(grads);
|
||||
|
||||
let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
|
||||
&self.optim,
|
||||
&mut self.records,
|
||||
&mut grads,
|
||||
lr,
|
||||
self.grad_clipping.as_ref(),
|
||||
);
|
||||
module.map(&mut mapper)
|
||||
}
|
||||
|
||||
fn to_record(&self) -> Self::Record {
|
||||
self.records.clone()
|
||||
}
|
||||
|
||||
fn load_record(mut self, record: Self::Record) -> Self {
|
||||
self.records = record;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
enum GradAdaptor {
|
||||
Single(GradientsParams),
|
||||
Multi(MultiGradientsParams),
|
||||
}
|
||||
|
||||
impl GradAdaptor {
|
||||
fn remove<B: Backend, const D: usize>(
|
||||
&mut self,
|
||||
id: ParamId,
|
||||
) -> Option<(Tensor<B, D>, Device<B>)> {
|
||||
match self {
|
||||
GradAdaptor::Single(grads) => grads.remove(id).map(|t| {
|
||||
let device = t.device();
|
||||
(t, device)
|
||||
}),
|
||||
GradAdaptor::Multi(grads) => grads.remove(id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct SimpleOptimizerMapper<'a, M, B, O>
|
||||
where
|
||||
M: AutodiffModule<B>,
|
||||
B: AutodiffBackend,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
optimizer: &'a O,
|
||||
records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
|
||||
grads: &'a mut GradAdaptor,
|
||||
lr: LearningRate,
|
||||
phantom: PhantomData<M>,
|
||||
grad_clipping: Option<&'a GradientClipping>,
|
||||
}
|
||||
|
||||
impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
|
||||
where
|
||||
M: AutodiffModule<B>,
|
||||
B: AutodiffBackend,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
let grad = self.grads.remove(id);
|
||||
|
||||
let tensor = if let Some((grad, device)) = grad {
|
||||
let is_require_grad = tensor.is_require_grad();
|
||||
let (key, record) = self.records.remove_entry(&id).unzip();
|
||||
let tensor = if tensor.device() != device {
|
||||
tensor.to_device(&device)
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
|
||||
debug_assert_eq!(
|
||||
grad.device(),
|
||||
device,
|
||||
"The gradient is on the provided device"
|
||||
);
|
||||
let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
|
||||
g_clipping.clip_gradient(grad)
|
||||
} else {
|
||||
grad
|
||||
};
|
||||
|
||||
debug_assert_eq!(
|
||||
tensor.device(),
|
||||
device,
|
||||
"Tensor and gradients are on the same device."
|
||||
);
|
||||
|
||||
let (tensor, state) = self.optimizer.step(
|
||||
self.lr,
|
||||
tensor.inner(),
|
||||
clipped_grad,
|
||||
record.map(|record| O::to_device(record.into_state(), &device)),
|
||||
);
|
||||
|
||||
if let Some(state) = state {
|
||||
self.records
|
||||
.insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
|
||||
}
|
||||
|
||||
let mut tensor = Tensor::from_inner(tensor);
|
||||
if is_require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
tensor
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
|
||||
Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::LearningRate;
|
||||
use burn::record::Record;
|
||||
use burn::tensor::{Tensor, backend::Backend};
|
||||
|
||||
/// Simple optimizer is an opinionated trait to simplify the process of implementing an
|
||||
/// optimizer.
|
||||
///
|
||||
/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
|
||||
/// module parameter structure, handle tracked and untracked tensors, and the likes.
|
||||
pub trait SimpleOptimizer<B>: Send + Sync + Clone
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// The state of the optimizer. It also implements [record](Record), so that it can be saved.
|
||||
type State<const D: usize>: Record<B> + Clone + 'static;
|
||||
|
||||
/// The optimizer step is performed for one tensor at a time with its gradient and state.
|
||||
///
|
||||
/// Note that the state is passed as parameter, so implementations don't have to handle
|
||||
/// the saving and loading of recorded states.
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
lr: LearningRate,
|
||||
tensor: Tensor<B, D>,
|
||||
grad: Tensor<B, D>,
|
||||
state: Option<Self::State<D>>,
|
||||
) -> (Tensor<B, D>, Option<Self::State<D>>);
|
||||
|
||||
/// Change the device of the state.
|
||||
///
|
||||
/// This function will be called accordingly to have the state on the same device as the
|
||||
/// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.
|
||||
fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod base;
|
||||
pub use base::*;
|
||||
|
||||
/// Adaptor module for optimizers.
|
||||
pub mod adaptor;
|
||||
|
||||
/// Record module for optimizers.
|
||||
pub mod record;
|
||||
@@ -0,0 +1,93 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::{AdaptorRecordItemV1, AdaptorRecordV1};
|
||||
use crate::optim::SimpleOptimizer;
|
||||
use burn::record::{PrecisionSettings, Record};
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.
|
||||
///
|
||||
/// Records are versioned for backward compatibility, so old records can be loaded.
|
||||
pub enum AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
/// Version 1.
|
||||
V1(AdaptorRecordV1<O, B::InnerBackend>),
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItem<
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
S: PrecisionSettings,
|
||||
> {
|
||||
/// Version 1.
|
||||
V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
|
||||
}
|
||||
|
||||
impl<O, B> Record<B> for AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
match self {
|
||||
AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, B> Clone for AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
AdaptorRecord::V1(record) => Self::V1(record.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, B> AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
/// Converts the record into the optimizer state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The optimizer state.
|
||||
pub fn into_state<const D: usize>(self) -> O::State<D> {
|
||||
match self {
|
||||
AdaptorRecord::V1(record) => record.into_state(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the optimizer state into the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state`: The optimizer state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
|
||||
Self::V1(AdaptorRecordV1::from_state(state))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod v1;
|
||||
|
||||
pub use base::*;
|
||||
pub use v1::*;
|
||||
@@ -0,0 +1,201 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use crate::optim::SimpleOptimizer;
|
||||
use burn::record::{PrecisionSettings, Record};
|
||||
use burn::tensor::backend::Backend;
|
||||
use core::any::Any;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use alloc::boxed::Box;
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
|
||||
/// Rank 0.
|
||||
Rank0(O::State<0>),
|
||||
|
||||
/// Rank 1.
|
||||
Rank1(O::State<1>),
|
||||
|
||||
/// Rank 2.
|
||||
Rank2(O::State<2>),
|
||||
|
||||
/// Rank 3.
|
||||
Rank3(O::State<3>),
|
||||
|
||||
/// Rank 4.
|
||||
Rank4(O::State<4>),
|
||||
|
||||
/// Rank 5.
|
||||
Rank5(O::State<5>),
|
||||
|
||||
/// Rank 6.
|
||||
Rank6(O::State<6>),
|
||||
|
||||
/// Rank 7.
|
||||
Rank7(O::State<7>),
|
||||
|
||||
/// Rank 8.
|
||||
Rank8(O::State<8>),
|
||||
}
|
||||
|
||||
impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),
|
||||
AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),
|
||||
AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),
|
||||
AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),
|
||||
AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()),
|
||||
AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()),
|
||||
AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()),
|
||||
AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()),
|
||||
AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
/// Rank 0.
|
||||
Rank0(<O::State<0> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 1.
|
||||
Rank1(<O::State<1> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 2.
|
||||
Rank2(<O::State<2> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 3.
|
||||
Rank3(<O::State<3> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 4.
|
||||
Rank4(<O::State<4> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 5.
|
||||
Rank5(<O::State<5> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 6.
|
||||
Rank6(<O::State<6> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 7.
|
||||
Rank7(<O::State<7> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 8.
|
||||
Rank8(<O::State<8> as Record<B>>::Item<S>),
|
||||
}
|
||||
|
||||
impl<O, B> AdaptorRecordV1<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
{
|
||||
/// Convert the record into the state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The state.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the state dimension is not supported.
|
||||
pub fn into_state<const D: usize>(self) -> O::State<D> {
|
||||
let boxed_state: Box<dyn Any> = match self {
|
||||
AdaptorRecordV1::Rank0(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank1(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank2(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank3(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank4(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank5(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank6(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank7(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank8(s) => Box::new(s),
|
||||
};
|
||||
let state = boxed_state
|
||||
.downcast::<O::State<D>>()
|
||||
.expect("Unsupported state dimension, dimension up to 8 are supported.");
|
||||
*state
|
||||
}
|
||||
|
||||
/// Convert the state into the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state`: The state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
|
||||
let state: Box<dyn Any> = Box::new(state);
|
||||
|
||||
match D {
|
||||
0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),
|
||||
1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),
|
||||
2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),
|
||||
3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),
|
||||
4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()),
|
||||
5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()),
|
||||
6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()),
|
||||
7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()),
|
||||
8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()),
|
||||
_ => panic!("Unsupported state dimension, dimension up to 8 are supported."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, B> Record<B> for AdaptorRecordV1<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = AdaptorRecordItemV1<O, B, S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
match self {
|
||||
AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),
|
||||
AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),
|
||||
AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),
|
||||
AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),
|
||||
AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()),
|
||||
AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()),
|
||||
AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()),
|
||||
AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()),
|
||||
AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
AdaptorRecordItemV1::Rank0(item) => {
|
||||
AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank1(item) => {
|
||||
AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank2(item) => {
|
||||
AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank3(item) => {
|
||||
AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank4(item) => {
|
||||
AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank5(item) => {
|
||||
AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank6(item) => {
|
||||
AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank7(item) => {
|
||||
AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank8(item) => {
|
||||
AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use burn_core as burn;
|
||||
|
||||
use super::GradientsParams;
|
||||
use burn::module::{AutodiffModule, ModuleVisitor, Param, ParamId};
|
||||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GradientsParamsConverter<'a, M: AutodiffModule<B>, B: AutodiffBackend> {
|
||||
grads: &'a mut B::Gradients,
|
||||
grads_params: &'a mut GradientsParams,
|
||||
phatom: PhantomData<M>,
|
||||
filter: Option<Vec<ParamId>>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GradientsParamsChangeDevice<'a, M: AutodiffModule<B>, B: AutodiffBackend> {
|
||||
device: &'a B::Device,
|
||||
grads: &'a mut GradientsParams,
|
||||
phatom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<B, M> ModuleVisitor<B> for GradientsParamsConverter<'_, M, B>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
{
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
if let Some(filter) = self.filter.as_ref()
|
||||
&& !filter.contains(¶m.id)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(grad) = param.val().grad_remove(self.grads) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.grads_params
|
||||
.register::<B::InnerBackend, D>(param.id, grad);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, M> ModuleVisitor<B> for GradientsParamsChangeDevice<'_, M, B>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
{
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
let Some(grad) = self.grads.remove::<B::InnerBackend, D>(param.id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.grads
|
||||
.register::<B::InnerBackend, D>(param.id, grad.to_device(self.device));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user