feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,151 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Flexible and Comprehensive Deep Learning Framework in Rust"
documentation = "https://docs.rs/burn-core"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-core"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-core"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"std",
"burn-std/default",
"burn-dataset?/default",
"burn-tensor/default",
]
doc = [
"std",
"dataset",
"audio",
# Doc features
"burn-std/doc",
"burn-dataset/doc",
"burn-tensor/doc",
]
tracing = [
"burn-std/tracing",
"burn-tensor/tracing",
"burn-dataset?/tracing",
"burn-vision?/tracing",
]
dataset = ["burn-dataset"]
network = ["burn-std/network"]
sqlite = ["burn-dataset?/sqlite"]
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
std = [
"bincode/std",
"burn-std/std",
"burn-tensor/std",
"flate2",
"half/std",
"log",
"rand/std",
"rmp-serde",
"serde/std",
"serde_json/std",
"num-traits/std",
]
vision = ["burn-vision", "burn-dataset?/vision"]
audio = ["burn-dataset?/audio"]
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
test-cuda = [
"burn-cuda/default",
] # To use cuda during testing, default uses ndarray.
test-rocm = [
"burn-rocm/default",
] # To use hip during testing, default uses ndarray.
test-tch = [
"burn-tch/default",
] # To use tch during testing, default uses ndarray.
test-wgpu = [
"burn-wgpu/default",
] # To use wgpu during testing, default uses ndarray.
test-vulkan = [
"test-wgpu",
"burn-wgpu/vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.
test-metal = [
"test-wgpu",
"burn-wgpu/metal",
] # To use wgpu-spirv during testing, default uses ndarray.
# Memory checks are disabled by default
test-memory-checks = ["burn-fusion/memory-checks"]
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
burn-dataset = { path = "../burn-dataset", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-derive = { path = "../burn-derive", version = "=0.21.0-pre.2" }
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
burn-vision = { path = "../burn-vision", version = "=0.21.0-pre.2", optional = true, default-features = false }
data-encoding = { workspace = true }
uuid = { workspace = true }
derive-new = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true }
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
# Serialize Deserialize
flate2 = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
ahash = { workspace = true }
bincode = { workspace = true }
half = { workspace = true }
num-traits = { workspace = true }
rmp-serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
thiserror = { workspace = true, optional = true }
[target.'cfg(target_has_atomic = "ptr")'.dependencies]
regex = { workspace = true }
# FOR TESTING
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
portable-atomic = { workspace = true }
[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
burn-dataset = { path = "../burn-dataset", version = "=0.21.0-pre.2", features = [
"fake",
] }
rstest = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,15 @@
# Burn Core
This crate should be used with [burn](https://github.com/tracel-ai/burn). It contains the core
traits and components for building and training deep learning models with Burn.
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-core.svg)](https://crates.io/crates/burn-core)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-core/blob/master/README.md)
## Feature Flags
This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the
default `std` feature.
- `std` - enables the standard library. Enabled by default.
- `experimental-named-tensor` - enables experimental named tensor.

View File

@@ -0,0 +1,98 @@
use alloc::{format, string::String, string::ToString};
pub use burn_derive::Config;
use core::fmt::Debug;
/// Configuration IO error.
#[derive(Debug)]
pub enum ConfigError {
/// Invalid format.
InvalidFormat(String),
/// File not found.
FileNotFound(String),
}
impl core::fmt::Display for ConfigError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut message = "Config error => ".to_string();
match self {
Self::InvalidFormat(err) => {
message += format!("Invalid format: {err}").as_str();
}
Self::FileNotFound(err) => {
message += format!("File not found: {err}").as_str();
}
};
f.write_str(message.as_str())
}
}
impl core::error::Error for ConfigError {}
/// Configuration trait.
pub trait Config: Debug + serde::Serialize + serde::de::DeserializeOwned {
/// Saves the configuration to a file.
///
/// # Arguments
///
/// * `file` - File to save the configuration to.
///
/// # Returns
///
/// The output of the save operation.
#[cfg(feature = "std")]
fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
std::fs::write(file, config_to_json(self))
}
/// Loads the configuration from a file.
///
/// # Arguments
///
/// * `file` - File to load the configuration from.
///
/// # Returns
///
/// The loaded configuration.
#[cfg(feature = "std")]
fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(file.as_ref())
.map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
config_from_str(&content)
}
/// Loads the configuration from a binary buffer.
///
/// # Arguments
///
/// * `data` - Binary buffer to load the configuration from.
///
/// # Returns
///
/// The loaded configuration.
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
let content = core::str::from_utf8(data).map_err(|_| {
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
})?;
config_from_str(content)
}
}
/// Converts a configuration to a JSON string.
///
/// # Arguments
///
/// * `config` - Configuration to convert.
///
/// # Returns
///
/// The JSON string.
pub fn config_to_json<C: Config>(config: &C) -> String {
serde_json::to_string_pretty(config).unwrap()
}
fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
}

View File

@@ -0,0 +1,49 @@
use burn_tensor::backend::Backend;
pub use crate::data::dataset::{Dataset, DatasetIterator};
use core::iter::Iterator;
use std::sync::Arc;
/// A progress struct that can be used to track the progress of a data loader.
#[derive(new, Clone, Debug)]
pub struct Progress {
/// The number of items that have been processed.
pub items_processed: usize,
/// The total number of items that need to be processed.
pub items_total: usize,
}
/// A data loader iterator that can be used to iterate over a data loader.
pub trait DataLoaderIterator<O>: Iterator<Item = O> {
/// Returns the progress of the data loader.
fn progress(&self) -> Progress;
}
/// A data loader that can be used to iterate over a dataset.
pub trait DataLoader<B: Backend, O>: Send + Sync {
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
/// The number of items (not the number of batches nor the number of iterations),
/// corresponding to the items_total of the progress returned by the iterator.
fn num_items(&self) -> usize;
/// Move the data loader to the given device, ensuring the batches are assigned to the correct device.
fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>>;
/// Returns a new data loader containing a subset of the data.
///
/// The subset includes items from `start` (inclusive) to `end` (exclusive),
/// preserving the batch size and ordering of the original data loader.
///
/// # Arguments
///
/// * `start` - The starting index of the subset (inclusive).
/// * `end` - The ending index of the subset (exclusive).
///
/// # Returns
///
/// A boxed [`DataLoader`] instance containing only the specified range.
fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>>;
}

View File

@@ -0,0 +1,259 @@
use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
use burn_dataset::{
Dataset,
transform::{PartialDataset, ShuffledDataset},
};
use burn_tensor::backend::Backend;
use rand::SeedableRng;
use std::ops::DerefMut;
use std::sync::Arc;
/// A data loader that can be used to iterate over a dataset in batches.
pub struct BatchDataLoader<B: Backend, I, O> {
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
device: B::Device,
rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
}
impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
fn clone(&self) -> Self {
Self {
strategy: self.strategy.clone_dyn(),
dataset: self.dataset.clone(),
batcher: self.batcher.clone(),
device: self.device.clone(),
rng: self.rng.clone(),
}
}
}
impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
/// Creates a new batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
/// * `device` - The device to use when loading a batch.
/// * `rng` - The rng determining if the dataset is shuffled each time a dataloader
/// iterator is created.
///
/// # Returns
///
/// The batch data loader.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
device: B::Device,
rng: Option<rand::rngs::StdRng>,
) -> Self {
Self {
strategy,
dataset,
batcher,
device,
rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
}
}
}
/// A data loader iterator that can be used to iterate over a data loader.
struct BatchDataloaderIterator<B: Backend, I, O> {
current_index: usize,
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
device: B::Device,
}
impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
where
B: Backend,
I: Send + Sync + Clone + 'static,
O: Send + 'static,
{
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
// When starting a new iteration, we first check if the dataloader was created with an rng,
// implying that we should shuffle the dataset beforehand, while advancing the current
// rng to ensure that each new iteration shuffles the dataset differently.
let dataset = match &self.rng {
Some(rng) => Arc::new(ShuffledDataset::new(
self.dataset.clone(),
rng.lock().deref_mut(),
)),
None => self.dataset.clone(),
};
Box::new(BatchDataloaderIterator::new(
self.strategy.clone_dyn(),
dataset,
self.batcher.clone(),
self.device.clone(),
))
}
fn num_items(&self) -> usize {
self.dataset.len()
}
fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
let rng = self.rng.as_ref().map(|rng| {
let mut rng = rng.lock();
rng.fork()
});
Arc::new(Self::new(
self.strategy.clone_dyn(),
self.dataset.clone(),
self.batcher.clone(),
device.clone(),
rng,
))
}
fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
let rng = self.rng.as_ref().map(|rng| {
let mut rng = rng.lock();
rng.fork()
});
let dataloader = Self::new(
self.strategy.clone_dyn(),
Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
self.batcher.clone(),
self.device.clone(),
rng,
);
Arc::new(dataloader)
}
}
impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
/// Creates a new batch data loader iterator.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
/// * `device` - The device to use when loading a batch.
///
/// # Returns
///
/// The batch data loader iterator.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
device: B::Device,
) -> Self {
BatchDataloaderIterator {
current_index: 0,
strategy,
dataset,
batcher,
device,
}
}
}
impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
type Item = O;
fn next(&mut self) -> Option<O> {
while let Some(item) = self.dataset.get(self.current_index) {
self.current_index += 1;
self.strategy.add(item);
if let Some(items) = self.strategy.batch(false) {
return Some(self.batcher.batch(items, &self.device));
}
}
if let Some(items) = self.strategy.batch(true) {
return Some(self.batcher.batch(items, &self.device));
}
None
}
}
impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
fn progress(&self) -> Progress {
Progress::new(self.current_index, self.dataset.len())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
use crate::data::dataloader::FixBatchStrategy;
use crate::data::dataloader::batcher::TestBatcher;
use crate::data::dataset::FakeDataset;
#[test]
fn test_batch_dataloader() {
let batcher = Arc::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(27));
let dataloader = BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset.clone(),
batcher,
Default::default(),
None,
);
let mut items_dataset = HashSet::new();
let mut items_dataloader = HashSet::new();
for item in dataset.iter() {
items_dataset.insert(item);
}
for items in dataloader.iter() {
for item in items {
items_dataloader.insert(item);
}
}
assert_eq!(items_dataset, items_dataloader);
}
#[test]
fn test_batch_dataloader_slice() {
let batcher = Arc::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(27));
let dataloader = BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset.clone(),
batcher,
Default::default(),
None,
);
let dataloader_slice = dataloader.slice(5, 15);
let mut items_dataloader = HashSet::new();
let mut items_dataloader_slice = HashSet::new();
let mut idx = 0;
for items in dataloader.iter() {
for item in items {
if (5..15).contains(&idx) {
items_dataloader.insert(item);
}
idx += 1;
}
}
for items in dataloader_slice.iter() {
for item in items {
items_dataloader_slice.insert(item);
}
}
assert_eq!(items_dataloader, items_dataloader_slice);
}
}

View File

@@ -0,0 +1,31 @@
use burn_tensor::backend::Backend;
#[cfg(test)]
use crate::TestBackend;
/// A trait for batching items of type `I` into items of type `O`.
pub trait Batcher<B: Backend, I, O>: Send + Sync {
/// Batches the given items on the specified device.
///
/// # Arguments
///
/// * `items` - The items to batch.
/// * `device` - The backend device to use.
///
/// # Returns
///
/// The batched items.
fn batch(&self, items: Vec<I>, device: &B::Device) -> O;
}
/// Test batcher
#[cfg(test)]
#[derive(new, Clone)]
pub struct TestBatcher;
#[cfg(test)]
impl<I> Batcher<TestBackend, I, Vec<I>> for TestBatcher {
fn batch(&self, items: Vec<I>, _device: &<TestBackend as Backend>::Device) -> Vec<I> {
items
}
}

View File

@@ -0,0 +1,260 @@
use super::{
BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
batcher::Batcher,
};
use burn_dataset::Dataset;
use burn_tensor::backend::Backend;
use rand::{SeedableRng, rngs::StdRng};
use std::sync::Arc;
/// A builder for data loaders.
pub struct DataLoaderBuilder<B: Backend, I, O> {
strategy: Option<Box<dyn BatchStrategy<I>>>,
batcher: Arc<dyn Batcher<B, I, O>>,
num_threads: Option<usize>,
shuffle: Option<u64>,
device: Option<B::Device>,
}
impl<B, I, O> DataLoaderBuilder<B, I, O>
where
B: Backend,
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Clone + std::fmt::Debug + 'static,
{
/// Creates a new data loader builder.
///
/// # Arguments
///
/// * `batcher` - The batcher.
///
/// # Returns
///
/// The data loader builder.
pub fn new<Bt>(batcher: Bt) -> Self
where
Bt: Batcher<B, I, O> + 'static,
{
Self {
batcher: Arc::new(batcher),
strategy: None,
num_threads: None,
shuffle: None,
device: None,
}
}
/// Sets the batch size to a fix number.
///
/// The [fix batch strategy](FixBatchStrategy) will be used.
///
/// # Arguments
///
/// * `batch_size` - The batch size.
///
/// # Returns
///
/// The data loader builder.
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
self
}
/// Sets the seed for shuffling.
///
/// Each time the dataloader starts a new iteration, the dataset will be shuffled.
///
/// # Arguments
///
/// * `seed` - The seed.
///
/// # Returns
///
/// The data loader builder.
pub fn shuffle(mut self, seed: u64) -> Self {
self.shuffle = Some(seed);
self
}
/// Sets the number of workers.
///
/// - `Some(0)` or `None`: the dataloader will run without work threads.
/// - `Some(n); n > 0`: the dataloader will run with `n` background threads.
///
/// A 1-worker threaded dataloader will run loads in a background thread,
/// while a 0-worker threaded dataloader will run loads in the main thread.
///
/// # Arguments
///
/// * `num_workers` - The number of workers.
///
/// # Returns
///
/// The data loader builder.
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_threads = Some(num_workers);
self
}
/// Sets the data loader device.
///
/// # Arguments
///
/// * `device` - The device to use when loading a batch.
///
/// # Returns
///
/// The data loader builder.
pub fn set_device(mut self, device: B::Device) -> Self {
self.device = Some(device);
self
}
/// Builds the data loader.
///
/// # Arguments
///
/// * `dataset` - The dataset.
///
/// # Returns
///
/// The data loader.
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
where
D: Dataset<I> + 'static,
{
let dataset = Arc::new(dataset);
let device = self.device.unwrap_or_default();
let rng = self.shuffle.map(StdRng::seed_from_u64);
let strategy = match self.strategy {
Some(strategy) => strategy,
None => Box::new(FixBatchStrategy::new(1)),
};
if let Some(num_threads) = self.num_threads
&& num_threads > 0
{
return Arc::new(MultiThreadDataLoader::new(
strategy,
dataset,
self.batcher,
num_threads,
device,
rng,
));
}
Arc::new(BatchDataLoader::new(
strategy,
dataset,
self.batcher,
device,
rng,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use crate::data::dataset::FakeDataset;
#[derive(new, Clone)]
struct TestBatcherDevice;
#[cfg(test)]
impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
*device
}
}
type TestDevice = <TestBackend as Backend>::Device;
#[test]
fn test_dataloader_no_workers() {
type TestDevice = <TestBackend as Backend>::Device;
let default_device = TestDevice::default();
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
.batch_size(1)
.build(FakeDataset::<String>::new(9));
assert_eq!(dataloader.num_items(), 9);
for device in dataloader.iter() {
assert_eq!(device, default_device)
}
}
#[test]
fn test_dataloader_default_device() {
let default_device = TestDevice::default();
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
.batch_size(1)
.num_workers(1)
.build(FakeDataset::<String>::new(9));
assert_eq!(dataloader.num_items(), 9);
for device in dataloader.iter() {
assert_eq!(device, default_device)
}
}
#[test]
fn test_dataloader_slice_multi_device() {
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
.batch_size(1)
.num_workers(1)
.build(FakeDataset::<String>::new(11));
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda")
))]
// Only one device exists...
let (device1, device2) = (
burn_ndarray::NdArrayDevice::Cpu,
burn_ndarray::NdArrayDevice::Cpu,
);
#[cfg(all(test, feature = "test-tch"))]
let (device1, device2) = (
burn_tch::LibTorchDevice::Cuda(0),
burn_tch::LibTorchDevice::Cuda(1),
);
#[cfg(all(test, feature = "test-wgpu"))]
let (device1, device2) = (
burn_wgpu::WgpuDevice::DiscreteGpu(0),
burn_wgpu::WgpuDevice::DiscreteGpu(1),
);
#[cfg(all(test, feature = "test-cuda"))]
let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
assert_eq!(dataloader.num_items(), 11);
let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
assert_eq!(dataloader_1.num_items(), 5);
assert_eq!(dataloader_2.num_items(), 6);
let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
for _ in 0..5 {
assert_eq!(iterator_1.next(), Some(device1));
assert_eq!(iterator_2.next(), Some(device2));
}
assert_eq!(iterator_1.next(), None);
// For uneven split, the last dataloader (partial dataset) will have the remaining item
assert_eq!(iterator_2.next(), Some(device2));
assert_eq!(iterator_2.next(), None);
}
}

View File

@@ -0,0 +1,16 @@
mod base;
mod batch;
mod builder;
mod multithread;
mod strategy;
/// Module for batching items.
pub mod batcher;
/// Module to split a dataloader.
pub mod split;
pub use base::*;
pub use batch::*;
pub use builder::*;
pub use multithread::*;
pub use strategy::*;

View File

@@ -0,0 +1,441 @@
use burn_dataset::Dataset;
use burn_dataset::transform::PartialDataset;
use burn_tensor::backend::Backend;
use rand::distr::{Distribution, StandardUniform};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use super::batcher::Batcher;
use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
use std::sync::{Arc, OnceLock, mpsc};
use std::thread;
const MAX_QUEUED_ITEMS: usize = 100;
type RngSeed = <StdRng as SeedableRng>::Seed;
/// A multi-threaded data loader that can be used to iterate over a dataset.
pub struct MultiThreadDataLoader<B: Backend, I, O> {
// Configuration parameters needed for initialization
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
device: B::Device,
seed: Option<RngSeed>,
num_threads: usize,
// The lazily initialized data loaders
dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,
}
/// A message that can be sent between threads.
#[derive(Debug)]
pub enum Message<O> {
/// A batch of items.
Batch(usize, O, Progress),
/// The thread is done.
Done,
}
struct MultiThreadsDataloaderIterator<O> {
num_done: usize,
workers: Vec<thread::JoinHandle<()>>,
receiver: mpsc::Receiver<Message<O>>,
progresses: Vec<Progress>,
}
impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + 'static,
{
/// Creates a new multi-threaded batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
/// * `num_threads` - The number of threads.
/// * `device` - The device to use when loading a batch.
/// * `rng` - The rng determining if the dataset is shuffled each time a dataloader
/// iterator is created.
///
/// # Returns
///
/// The multi-threaded batch data loader.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
num_threads: usize,
device: B::Device,
rng: Option<rand::rngs::StdRng>,
) -> Self {
let mut seed = None;
if let Some(mut rng) = rng {
// RNG stream splitting (not state cloning): derive a new seed from the RNG's output.
// This is exactly what `rng.fork()` does.
let mut s = RngSeed::default();
rng.fill_bytes(&mut s);
seed = Some(s);
}
Self::from_seed(strategy, dataset, batcher, num_threads, device, seed)
}
fn from_seed(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<B, I, O>>,
num_threads: usize,
device: B::Device,
seed: Option<RngSeed>,
) -> Self {
Self {
strategy,
dataset,
batcher,
num_threads,
device,
seed,
dataloaders: OnceLock::new(),
}
}
/// Force initialization if needed.
fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
self.dataloaders
.get_or_init(|| {
let mut dataset = self.dataset.clone();
if let Some(seed) = self.seed.as_ref() {
// Pre-shuffle the dataset before split if shuffle is enabled.
// This ensures that each thread gets a uniform random sample of the dataset.
let mut rng = StdRng::from_seed(*seed);
dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(
dataset, &mut rng,
));
}
let datasets = match self.strategy.batch_size() {
Some(batch_size) => {
PartialDataset::split_chunks(dataset, self.num_threads, batch_size)
}
None => PartialDataset::split(dataset, self.num_threads),
};
// Create more rngs from the first one, one for each new dataloader.
let mut rng = self.seed.map(StdRng::from_seed);
let rngs = (0..self.num_threads).map(|_| {
rng.as_mut().map(|rng| {
StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
})
});
datasets
.into_iter()
.zip(rngs)
.map(|(dataset, rng)| {
let strategy = self.strategy.clone_dyn();
BatchDataLoader::new(
strategy,
Arc::new(dataset),
self.batcher.clone(),
self.device.clone(),
rng,
)
})
.collect()
})
.as_ref()
}
}
impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + 'static + std::fmt::Debug,
{
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
// This will initialize the loader if it hasn't been initialized yet
let dataloaders = self.initialize();
let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
let mut progresses = Vec::with_capacity(dataloaders.len());
let handlers: Vec<_> = dataloaders
.iter()
.enumerate()
.map(|(index, dataloader)| {
let dataloader_cloned = dataloader.clone();
let sender_cloned = sender.clone();
progresses.push(Progress::new(0, dataloader_cloned.num_items()));
thread::spawn(move || {
let mut iterator = dataloader_cloned.iter();
while let Some(item) = iterator.next() {
let progress = iterator.progress();
match sender_cloned.send(Message::Batch(index, item, progress)) {
Ok(_) => {}
// The receiver is probably gone, no need to panic, just need to stop
// iterating.
Err(_) => return,
};
}
// Same thing.
sender_cloned.send(Message::Done).ok();
})
})
.collect();
Box::new(MultiThreadsDataloaderIterator::new(
receiver, handlers, progresses,
))
}
fn num_items(&self) -> usize {
// For num_items, we can directly use the dataset size without
// necessarily initializing the full loader
self.dataset.len()
}
fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
Arc::new(Self::from_seed(
self.strategy.clone_dyn(),
self.dataset.clone(),
self.batcher.clone(),
self.num_threads,
device.clone(),
self.seed,
))
}
fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
let dataloader = Self::from_seed(
self.strategy.clone_dyn(),
Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
self.batcher.clone(),
self.num_threads,
self.device.clone(),
self.seed,
);
Arc::new(dataloader)
}
}
impl<O> MultiThreadsDataloaderIterator<O> {
pub fn new(
receiver: mpsc::Receiver<Message<O>>,
workers: Vec<thread::JoinHandle<()>>,
progresses: Vec<Progress>,
) -> Self {
MultiThreadsDataloaderIterator {
num_done: 0,
workers,
receiver,
progresses,
}
}
}
impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
fn progress(&self) -> Progress {
let mut items_total = 0;
let mut items_processed = 0;
for progress in self.progresses.iter() {
items_total += progress.items_total;
items_processed += progress.items_processed;
}
Progress::new(items_processed, items_total)
}
}
impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
type Item = O;
fn next(&mut self) -> Option<O> {
if self.workers.is_empty() {
return None;
}
loop {
let item = self.receiver.recv();
let item = item.unwrap();
match item {
Message::Batch(index, item, progress) => {
if let Some(current) = self.progresses.get_mut(index) {
*current = progress;
}
return Some(item);
}
Message::Done => {
self.num_done += 1;
}
};
if self.num_done == self.workers.len() {
while let Some(worker) = self.workers.pop() {
worker.join().unwrap();
}
return None;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::dataloader::FixBatchStrategy;
use crate::data::dataloader::batcher::TestBatcher;
use crate::data::dataset::FakeDataset;
use burn_dataset::InMemDataset;
use std::collections::HashSet;
#[test]
fn test_multi_thread_batch_dataloader() {
let batcher = Arc::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(27));
let dataloader_single_thread = BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset.clone(),
batcher.clone(),
Default::default(),
None,
);
let dataloader_multi_thread = MultiThreadDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset,
batcher,
4,
Default::default(),
None,
);
let mut items_single_thread = HashSet::new();
let mut items_multi_thread = HashSet::new();
for items in dataloader_single_thread.iter() {
for item in items {
items_single_thread.insert(item);
}
}
for items in dataloader_multi_thread.iter() {
for item in items {
items_multi_thread.insert(item);
}
}
assert_eq!(items_single_thread, items_multi_thread);
}
#[test]
fn test_multi_thread_batch_dataloader_shuffle() {
let num_classes = 2;
let class_size = 100;
let batch_size = 10;
// Items is a deliberately ordered dataset.
let mut items = Vec::new();
for class in 0..num_classes {
items.extend(vec![class; class_size]);
}
{
// Unshuffled multithreaded loader
let dataset = Arc::new(InMemDataset::new(items.clone()));
let batcher = Arc::new(TestBatcher::new());
let loader = MultiThreadDataLoader::new(
Box::new(FixBatchStrategy::new(batch_size)),
dataset,
batcher,
num_classes,
Default::default(),
// No rng means no shuffling.
None,
);
for batch in loader.iter() {
let mut batch_items = HashSet::new();
for item in batch {
batch_items.insert(item);
}
// Since the dataset is not shuffled, we expect each batch to contain the same item.
assert_eq!(batch_items.len(), 1);
}
}
{
// Shuffled multithreaded loader
let dataset = Arc::new(InMemDataset::new(items.clone()));
let batcher = Arc::new(TestBatcher::new());
let loader = MultiThreadDataLoader::new(
Box::new(FixBatchStrategy::new(batch_size)),
dataset.clone(),
batcher.clone(),
num_classes,
Default::default(),
// The rng enables shuffling.
Some(StdRng::seed_from_u64(42)),
);
for batch in loader.iter() {
let mut batch_items = HashSet::new();
for item in batch {
batch_items.insert(item);
}
// Since the dataset is shuffled, we expect to see all items.
assert_eq!(batch_items.len(), num_classes);
}
}
}
#[test]
fn test_multi_thread_batch_dataloader_incomplete_batches() {
let batcher = Arc::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(27));
let dataloader_single_thread = BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset.clone(),
batcher.clone(),
Default::default(),
None,
);
let dataloader_multi_thread = MultiThreadDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset,
batcher,
4,
Default::default(),
None,
);
let mut items_single_thread = HashSet::new();
let mut items_multi_thread = HashSet::new();
let mut single_thread_cnt = 0;
let mut multi_thread_cnt = 0;
for items in dataloader_single_thread.iter() {
items_single_thread.insert(items);
single_thread_cnt += 1;
}
for items in dataloader_multi_thread.iter() {
items_multi_thread.insert(items);
multi_thread_cnt += 1;
}
assert_eq!(single_thread_cnt, multi_thread_cnt);
assert_eq!(items_single_thread, items_multi_thread);
}
}

View File

@@ -0,0 +1,134 @@
use std::sync::Arc;
use burn_tensor::backend::Backend;
use super::DataLoader;
/// Splits a dataloader into multiple partial dataloaders (one per device).
pub fn split_dataloader<B: Backend, O>(
dataloader: Arc<dyn DataLoader<B, O>>,
devices: &[B::Device],
) -> Vec<Arc<dyn DataLoader<B, O>>> {
let num_splits = devices.len();
if num_splits > 1 {
let num_items = dataloader.num_items();
let mut dataloaders = Vec::with_capacity(num_splits);
let mut start = 0;
let step = num_items / num_splits;
for (i, device) in devices.iter().enumerate() {
let end = if i == (num_splits - 1) {
num_items
} else {
start + step
};
let dataloader = dataloader.slice(start, end).to_device(device);
dataloaders.push(dataloader);
start = end;
}
dataloaders
} else {
vec![dataloader]
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
use crate::TestBackend;
use crate::data::dataloader::batcher::Batcher;
use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};
use crate::data::dataset::FakeDataset;
#[test]
fn test_split_batch_dataloader() {
type TestDevice = <TestBackend as Backend>::Device;
#[derive(new, Clone)]
pub struct TestBatcher;
#[cfg(test)]
impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {
fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {
(items, *device)
}
}
let batcher = Arc::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(11));
#[allow(clippy::arc_with_non_send_sync)]
let dataloader = Arc::new(BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset.clone(),
batcher,
Default::default(),
None,
));
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda")
))]
// Only one device exists...
let (device1, device2) = (
burn_ndarray::NdArrayDevice::Cpu,
burn_ndarray::NdArrayDevice::Cpu,
);
#[cfg(all(test, feature = "test-tch"))]
let (device1, device2) = (
burn_tch::LibTorchDevice::Cuda(0),
burn_tch::LibTorchDevice::Cuda(1),
);
#[cfg(all(test, feature = "test-wgpu"))]
let (device1, device2) = (
burn_wgpu::WgpuDevice::DiscreteGpu(0),
burn_wgpu::WgpuDevice::DiscreteGpu(1),
);
#[cfg(all(test, feature = "test-cuda"))]
let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);
assert_eq!(dataloaders.len(), 2);
let [dataloader_1, dataloader_2] = match dataloaders.try_into() {
Ok(arr) => arr,
Err(_) => unreachable!(),
};
assert_eq!(dataloader_1.num_items(), 5);
assert_eq!(dataloader_2.num_items(), 6);
let mut items_dataloader = HashSet::new();
let mut items_dataloader_split = HashSet::new();
for (items, _device) in dataloader.iter() {
for item in items {
items_dataloader.insert(item);
}
}
for (items, device) in dataloader_1.iter() {
assert_eq!(device, device1);
for item in items {
items_dataloader_split.insert(item);
}
}
for (items, device) in dataloader_2.iter() {
assert_eq!(device, device2);
for item in items {
items_dataloader_split.insert(item);
}
}
assert_eq!(items_dataloader, items_dataloader_split);
}
}

View File

@@ -0,0 +1,87 @@
/// A strategy to batch items.
pub trait BatchStrategy<I>: Send + Sync {
/// Adds an item to the strategy.
///
/// # Arguments
///
/// * `item` - The item to add.
fn add(&mut self, item: I);
/// Batches the items.
///
/// # Arguments
///
/// * `force` - Whether to force batching.
///
/// # Returns
///
/// The batched items.
fn batch(&mut self, force: bool) -> Option<Vec<I>>;
/// Creates a new strategy of the same type.
///
/// # Returns
///
/// The new strategy.
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
/// Returns the expected batch size for this strategy.
///
/// # Returns
///
/// The batch size, or None if the strategy doesn't have a fixed batch size.
fn batch_size(&self) -> Option<usize>;
}
/// A strategy to batch items with a fixed batch size.
pub struct FixBatchStrategy<I> {
items: Vec<I>,
batch_size: usize,
}
impl<I> FixBatchStrategy<I> {
/// Creates a new strategy to batch items with a fixed batch size.
///
/// # Arguments
///
/// * `batch_size` - The batch size.
///
/// # Returns
///
/// The strategy.
pub fn new(batch_size: usize) -> Self {
FixBatchStrategy {
items: Vec::with_capacity(batch_size),
batch_size,
}
}
}
impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
fn add(&mut self, item: I) {
self.items.push(item);
}
fn batch(&mut self, force: bool) -> Option<Vec<I>> {
if self.items.len() < self.batch_size && !force {
return None;
}
let mut items = Vec::with_capacity(self.batch_size);
std::mem::swap(&mut items, &mut self.items);
if items.is_empty() {
return None;
}
Some(items)
}
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
Box::new(Self::new(self.batch_size))
}
fn batch_size(&self) -> Option<usize> {
Some(self.batch_size)
}
}

View File

@@ -0,0 +1,15 @@
/// Dataloader module.
#[cfg(feature = "dataset")]
pub mod dataloader;
/// Dataset module.
#[cfg(feature = "dataset")]
pub mod dataset {
pub use burn_dataset::*;
}
/// Network module.
#[cfg(feature = "network")]
pub mod network {
pub use burn_std::network::*;
}

View File

@@ -0,0 +1,118 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "135"]
//! The core crate of Burn.
#[macro_use]
extern crate derive_new;
/// Re-export serde for proc macros.
pub use serde;
/// The configuration module.
pub mod config;
/// Data module.
#[cfg(feature = "std")]
pub mod data;
/// Module for the neural network module.
pub mod module;
/// Module for the recorder.
pub mod record;
/// Module for the tensor.
pub mod tensor;
// Tensor at root: `burn::Tensor`
pub use tensor::Tensor;
/// Module for visual operations
#[cfg(feature = "vision")]
pub mod vision;
extern crate alloc;
/// Backend for test cases
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda"),
not(feature = "test-rocm")
))]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-tch"))]
/// Backend for test cases
pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type TestBackend = burn_cuda::Cuda;
#[cfg(all(test, feature = "test-rocm"))]
/// Backend for test cases
pub type TestBackend = burn_rocm::Rocm;
/// Backend for autodiff test cases
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
#[cfg(all(test, feature = "test-memory-checks"))]
mod tests {
burn_fusion::memory_checks!();
}
#[cfg(test)]
mod test_utils {
use crate as burn;
use crate::module::Module;
use crate::module::Param;
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;
/// Simple linear module.
#[derive(Module, Debug)]
pub struct SimpleLinear<B: Backend> {
pub weight: Param<Tensor<B, 2>>,
pub bias: Option<Param<Tensor<B, 1>>>,
}
impl<B: Backend> SimpleLinear<B> {
pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
let weight = Tensor::random(
[out_features, in_features],
burn_tensor::Distribution::Default,
device,
);
let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
Self {
weight: Param::from_tensor(weight),
bias: Some(Param::from_tensor(bias)),
}
}
}
}
pub mod prelude {
//! Structs and macros used by most projects. Add `use
//! burn::prelude::*` to your code to quickly get started with
//! Burn.
pub use crate::{
config::Config,
module::Module,
tensor::{
Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,
backend::Backend, cast::ToElement, s,
},
};
pub use burn_std::device::Device as DeviceOps;
}

View File

@@ -0,0 +1,470 @@
use super::{Param, ParamId, Quantizer};
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::{string::String, vec::Vec};
pub use burn_derive::Module;
use burn_tensor::{Bool, Int, Tensor, ops::Device};
/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
/// the `alloc` crate.
pub type Devices<B> = Vec<Device<B>>;
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map_float<const D: usize>(
&mut self,
param: Param<Tensor<B, D>>,
) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let func = $item;
let tensor = func(tensor);
Param::from_mapped_value(id, tensor, mapper)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let func = $item;
func(&param.val(), &mut self.state)
}
}
#[allow(clippy::redundant_closure_call)]
let mut state = $init();
let mut visitor = Visitor {
state: &mut state,
backend: core::marker::PhantomData,
};
$module.visit(&mut visitor);
state
}};
}
/// Trait for all neural network modules.
///
/// Modules should be created using the [derive](burn_derive::Module) attribute.
/// This will make your module trainable, savable and loadable via
/// `state` and `load`.
///
/// # Example
///
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
/// necessary to optimize and train the module on any backend.
///
/// ```rust, ignore
/// // Not necessary when using the burn crate directly.
/// use burn_core as burn;
///
/// use burn::{
/// module::Module,
/// nn::Linear,
/// tensor::Tensor,
/// tensor::backend::Backend,
/// };
///
/// #[derive(Module, Debug)]
/// struct MyModule<B: Backend> {
/// my_param: Linear<B>,
/// my_other_field: usize,
/// }
/// ```
pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record<B>;
/// Return all the devices found in the underneath module tree added to the given vector
/// without duplicates.
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
/// Return all the devices found in the underneath module tree without duplicates.
fn devices(&self) -> Devices<B> {
self.collect_devices(Devices::<B>::new())
}
/// Fork the module and all of its sub-modules to the given device.
///
/// # Notes
///
/// This is similar to [to_device](Module::to_device), but it ensures the output module on the
/// new device will have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self;
/// Move the module and all of its sub-modules to the given device.
///
/// # Warnings
///
/// The operation supports autodiff and it will be registered when activated. However, this may
/// not be what you want. The output model will be an intermediary model, meaning that you
/// can't optimize it with gradient descent. If you want to optimize the output network on the
/// target device, use [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self;
/// Each tensor in the module tree will not require grad.
///
/// # Warnings
///
/// This should not be used for inference, use [valid](AutodiffModule::valid) when using
/// AD modules. This is mostly useful when performing partial finetuning, which is updating only
/// a small fraction of the parameters instead of finetuning all of them.
fn no_grad(self) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
)
}
/// Move the module and all of its sub-modules to the autodiff backend.
///
/// # Notes
///
/// * Only plain modules (not already on an autodiff backend) can be moved.
/// * Calling `train()` on a module that is already on an autodiff backend
/// will result in a type error, because the module's inner backend does not match.
fn train<AB>(self) -> <Self as HasAutodiffModule<AB>>::TrainModule
where
AB: AutodiffBackend<InnerBackend = B>,
Self: HasAutodiffModule<AB>,
{
<Self as HasAutodiffModule<AB>>::TrainModule::from_inner(self)
}
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize {
module!(
visit_float = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
/// Load the module state from a record.
fn load_record(self, record: Self::Record) -> Self;
/// Convert the module into a record containing the state.
fn into_record(self) -> Self::Record;
#[cfg(feature = "std")]
/// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
///
/// List of supported file recorders:
///
/// * [default](crate::record::DefaultFileRecorder)
/// * [bincode](crate::record::BinFileRecorder)
/// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
/// * [json pretty](crate::record::PrettyJsonFileRecorder)
/// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
/// * [named mpk](crate::record::NamedMpkFileRecorder)
/// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
///
/// ## Notes
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = Self::into_record(self);
recorder.record(record, file_path.into())
}
#[cfg(feature = "std")]
/// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
///
/// The recorder should be the same as the one used to save the module, see
/// [save_file](Self::save_file).
///
/// ## Notes
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &B::Device,
) -> Result<Self, crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = recorder.load(file_path.into(), device)?;
Ok(self.load_record(record))
}
/// Quantize the weights of the module.
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
self.map(quantizer)
}
}
/// Module visitor trait for traversing and inspecting module parameters.
pub trait ModuleVisitor<B: Backend> {
/// Visit a float parameter in the module.
///
/// # Parameters
/// - `param`: The float parameter to visit
#[allow(unused_variables)]
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
/// Visit an int parameter in the module.
///
/// # Parameters
/// - `param`: The integer parameter to visit
#[allow(unused_variables)]
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
/// Visit a bool parameter in the module.
///
/// # Parameters
/// - `param`: The boolean parameter to visit
#[allow(unused_variables)]
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
/// Called when entering a submodule.
///
/// # Parameters
/// - `name`: The name of the submodule being entered
/// - `container_type`: The type of the container with format:
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
/// - For Vec containers: "Vec" (name is the index)
/// - For Tuple containers: "Tuple" (name is the index)
/// - For Array containers: "Array" (name is the index)
///
/// Note: Option containers do not call enter_module/exit_module to preserve
/// the field name in the path (e.g., "bias" instead of "bias.Some")
#[allow(unused_variables)]
fn enter_module(&mut self, name: &str, container_type: &str) {}
/// Called when exiting a submodule.
///
/// # Parameters
/// - `name`: The name of the submodule being exited
/// - `container_type`: The type of the container with format:
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
/// - For Vec containers: "Vec" (name is the index)
/// - For Tuple containers: "Tuple" (name is the index)
/// - For Array containers: "Array" (name is the index)
///
/// Note: Option containers do not call enter_module/exit_module to preserve
/// the field name in the path (e.g., "bias" instead of "bias.Some")
#[allow(unused_variables)]
fn exit_module(&mut self, name: &str, container_type: &str) {}
/// Visit a float tensor with its full module path.
///
/// # Parameters
/// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
/// Each element represents a module name in the hierarchy, with the final element
/// being the parameter name. This allows efficient reuse of the path stack.
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The float tensor to visit
#[allow(unused_variables)]
fn visit_float_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D>,
) {
}
/// Visit an int tensor with its full module path.
///
/// # Parameters
/// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
/// Each element represents a module name in the hierarchy, with the final element
/// being the parameter name. This allows efficient reuse of the path stack.
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The integer tensor to visit
#[allow(unused_variables)]
fn visit_int_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D, Int>,
) {
}
/// Visit a bool tensor with its full module path.
///
/// # Parameters
/// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
/// Each element represents a module name in the hierarchy, with the final element
/// being the parameter name. This allows efficient reuse of the path stack.
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The boolean tensor to visit
#[allow(unused_variables)]
fn visit_bool_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D, Bool>,
) {
}
}
/// Module mapper trait for transforming module parameters.
pub trait ModuleMapper<B: Backend> {
/// Called when entering a submodule.
///
/// # Parameters
/// - `name`: The name of the submodule being entered
/// - `container_type`: The type of the container with format:
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
/// - For Vec containers: "Vec" (name is the index)
/// - For Tuple containers: "Tuple" (name is the index)
/// - For Array containers: "Array" (name is the index)
///
/// Note: Option containers do not call enter_module/exit_module to preserve
/// the field name in the path (e.g., "bias" instead of "bias.Some")
#[allow(unused_variables)]
fn enter_module(&mut self, name: &str, container_type: &str) {}
/// Called when exiting a submodule.
///
/// # Parameters
/// - `name`: The name of the submodule being exited
/// - `container_type`: The type of the container with format:
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
/// - For Vec containers: "Vec" (name is the index)
/// - For Tuple containers: "Tuple" (name is the index)
/// - For Array containers: "Array" (name is the index)
///
/// Note: Option containers do not call enter_module/exit_module to preserve
/// the field name in the path (e.g., "bias" instead of "bias.Some")
#[allow(unused_variables)]
fn exit_module(&mut self, name: &str, container_type: &str) {}
/// Map a float parameter in the module.
///
/// # Parameters
/// - `param`: The float parameter to transform
///
/// # Returns
/// The transformed parameter
#[allow(unused_variables)]
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
/// Map an int parameter in the module.
///
/// # Parameters
/// - `param`: The integer parameter to transform
///
/// # Returns
/// The transformed parameter
#[allow(unused_variables)]
fn map_int<const D: usize>(
&mut self,
param: Param<Tensor<B, D, Int>>,
) -> Param<Tensor<B, D, Int>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
/// Map a bool parameter in the module.
///
/// # Parameters
/// - `param`: The boolean parameter to transform
///
/// # Returns
/// The transformed parameter
#[allow(unused_variables)]
fn map_bool<const D: usize>(
&mut self,
param: Param<Tensor<B, D, Bool>>,
) -> Param<Tensor<B, D, Bool>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
}
/// Module with auto-differentiation backend.
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
/// Inner module without auto-differentiation.
type InnerModule: Module<B::InnerBackend>;
/// Returns the same module, but on the inner backend without auto-differentiation.
fn valid(&self) -> Self::InnerModule;
/// Wraps an inner module back into an auto-diff module.
fn from_inner(module: Self::InnerModule) -> Self;
}
/// Helper trait to associate a module with its autodiff version.
pub trait HasAutodiffModule<B: AutodiffBackend> {
/// The module with auto-differentiation.
type TrainModule: AutodiffModule<B, InnerModule = Self>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use crate::test_utils::SimpleLinear;
#[test]
fn test_module_val_train_stateful() {
let device = Default::default();
let module = SimpleLinear::<TestAutodiffBackend>::new(4, 4, &device);
assert!(module.weight.is_require_grad());
assert!(module.weight.require_grad);
let module = module.valid();
assert!(!module.weight.is_require_grad());
assert!(module.weight.require_grad); // stateful
// Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying
// let module: SimpleLinear<TestAutodiffBackend> = module.train();
let module = module.train::<TestAutodiffBackend>();
assert!(module.weight.is_require_grad());
assert!(module.weight.require_grad); // stateful
let module = module.no_grad();
assert!(!module.weight.is_require_grad());
assert!(!module.weight.require_grad); // stateful
let module = module.valid();
assert!(!module.weight.is_require_grad()); // always
assert!(!module.weight.require_grad); // stateful
let module = module.train::<TestAutodiffBackend>();
assert!(!module.weight.is_require_grad());
assert!(!module.weight.require_grad); // stateful
}
}

View File

@@ -0,0 +1,543 @@
use alloc::{
borrow::ToOwned,
format,
string::{String, ToString},
vec::Vec,
};
use core::any;
use core::fmt::{Display, Write};
/// Default display settings for a module.
pub trait ModuleDisplayDefault {
/// Attributes of the module used for display purposes.
///
/// # Arguments
///
/// * `_content` - The content object that contains display settings and attributes.
///
/// # Returns
///
/// An optional content object containing the display attributes.
fn content(&self, _content: Content) -> Option<Content>;
/// Gets the number of the parameters of the module.
fn num_params(&self) -> usize {
0
}
}
/// Trait to implement custom display settings for a module.
///
/// In order to implement custom display settings for a module,
/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
/// 2. Implement ModuleDisplay trait for the module
pub trait ModuleDisplay: ModuleDisplayDefault {
/// Formats the module with provided display settings.
///
/// # Arguments
///
/// * `passed_settings` - Display settings passed to the module.
///
/// # Returns
///
/// A string representation of the formatted module.
fn format(&self, passed_settings: DisplaySettings) -> String {
let settings = if let Some(custom_settings) = self.custom_settings() {
custom_settings.inherit(passed_settings)
} else {
passed_settings
};
let indent = " ".repeat(settings.level * settings.indentation_size());
let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
let settings = settings.level_up();
let self_type = extract_type_name::<Self>();
// Use custom content if it is implemented and show_all_attributes is false,
// otherwise use default content
let content = if !settings.show_all_attributes() {
self.custom_content(Content::new(settings.clone()))
.unwrap_or_else(|| {
self.content(Content::new(settings.clone()))
.unwrap_or_else(|| {
panic!("Default content should be implemented for {self_type}.")
})
})
} else {
self.content(Content::new(settings.clone()))
.unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
};
let top_level_type = if let Some(top_level_type) = content.top_level_type {
top_level_type.to_owned()
} else {
self_type.to_owned()
};
// If there is only one item in the content, return it or no attributes
if let Some(item) = content.single_item {
return item;
} else if content.attributes.is_empty() {
return top_level_type.to_string();
}
let mut result = String::new();
// Print the struct name
if settings.new_line_after_attribute() {
writeln!(result, "{top_level_type} {{").unwrap();
} else {
write!(result, "{top_level_type} {{").unwrap();
}
for (i, attribute) in content.attributes.iter().enumerate() {
if settings.new_line_after_attribute() {
writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
} else if i == 0 {
write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
} else {
write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
}
}
if settings.show_num_parameters() {
let num_params = self.num_params();
if num_params > 0 {
if settings.new_line_after_attribute() {
writeln!(result, "{indent}params: {num_params}").unwrap();
} else {
write!(result, ", params: {num_params}").unwrap();
}
}
}
if settings.new_line_after_attribute() {
write!(result, "{indent_close_braces}}}").unwrap();
} else {
write!(result, "}}").unwrap();
}
result
}
/// Custom display settings for the module.
///
/// # Returns
///
/// An optional display settings object.
fn custom_settings(&self) -> Option<DisplaySettings> {
None
}
/// Custom attributes for the module.
///
/// # Arguments
///
/// * `_content` - The content object that contains display settings and attributes.
///
/// # Returns
///
/// An optional content object containing the custom attributes.
fn custom_content(&self, _content: Content) -> Option<Content> {
None
}
}
/// Custom module display settings.
#[derive(Debug, Clone)]
pub struct DisplaySettings {
/// Whether to print the module parameter ids.
show_param_id: Option<bool>,
/// Whether to print the module attributes.
show_all_attributes: Option<bool>,
/// Whether to print the module number of parameters.
show_num_parameters: Option<bool>,
/// Print new line after an attribute.
new_line_after_attribute: Option<bool>,
/// Indentation size.
indentation_size: Option<usize>,
/// Level of indentation.
level: usize,
}
impl Default for DisplaySettings {
fn default() -> Self {
DisplaySettings {
show_param_id: None,
show_all_attributes: None,
show_num_parameters: None,
new_line_after_attribute: None,
indentation_size: None,
level: 1,
}
}
}
impl DisplaySettings {
/// Create a new format settings.
///
/// # Returns
///
/// A new instance of `DisplaySettings`.
pub fn new() -> Self {
Default::default()
}
/// Sets a flag to show module parameters.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show module parameters.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_param_id(mut self, flag: bool) -> Self {
self.show_param_id = Some(flag);
self
}
/// Sets a flag to show module attributes.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show all module attributes.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
self.show_all_attributes = Some(flag);
self
}
/// Sets a flag to show the number of module parameters.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show the number of module parameters.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
self.show_num_parameters = Some(flag);
self
}
/// Sets a flag to print a new line after an attribute.
///
/// # Arguments
///
/// * `flag` - Boolean flag to print a new line after an attribute.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
self.new_line_after_attribute = Some(flag);
self
}
/// Sets the indentation size.
///
/// # Arguments
///
/// * `size` - The size of the indentation.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_indentation_size(mut self, size: usize) -> Self {
self.indentation_size = Some(size);
self
}
/// Inherits settings from the provided settings and return a new settings object.
///
/// # Arguments
///
/// * `top` - The top level `DisplaySettings` to inherit from.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn inherit(self, top: Self) -> Self {
let mut updated = self.clone();
if let Some(show_param_id) = top.show_param_id {
updated.show_param_id = Some(show_param_id);
};
if let Some(show_all_attributes) = top.show_all_attributes {
updated.show_all_attributes = Some(show_all_attributes);
}
if let Some(show_num_parameters) = top.show_num_parameters {
updated.show_num_parameters = Some(show_num_parameters);
}
if let Some(new_line_after_attribute) = top.new_line_after_attribute {
updated.new_line_after_attribute = Some(new_line_after_attribute);
}
if let Some(indentation_size) = top.indentation_size {
updated.indentation_size = Some(indentation_size);
}
updated.level = top.level;
updated
}
/// A convenience method to wrap the DisplaySettings struct in an option.
///
/// # Returns
///
/// An optional `DisplaySettings`.
pub fn optional(self) -> Option<Self> {
Some(self)
}
/// Increases the level of indentation.
///
/// # Returns
///
/// Updated `DisplaySettings` instance with increased indentation level.
pub fn level_up(mut self) -> Self {
self.level += 1;
self
}
/// Gets `show_param_id` flag, substitutes false if not set.
///
/// This flag is used to print the module parameter ids.
///
/// # Returns
///
/// A boolean value indicating whether to show parameter ids.
pub fn show_param_id(&self) -> bool {
self.show_param_id.unwrap_or(false)
}
/// Gets `show_all_attributes`, substitutes false if not set.
///
/// This flag is used to force to print all module attributes, overriding custom attributes.
///
/// # Returns
///
/// A boolean value indicating whether to show all attributes.
pub fn show_all_attributes(&self) -> bool {
self.show_all_attributes.unwrap_or(false)
}
/// Gets `show_num_parameters`, substitutes true if not set.
///
/// This flag is used to print the number of module parameters.
///
/// # Returns
///
/// A boolean value indicating whether to show the number of parameters.
pub fn show_num_parameters(&self) -> bool {
self.show_num_parameters.unwrap_or(true)
}
/// Gets `new_line_after_attribute`, substitutes true if not set.
///
/// This flag is used to print a new line after an attribute.
///
/// # Returns
///
/// A boolean value indicating whether to print a new line after an attribute.
pub fn new_line_after_attribute(&self) -> bool {
self.new_line_after_attribute.unwrap_or(true)
}
/// Gets `indentation_size`, substitutes 2 if not set.
///
/// This flag is used to set the size of indentation.
///
/// # Returns
///
/// An integer value indicating the size of indentation.
pub fn indentation_size(&self) -> usize {
self.indentation_size.unwrap_or(2)
}
}
/// Struct to store the attributes of a module for formatting.
#[derive(Clone, Debug)]
pub struct Content {
/// List of attributes.
pub attributes: Vec<Attribute>,
/// Single item content.
pub single_item: Option<String>,
/// Display settings.
pub display_settings: DisplaySettings,
/// Top level type name.
pub top_level_type: Option<String>,
}
impl Content {
/// Creates a new attributes struct.
///
/// # Arguments
///
/// * `display_settings` - Display settings for the content.
///
/// # Returns
///
/// A new instance of `Content`.
pub fn new(display_settings: DisplaySettings) -> Self {
Content {
attributes: Vec::new(),
single_item: None,
display_settings,
top_level_type: None,
}
}
/// Adds an attribute to the format settings. The value will be formatted and stored as a string.
///
/// # Arguments
///
/// * `name` - Name of the attribute.
/// * `value` - Value of the attribute.
///
/// # Returns
///
/// Updated `Content` instance with the new attribute added.
pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
if self.single_item.is_some() {
panic!("Cannot add multiple attributes when single item is set.");
}
let attribute = Attribute {
name: name.to_owned(),
value: value.format(self.display_settings.clone()), // TODO level + 1
ty: any::type_name::<T>().to_string(),
};
self.attributes.push(attribute);
self
}
/// Adds a single item.
///
/// # Arguments
///
/// * `value` - Rendered string of the single item.
///
/// # Returns
///
/// Updated `Content` instance with the single item added.
pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
if !self.attributes.is_empty() {
panic!("Cannot add single item when attributes are set.");
}
self.single_item = Some(value.format(self.display_settings.clone()));
self
}
/// Adds a single item.
///
/// # Arguments
///
/// * `value` - Formatted display value.
///
/// # Returns
///
/// Updated `Content` instance with the formatted single item added.
pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
if !self.attributes.is_empty() {
panic!("Cannot add single item when attributes are set.");
}
self.single_item = Some(format!("{value}"));
self
}
/// A convenience method to wrap the Attributes struct in an option
/// because it is often used as an optional field.
///
/// # Returns
///
/// An optional `Content`.
pub fn optional(self) -> Option<Self> {
if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
{
None
} else {
Some(self)
}
}
/// Sets the top level type name.
///
/// # Arguments
///
/// * `ty` - The type name to set.
///
/// # Returns
///
/// Updated `Content` instance with the top level type name set.
pub fn set_top_level_type(mut self, ty: &str) -> Self {
self.top_level_type = Some(ty.to_owned());
self
}
}
/// Attribute to print in the display method.
#[derive(Clone, Debug)]
pub struct Attribute {
/// Name of the attribute.
pub name: String,
/// Value of the attribute.
pub value: String,
/// Type of the attribute.
pub ty: String,
}
/// Extracts the short name of a type T
///
/// # Returns
///
/// A string slice representing the short name of the type.
pub fn extract_type_name<T: ?Sized>() -> &'static str {
// Get the full type name of T, including module path and generic parameters
let ty = any::type_name::<T>();
// Find the first occurrence of '<' in the full type name
// If not found, use the length of the type name
let end = ty.find('<').unwrap_or(ty.len());
// Slice the type name up to the first '<' or the end
let ty = &ty[0..end];
// Find the last occurrence of "::" in the sliced type name
// If found, add 2 to skip the "::" itself
// If not found, start from the beginning of the type name
let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
// Find the last occurrence of '<' in the sliced type name
// If not found, use the length of the type name
let end = ty.rfind('<').unwrap_or(ty.len());
// If the start index is less than the end index,
// return the slice of the type name from start to end
// Otherwise, return the entire sliced type name
if start < end { &ty[start..end] } else { ty }
}

View File

@@ -0,0 +1,627 @@
use crate::tensor::Shape;
use crate::config::Config;
use crate::module::{Param, ParamId};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor, s};
use crate as burn;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
/// Enum specifying with what values a tensor should be initialized
#[derive(Config, Debug, PartialEq)]
pub enum Initializer {
/// Fills tensor with specified value everywhere
Constant {
/// The value to fill the tensor with
value: f64,
},
/// Fills tensor with 1s everywhere
Ones,
/// Fills tensor with 0s everywhere
Zeros,
/// Fills tensor with values drawn uniformly between specified values
Uniform {
/// The minimum value to draw from
min: f64,
/// The maximum value to draw from
max: f64,
},
/// Fills tensor with values drawn from normal distribution with specified mean and std
Normal {
/// The mean of the normal distribution
mean: f64,
/// The standard deviation of the normal distribution
std: f64,
},
/// Fills tensor with values according to the uniform version of Kaiming initialization
KaimingUniform {
/// The gain to use in initialization formula
gain: f64,
/// Whether to use fan out only in initialization formula
fan_out_only: bool,
},
/// Fills tensor with values according to the uniform version of Kaiming initialization
KaimingNormal {
/// The gain to use in initialization formula
gain: f64,
/// Whether to use fan out only in initialization formula
fan_out_only: bool,
},
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization
/// described in [Understanding the difficulty of training deep feedforward neural networks
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
XavierUniform {
/// The gain to use in initialization formula
gain: f64,
},
/// Fills tensor with values according to the normal version of Xavier Glorot initialization
/// described in [Understanding the difficulty of training deep feedforward neural networks
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
XavierNormal {
/// The gain to use in initialization formula
gain: f64,
},
/// Fills tensor with values according to the (semi) orthogonal initialization
/// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
/// - [Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120)
Orthogonal {
/// The gain to use in initialization formula
gain: f64,
},
}
impl Initializer {
/// Inits a tensor parameter of given shape with values depending on initializer kind.
///
/// # Params
///
/// - shape: Shape of the initiated tensor.
pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
&self,
shape: S,
device: &B::Device,
) -> Param<Tensor<B, D>> {
self.init_with(shape, None, None, device)
}
/// Inits a tensor parameter of given shape with values depending on initializer kind.
///
/// # Params
///
/// - shape: Shape of the initiated tensor.
pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
&self,
shape: S,
fan_in: Option<usize>,
fan_out: Option<usize>,
device: &B::Device,
) -> Param<Tensor<B, D>> {
let device = device.clone();
let shape: Shape = shape.into();
let config = self.clone();
let shape_for_closure = shape.clone();
Param::uninitialized(
ParamId::new(),
move |device, require_grad| {
B::memory_persistent_allocations(device, (), move |_| {
let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
if require_grad {
tensor = tensor.require_grad();
}
tensor
})
},
device,
true,
shape_for_closure,
)
}
fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
&self,
shape: S,
fan_in: Option<usize>,
fan_out: Option<usize>,
device: &B::Device,
) -> Tensor<B, D> {
let shape = shape.into();
match self {
Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
Initializer::Ones => Tensor::<B, D>::ones(shape, device),
Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
Initializer::KaimingUniform { gain, fan_out_only } => {
let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
uniform_draw(shape, -a, a, device)
}
Initializer::KaimingNormal { gain, fan_out_only } => {
let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
normal_draw(shape, 0.0, std, device)
}
Initializer::XavierUniform { gain } => {
let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
uniform_draw(shape, -a, a, device)
}
Initializer::XavierNormal { gain } => {
let std = *gain * self.xavier_std(fan_in, fan_out);
normal_draw(shape, 0.0, std, device)
}
Initializer::Orthogonal { gain } => {
// following the implementation in pytorch:
// https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574
assert!(
D >= 2,
"Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
);
let rows: usize = shape.dims::<D>()[0];
let cols: usize = shape.num_elements() / rows;
let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
if rows < cols {
t = t.transpose();
}
let (q, r) = qr_decomposition(t, device);
let [r_rows, r_cols] = r.clone().dims();
let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
.matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
let ph = diag_r.clone().sign();
let mut q = q.mul(ph);
if rows < cols {
q = q.transpose();
}
q.reshape(shape).mul_scalar(*gain)
}
}
}
fn kaiming_std(
&self,
fan_out_only: bool,
fan_in: Option<usize>,
fan_out: Option<usize>,
) -> f64 {
let fan = if fan_out_only { fan_out } else { fan_in };
let fan = fan.expect(
"Can't use Kaiming initialization without specifying fan. Use init_with method.",
);
1.0 / (fan as f64).sqrt()
}
fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
let fan_in = fan_in.expect(
"Can't use Xavier initialization without specifying fan in. Use init_with method and \
provide fan_in.",
);
let fan_out = fan_out.expect(
"Can't use Xavier initialization without specifying fan out. Use init_with method and \
provide fan_out.",
);
(2.0 / (fan_in + fan_out) as f64).sqrt()
}
}
fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
shape: S,
low: f64,
high: f64,
device: &B::Device,
) -> Tensor<B, D> {
let distribution = Distribution::Uniform(low, high);
Tensor::<B, D>::random(shape, distribution, device)
}
fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
shape: S,
mean: f64,
std: f64,
device: &B::Device,
) -> Tensor<B, D> {
let distribution = Distribution::Normal(mean, std);
Tensor::<B, D>::random(shape, distribution, device)
}
fn qr_decomposition<B: Backend>(
a: Tensor<B, 2>,
device: &B::Device,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
// Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
let [m, n] = a.clone().dims();
let mut q = Tensor::<B, 2>::zeros([m, n], device);
let mut r = Tensor::<B, 2>::zeros([n, n], device);
for j in 0..n {
let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
for i in 0..j {
let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
let r_ij = q_i.clone().mul(v.clone()).sum();
r = r
.clone()
.slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
v = v - q_i.mul(r_ij);
}
// norm of v
let r_jj = v
.clone()
.powf(Tensor::from_floats([2.0], device))
.sum()
.sqrt();
r = r
.clone()
.slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
let q_j = v / r_jj;
q = q
.clone()
.slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
}
(q, r)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_tensor::{ElementConversion, TensorData};
use num_traits::Pow;
pub type TB = burn_ndarray::NdArray<f32>;
use burn_tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TB>;
fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
let (actual_vars, actual_means) = tensor.clone().var_mean(0);
let actual_vars = actual_vars.to_data();
let actual_vars = actual_vars.as_slice::<FT>().unwrap();
let actual_means = actual_means.to_data();
let actual_means = actual_means.as_slice::<FT>().unwrap();
for i in 0..tensor.shape()[0] {
let actual_var = actual_vars[i] as f64;
let actual_mean = actual_means[i] as f64;
assert!(
(expected_var - actual_var).abs() <= 0.1,
"Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
);
assert!(
(expected_mean - actual_mean).abs() <= 0.1,
"Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
);
}
}
#[test]
fn initializer_uniform_init() {
let device = Default::default();
TB::seed(&device, 0);
let (min, max) = (0.0, 1.0);
let uniform = Initializer::Uniform { min, max };
let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
tensor
.into_data()
.assert_within_range::<FT>(min.elem()..max.elem());
}
#[test]
fn initializer_normal_init() {
// seed random generator
let device = Default::default();
TB::seed(&device, 0);
let (mean, std) = (0.0, 1.0);
let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
.init([1000], &Default::default())
.into_value();
let (var_act, mean_act) = normal.var_mean(0);
let var_act: f32 = var_act.into_scalar().elem();
let mean_act: f32 = mean_act.into_scalar().elem();
assert!(
var_act > 0.9 && var_act < 1.1,
"Expected variance to be between 1.0 += 0.1, but got {var_act}"
);
assert!(
mean_act > -0.1 && mean_act < 0.1,
"Expected mean to be between 0.0 += 0.1, but got {mean_act}"
);
}
#[test]
fn initializer_constant_init() {
let value = 5.0;
let constants: Tensor<TB, 4> = Initializer::Constant { value }
.init([2, 2, 2, 2], &Default::default())
.into_value();
constants.sum().to_data().assert_approx_eq::<FT>(
&TensorData::from([value as f32 * 16.0]),
Tolerance::default(),
);
}
#[test]
fn initializer_zeros_init() {
let zeros: Tensor<TB, 4> = Initializer::Zeros
.init([2, 2, 2, 2], &Default::default())
.into_value();
zeros
.sum()
.to_data()
.assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
}
#[test]
fn initializer_ones_init() {
let ones: Tensor<TB, 4> = Initializer::Ones
.init([2, 2, 2, 2], &Default::default())
.into_value();
ones.sum()
.to_data()
.assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
}
#[test]
fn initializer_kaiming_uniform_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let (fan_in, fan_out) = (5, 6);
let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
gain,
fan_out_only: false,
}
.init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
.into_value();
tensor.into_data().assert_within_range(-k..k);
}
#[test]
fn initializer_kaiming_normal_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (1000, 10);
let expected_mean = 0_f64;
let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
gain,
fan_out_only: false,
}
.init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
.into_value();
assert_normal_init(expected_mean, expected_var, &tensor)
}
#[test]
fn initializer_kaiming_uniform_init_bias() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let shape = [3];
let fan_in = 5;
let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
gain,
fan_out_only: false,
}
.init_with(shape, Some(fan_in), None, &Default::default())
.into_value();
tensor.into_data().assert_within_range(-k..k);
}
#[test]
fn initializer_kaiming_uniform_init_fan_out() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let (fan_in, fan_out) = (5, 6);
let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
gain,
fan_out_only: true,
}
.init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
.into_value();
tensor.into_data().assert_within_range(-k..k);
}
#[test]
#[should_panic]
fn initializer_kaiming_uniform_no_fan() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let (fan_in, fan_out) = (5, 6);
let _: Tensor<TB, 2> = Initializer::KaimingUniform {
gain,
fan_out_only: false,
}
.init([fan_out, fan_in], &Default::default())
.into_value();
}
#[test]
fn initializer_xavier_uniform_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (5, 6);
let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
.init_with(
[fan_out, fan_in],
Some(fan_in),
Some(fan_out),
&Default::default(),
)
.into_value();
tensor.into_data().assert_within_range(-bound..bound);
}
#[test]
fn initializer_xavier_normal_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (1000, 10);
let expected_mean = 0_f64;
let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
.init_with(
[fan_out, fan_in],
Some(fan_in),
Some(fan_out),
&Default::default(),
)
.into_value();
assert_normal_init(expected_mean, expected_var, &tensor)
}
#[test]
#[should_panic]
fn initializer_xavier_uniform_no_fan() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (5, 6);
let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
.init([fan_out, fan_in], &Default::default())
.into_value();
}
#[test]
fn test_qr_decomposition() {
let device = Default::default();
TB::seed(&device, 0);
// test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr
let a = Tensor::<TB, 2>::from_floats(
[[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
&Default::default(),
);
let qr = qr_decomposition(a.clone(), &Default::default());
// Q @ R should reconstruct input `a`
let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
// assert that the difference between input (`a`) and Q @ R is (almost) zero
q_matmul_r
.into_data()
.assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
}
#[test]
fn initializer_orthogonal_correct() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 1.;
// test 2D tensor
let size = 10;
let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
.init([size, size], &Default::default())
.into_value();
let eye = Tensor::<TB, 2>::eye(size, &Default::default());
// Q.T @ Q should be close to identity matrix
q.clone()
.transpose()
.matmul(q)
.into_data()
.assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
}
#[test]
fn initializer_orthogonal_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 1.;
// test 2D tensor
let shape = [25, 30];
let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
.init(shape, &Default::default())
.into_value();
let dims = t.dims();
assert_eq!(
shape, dims,
"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
);
// test 3D tensor
let shape = [24, 6, 85];
let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
.init(shape, &Default::default())
.into_value();
let dims = t.dims();
assert_eq!(
shape, dims,
"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
);
}
#[test]
#[should_panic]
fn initializer_orthogonal_init_1d() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 1.;
// test 1D tensor
let shape = [3];
let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
.init(shape, &Default::default())
.into_value();
}
}

View File

@@ -0,0 +1,16 @@
mod base;
mod display;
mod initializer;
mod param;
mod quantize;
#[cfg(feature = "std")]
mod reinit;
pub use base::*;
pub use display::*;
pub use initializer::*;
pub use param::*;
pub use quantize::*;
#[cfg(feature = "std")]
pub use reinit::*;

View File

@@ -0,0 +1,424 @@
use super::ParamId;
use alloc::{boxed::Box, format};
use burn_std::stub::RwLock;
use burn_tensor::Shape;
use core::cell::OnceCell;
use core::ops::Deref;
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
#[cfg(target_has_atomic = "ptr")]
type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
#[cfg(not(target_has_atomic = "ptr"))]
type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
#[cfg(target_has_atomic = "ptr")]
fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
Arc::new(func)
}
#[cfg(not(target_has_atomic = "ptr"))]
fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
Arc::new(Box::new(func))
}
/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during
/// training, and loaded during inference. If you don't want to save the tensors
/// and/or don't want to update it during training, you don't need this type to wrap your tensor.
///
/// # Core Lazy Initialization Architecture
///
/// `Param<T>` has a dual-state design using `OnceCell<T>`:
///
/// ## State Management
///
/// **Two possible states:**
///
/// 1. **Initialized**: `state: OnceCell<T>` contains value, `initialization: None`
/// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock<Option<Uninitialized<T>>>)`
pub struct Param<T: Parameter> {
/// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.
pub id: ParamId,
/// The OnceCell holding the initialized parameter value.
/// Empty for uninitialized parameters, populated after first access or explicit initialization.
pub(crate) state: OnceCell<T>,
/// The deferred initialization state for lazy parameters.
///
/// **State Transitions:**
/// - Initialized params: `None`
/// - Uninitialized params: `Some(RwLock<Some(Uninitialized<T>)>)`
/// - After lazy init triggers: `Some(RwLock<None>)` (inner Option is taken)
pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
pub(crate) param_mapper: ParamMapper<T>,
// For stateful `module.valid()` <> `module.train()`
pub(crate) require_grad: bool,
}
#[derive(Clone)]
/// Applies transformations when loading and saving parameters.
///
/// # Mapper System
///
/// `ParamMapper<T>` allows applying transformations during serialization and deserialization:
/// - `load: Option<Mapper<T>>` - transformation during deserialization (applied in `transform_for_load()`)
/// - `save: Option<Mapper<T>>` - transformation during serialization (applied in `transform_for_save()`)
///
/// These are commonly used for:
/// - Quantization/dequantization
/// - Precision conversion (e.g., FP32 ↔ FP16)
/// - Custom parameter transformations
pub struct ParamMapper<T: Parameter> {
load: Option<Mapper<T>>,
save: Option<Mapper<T>>,
}
impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!(
"ParamMapper {{ load: {}, save: {} }}",
self.load.is_some(),
self.save.is_some()
))
}
}
impl<T: Parameter> ParamMapper<T> {
/// Applies the transformation when loading the given parameter.
pub fn on_load(&self, param: T) -> T {
match &self.load {
Some(mapper) => mapper(param),
None => param,
}
}
/// Applies the transformation when saving the given parameter.
pub fn on_save(&self, param: T) -> T {
match &self.save {
Some(mapper) => mapper(param),
None => param,
}
}
}
impl<T: Parameter> Default for ParamMapper<T> {
fn default() -> Self {
Self {
load: None,
save: None,
}
}
}
impl<T: Parameter> core::fmt::Display for Param<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Param: {}", self.id).as_str())
}
}
impl<T: Parameter> core::fmt::Debug for Param<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
}
}
/// Trait that defines what is necessary for a type to be a parameter.
pub trait Parameter: Clone + core::fmt::Debug + Send {
/// The device type to be used.
type Device: Clone;
/// Fetch the device.
fn device(&self) -> Self::Device;
/// Fetch the gradient requirement.
fn is_require_grad(&self) -> bool;
/// Set the gradient requirement.
fn set_require_grad(self, require_grad: bool) -> Self;
}
/// The deferred initialization state for lazy parameters.
#[allow(clippy::type_complexity)]
pub(crate) struct Uninitialized<P: Parameter> {
/// The initialization function. Called with `(device, is_require_grad) -> Parameter`.
/// This function is consumed during initialization via `FnOnce`.
init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
/// The target device on which the parameter should be initialized.
/// Used by `lazy_device()` to provide device information without triggering initialization.
pub(crate) device: P::Device,
/// The gradient requirement for the parameter.
/// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization.
pub(crate) is_require_grad: bool,
/// The shape of the tensor parameter.
/// Used by `lazy_shape()` to provide shape information without triggering initialization.
pub(crate) shape: Shape,
}
impl<P: Parameter> Uninitialized<P> {
/// Consumes the uninitialized state and runs the initialization function.
///
/// This is called by [Param::val] when accessing an uninitialized parameter for the first time.
/// The function is given the stored device and gradient requirement, and returns the initialized parameter.
fn initialize(self) -> P {
let init = self.init;
init(&self.device, self.is_require_grad)
}
}
impl<T: Parameter> Param<T> {
/// Create a new parameter that is already initialized.
pub fn initialized(id: ParamId, value: T) -> Self {
let require_grad = value.is_require_grad();
Self {
id,
state: OnceCell::from(value),
initialization: None,
param_mapper: Default::default(),
require_grad,
}
}
/// Create a new parameter that is not already initialized.
pub fn uninitialized<F>(
id: ParamId,
init: F,
device: T::Device,
is_require_grad: bool,
shape: Shape,
) -> Self
where
F: FnOnce(&T::Device, bool) -> T + Send + 'static,
{
Self {
id,
state: OnceCell::new(),
initialization: Some(RwLock::new(Some(Uninitialized {
init: Box::new(init),
device,
is_require_grad,
shape,
}))),
param_mapper: Default::default(),
require_grad: is_require_grad,
}
}
/// Gets the parameter value, initializing it lazily if needed.
///
/// For initialized parameters, this returns a clone of the cached value.
/// For uninitialized parameters, this triggers initialization:
pub fn val(&self) -> T {
self.state
.get_or_init(|| {
let mut result = self
.initialization
.as_ref()
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
.clone()
}
/// Check if the parameter has been initialized.
///
/// Returns `true` if the parameter's value has been computed and cached,
/// `false` if it's still lazy and will be initialized on first access.
pub fn is_initialized(&self) -> bool {
self.state.get().is_some()
}
/// Gets the parameter's value while consuming the parameter.
pub fn into_value(self) -> T {
self.consume().1
}
/// Gets the parameter id and value while consuming the parameter.
pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
let tensor = self.val();
core::mem::drop(self.state);
(self.id, tensor, self.param_mapper)
}
/// Execute the given function on the inner value.
pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
let (id, tensor, param_mapper) = self.consume();
let tensor = func(tensor);
let require_grad = tensor.is_require_grad();
Self {
id,
state: OnceCell::from(tensor),
initialization: None,
param_mapper,
require_grad,
}
}
/// Create an initialized parameter with the given id, value, and param mapper.
///
/// This is a helper method for creating parameters while preserving the param mapper,
/// typically used in ModuleMapper implementations.
pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
let require_grad = value.is_require_grad();
Self {
id,
state: OnceCell::from(value),
initialization: None,
param_mapper,
require_grad,
}
}
/// Runs a transformation on the parameter when loading.
pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
self.param_mapper.load = Some(new_mapper(func));
self
}
/// Runs a transformation on the parameter when saving.
pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
self.param_mapper.save = Some(new_mapper(func));
self
}
/// Execute the given function on the inner value.
pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self
where
T: 'static,
{
let initialization = match &self.initialization {
Some(init) => init,
None => return self.map(func),
};
let mut init = initialization.write().unwrap();
match init.as_mut() {
Some(value) => {
#[allow(clippy::type_complexity)]
let mut prev: Box<dyn FnOnce(&T::Device, bool) -> T + Send> =
Box::new(|_, _| panic!("Fake func to not have null ref."));
core::mem::swap(&mut prev, &mut value.init);
value.init = Box::new(|a, b| {
let tensor = prev(a, b);
func(tensor)
});
core::mem::drop(init);
self
}
None => {
core::mem::drop(init);
self.map(func)
}
}
}
/// The device on which the parameter is or will be initialized, **without triggering initialization**.
///
/// This is critical for the load optimization: when loading tensors into an uninitialized parameter,
/// we need to know the target device to move the loaded tensor appropriately, but we don't want to
/// trigger the initialization function (which would allocate an unnecessary tensor).
///
/// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to
/// preserve lazy initialization.
pub fn lazy_device(&self) -> T::Device {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.device(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.device.clone(),
None => self.device(),
}
}
/// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**.
///
/// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization.
/// When loading tensors into an uninitialized parameter, we need to apply the correct gradient
/// setting to the loaded tensor without triggering the initialization function.
///
/// # Notes
///
/// This is a crate-private function, since users are not expected to use `is_require_grad` of an
/// uninitialized module to then override its value. All low-level functions should be provided
/// by `burn` and should handle those details.
pub(crate) fn lazy_is_require_grad(&self) -> bool {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.is_require_grad(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.is_require_grad,
None => self.is_require_grad(),
}
}
/// Override the gradient requirement for the current parameter.
pub fn set_require_grad(self, require_grad: bool) -> Self {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
};
let mut init = initialization.write().unwrap();
let mut is_lazy = false;
if let Some(value) = init.as_mut() {
is_lazy = true;
value.is_require_grad = require_grad;
};
core::mem::drop(init);
if is_lazy {
return self;
}
self.map(|tensor| tensor.set_require_grad(require_grad))
}
}
impl<T: Parameter> Clone for Param<T> {
fn clone(&self) -> Self {
let mut param = Param::initialized(self.id, self.val());
param.param_mapper = self.param_mapper.clone();
param
}
}
impl<T: Parameter> Deref for Param<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.state.get_or_init(|| {
let mut result = self
.initialization
.as_ref()
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
}
}

View File

@@ -0,0 +1,408 @@
use alloc::{format, string::ToString};
use core::{fmt::Display, marker::PhantomData};
use crate as burn;
use crate::{
module::{
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
ModuleMapper, ModuleVisitor,
},
record::{PrecisionSettings, Record},
};
use burn_tensor::{
BasicAutodiffOps, BasicOps, Tensor,
backend::{AutodiffBackend, Backend},
ops::Device,
};
/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new, Default)]
pub struct ConstantRecord;
impl serde::Serialize for ConstantRecord {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// nothing to serialize
S::serialize_none(serializer)
}
}
impl<'de> serde::Deserialize<'de> for ConstantRecord {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_option(serde::de::IgnoredAny).ok();
Ok(ConstantRecord::new())
}
}
impl<B: Backend> Record<B> for ConstantRecord {
type Item<S: PrecisionSettings> = ConstantRecord;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
item
}
}
/// Constant macro.
#[macro_export]
macro_rules! constant {
(module) => {
type Record = burn::module::ConstantRecord;
fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn into_record(self) -> Self::Record {
burn::module::ConstantRecord::new()
}
fn to_device(self, _: &B::Device) -> Self {
self
}
fn fork(self, _: &B::Device) -> Self {
self
}
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
devices
}
};
(ad_module, $type:ty) => {
type InnerModule = $type;
fn valid(&self) -> Self::InnerModule {
self.clone()
}
fn from_inner(module: Self::InnerModule) -> Self {
module
}
};
($type:ty) => {
impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
constant!(module);
}
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
constant!(ad_module, $type);
}
impl burn::module::ModuleDisplayDefault for $type {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
let string = format!("{}", self);
content.add_formatted(&string).optional()
}
}
impl burn::module::ModuleDisplay for $type {}
};
}
// General Types
constant!(alloc::string::String);
constant!(bool);
// Float Types
constant!(f64);
constant!(f32);
constant!(half::bf16);
constant!(half::f16);
// Unsigned Integer Types
constant!(usize);
constant!(u64);
constant!(u32);
constant!(u16);
constant!(u8);
// Signed Integer Types
constant!(isize);
constant!(i64);
constant!(i32);
constant!(i16);
constant!(i8);
impl burn::module::ModuleDisplay for str {}
impl burn::module::ModuleDisplayDefault for str {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
content.add_formatted(&self).optional()
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
type Record = ConstantRecord;
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn into_record(self) -> Self::Record {
ConstantRecord
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn to_device(self, device: &B::Device) -> Self {
self.to_device(device)
}
fn fork(self, device: &B::Device) -> Self {
self.to_device(device)
}
fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
let device = self.device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
fn content(&self, content: Content) -> Option<Content> {
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().as_slice());
content.add_single(&string).optional()
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
fn valid(&self) -> Self::InnerModule {
self.clone().inner()
}
fn from_inner(tensor: Self::InnerModule) -> Self {
Tensor::from_inner(tensor)
}
}
impl<B: Backend> Module<B> for PhantomData<B> {
type Record = ConstantRecord;
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}
fn to_device(self, _: &Device<B>) -> Self {
self
}
fn fork(self, _: &Device<B>) -> Self {
self
}
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}
impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
fn content(&self, content: Content) -> Option<Content> {
content.add_single(&"PhantomData".to_string()).optional()
}
}
impl<B: Backend> ModuleDisplay for PhantomData<B> {}
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
type InnerModule = PhantomData<B::InnerBackend>;
fn valid(&self) -> Self::InnerModule {
PhantomData
}
fn from_inner(_module: Self::InnerModule) -> Self {
PhantomData
}
}
/// Container to satisfy the Module trait for types that are not modules.
#[derive(Clone, Debug)]
pub struct Ignored<T>(pub T);
impl<B, T> Module<B> for Ignored<T>
where
B: Backend,
T: Sync + Send + core::fmt::Debug + Clone,
{
type Record = ConstantRecord;
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}
fn to_device(self, _: &Device<B>) -> Self {
self
}
fn fork(self, _: &Device<B>) -> Self {
self
}
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}
impl<T> ModuleDisplayDefault for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn content(&self, content: Content) -> Option<Content> {
// For now, just print the debug representation of the ignored value
content.add_single(&format!("{:?}", self.0)).optional()
}
}
impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
impl<T> Display for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
where
B: AutodiffBackend,
T: Sync + Send + core::fmt::Debug + Clone,
{
type InnerModule = Ignored<T>;
fn valid(&self) -> Self::InnerModule {
self.clone()
}
fn from_inner(module: Self::InnerModule) -> Self {
module
}
}
// Implement deref for Ignored
impl<T> core::ops::Deref for Ignored<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use core::marker::PhantomData;
use burn_tensor::backend::Backend;
use burn_tensor::{Device, Tensor};
use crate::TestBackend;
use crate::{
TestAutodiffBackend,
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
};
use burn::module::Module;
use crate as burn;
#[test]
fn tensor_load_record_setting() {
let device: &Device<TestAutodiffBackend> = &Default::default();
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
let bytes = Recorder::<TestAutodiffBackend>::record(
&byte_recorder,
tensor.clone().into_record(),
(),
)
.unwrap();
let no_grad_is_require_grad = tensor
.clone()
.no_grad()
.load_record(
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
.unwrap(),
)
.is_require_grad();
let with_default_is_require_grad = tensor
.load_record(
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
.unwrap(),
)
.is_require_grad();
assert!(!no_grad_is_require_grad);
assert!(!with_default_is_require_grad);
}
#[test]
fn empty_module_with_phantom() {
#[derive(Module, Debug, new)]
struct EmptyModule<B: Backend> {
_phantom: PhantomData<B>,
}
let _module = EmptyModule::<TestBackend>::new();
assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
}
}

View File

@@ -0,0 +1,116 @@
use core::hash::{BuildHasher, Hasher};
use alloc::string::String;
use burn_std::id::IdGenerator;
use data_encoding::BASE32_DNSSEC;
// Hashbrown changed its default hasher in 0.15, but there are some issues
// https://github.com/rust-lang/hashbrown/issues/577
// Also, `param_serde_deserialize_legacy_uuid` doesn't pass with the default hasher.
type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
/// Parameter ID.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
pub struct ParamId {
value: u64,
}
impl From<u64> for ParamId {
fn from(value: u64) -> Self {
Self { value }
}
}
impl Default for ParamId {
fn default() -> Self {
Self::new()
}
}
impl ParamId {
/// Create a new parameter ID.
pub fn new() -> Self {
Self {
value: IdGenerator::generate(),
}
}
/// Gets the internal value of the id.
pub fn val(&self) -> u64 {
self.value
}
/// Convert the parameter ID into a string.
pub fn serialize(self) -> String {
BASE32_DNSSEC.encode(&self.value.to_le_bytes())
}
/// Deserialize a param id.
///
/// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).
pub fn deserialize(encoded: &str) -> ParamId {
let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
Ok(bytes) => {
let mut buffer = [0u8; 8];
buffer[..bytes.len()].copy_from_slice(&bytes);
u64::from_le_bytes(buffer)
}
Err(err) => match uuid::Uuid::try_parse(encoded) {
// Backward compatibility with uuid parameter identifiers
Ok(id) => {
// Hash the 128-bit uuid to 64-bit
// Though not *theoretically* unique, the probability of a collision should be extremely low
let mut hasher = DefaultHashBuilder::default().build_hasher();
// let mut hasher = DefaultHasher::new();
hasher.write(id.as_bytes());
hasher.finish()
}
Err(_) => panic!("Invalid id. {err}"),
},
};
ParamId::from(u64_id)
}
}
impl core::fmt::Display for ParamId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.serialize())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn param_serde_deserialize() {
let val = ParamId::from(123456u64);
let deserialized = ParamId::deserialize(&val.serialize());
assert_eq!(val, deserialized);
}
#[test]
fn param_serde_deserialize_legacy() {
let legacy_val = [45u8; 6];
let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
}
#[test]
fn param_serde_deserialize_legacy_uuid() {
// Ensure support for legacy uuid deserialization and make sure it results in the same output
let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
let param_id1 = ParamId::deserialize(legacy_id);
let param_id2 = ParamId::deserialize(legacy_id);
assert_eq!(param_id1, param_id2);
}
#[test]
#[should_panic = "Invalid id."]
fn param_serde_deserialize_invalid_id() {
let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
let _ = ParamId::deserialize(invalid_uuid);
}
}

View File

@@ -0,0 +1,13 @@
mod base;
mod constant;
mod id;
mod primitive;
mod running;
mod tensor;
mod visitor;
pub use base::*;
pub use constant::*;
pub use id::*;
pub use running::*;
pub use visitor::*;

View File

@@ -0,0 +1,426 @@
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use alloc::{format, string::ToString, vec::Vec};
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::Device,
};
use core::fmt::Debug;
impl<T, B> Module<B> for Option<T>
where
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = Option<T::Record>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
if let Some(module) = self {
module.visit(visitor)
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.map(|module| module.map(mapper))
}
fn load_record(self, record: Self::Record) -> Self {
let is_constant = self.num_params() == 0;
if is_constant {
return self;
}
self.zip(record)
.map(|(module, record)| module.load_record(record))
}
fn into_record(self) -> Self::Record {
self.map(Module::into_record)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|module| module.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.map(|module| module.fork(device))
}
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
if let Some(module) = self.as_ref() {
devices = module.collect_devices(devices);
}
devices
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
fn content(&self, content: Content) -> Option<Content> {
match self {
Some(module) => content.add_single(module).optional(),
None => content.add_single("None").optional(),
}
}
}
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
impl<T, B> AutodiffModule<B> for Option<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
B: AutodiffBackend,
{
type InnerModule = Option<T::InnerModule>;
fn valid(&self) -> Self::InnerModule {
self.as_ref().map(|module| module.valid())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
}
}
impl<T, B> Module<B> for Vec<T>
where
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = Vec<T::Record>;
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
num_params += module.num_params();
}
num_params
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
for (i, module) in self.iter().enumerate() {
let index_str = alloc::format!("{}", i);
visitor.enter_module(&index_str, "Vec");
module.visit(visitor);
visitor.exit_module(&index_str, "Vec");
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.into_iter()
.enumerate()
.map(|(i, module)| {
let index_str = alloc::format!("{}", i);
mapper.enter_module(&index_str, "Vec");
let mapped = module.map(mapper);
mapper.exit_module(&index_str, "Vec");
mapped
})
.collect()
}
fn into_record(self) -> Self::Record {
self.into_iter().map(Module::into_record).collect()
}
fn load_record(self, record: Self::Record) -> Self {
assert_eq!(
self.len(),
record.len(),
r#"[Load Record Error] The vec record does not the same length as the module.
Make sure you module initialization is compatible with the record being loaded.
"#,
);
self.into_iter()
.zip(record)
.map(|(module, record)| module.load_record(record))
.collect()
}
fn to_device(self, device: &Device<B>) -> Self {
self.into_iter()
.map(|module| module.to_device(device))
.collect()
}
fn fork(self, device: &Device<B>) -> Self {
self.into_iter().map(|module| module.fork(device)).collect()
}
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.collect_devices(devices);
}
devices
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{i}");
acc.add(&index, module)
})
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
.optional()
}
}
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
impl<T, B> AutodiffModule<B> for Vec<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
B: AutodiffBackend,
{
type InnerModule = Vec<T::InnerModule>;
fn valid(&self) -> Self::InnerModule {
self.iter().map(|module| module.valid()).collect()
}
fn from_inner(module: Self::InnerModule) -> Self {
module
.into_iter()
.map(|module| T::from_inner(module))
.collect()
}
}
impl<const N: usize, T, B> Module<B> for [T; N]
where
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = [T::Record; N];
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.collect_devices(devices);
}
devices
}
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
num_params += module.num_params();
}
num_params
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
for (i, module) in self.iter().enumerate() {
let index_str = alloc::format!("{}", i);
visitor.enter_module(&index_str, "Array");
module.visit(visitor);
visitor.exit_module(&index_str, "Array");
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut result = Vec::with_capacity(N);
for (i, module) in IntoIterator::into_iter(self).enumerate() {
let index_str = alloc::format!("{}", i);
mapper.enter_module(&index_str, "Array");
let mapped = module.map(mapper);
mapper.exit_module(&index_str, "Array");
result.push(mapped);
}
result
.try_into()
.unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
}
fn load_record(self, record: Self::Record) -> Self {
self.into_iter()
.zip(record)
.map(|(module, record)| module.load_record(record))
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
fn into_record(self) -> Self::Record {
self.map(Module::into_record)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|module| module.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.map(|module| module.fork(device))
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{i}");
acc.add(&index, module)
})
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
.optional()
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Clone,
T::InnerModule: Debug,
B: AutodiffBackend,
{
type InnerModule = [T::InnerModule; N];
fn valid(&self) -> Self::InnerModule {
self.clone().map(|module| module.valid())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
}
}
/// A macro for generating implementations for tuple modules of different sizes.
/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
/// Would generate an implementation for a tuple of size 2.
/// For this macro to work properly, please adhere to the convention:
/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
macro_rules! impl_module_tuple {
// `$l` represents the generic modules.
// `$i` represents the indices of the modules in the tuple.
([$($l:ident),*][$($i:tt),*]) => {
impl<B, $($l,)*> Module<B> for ($($l,)*)
where
B: Backend,
$($l: Module<B> + Debug + Send + Clone,)*
{
type Record = ($($l::Record),*);
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
$(devices = self.$i.collect_devices(devices);)*
devices
}
fn fork(self, device: &Device<B>) -> Self {
($(self.$i.fork(device),)*)
}
fn to_device(self, device: &Device<B>) -> Self {
($(self.$i.to_device(device),)*)
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
$(
let index_str = $i.to_string();
visitor.enter_module(&index_str, "Tuple");
self.$i.visit(visitor);
visitor.exit_module(&index_str, "Tuple");
)*
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
($(
{
let index_str = $i.to_string();
mapper.enter_module(&index_str, "Tuple");
let mapped = self.$i.map(mapper);
mapper.exit_module(&index_str, "Tuple");
mapped
}
,)*)
}
fn load_record(self, record: Self::Record) -> Self {
($(self.$i.load_record(record.$i),)*)
}
fn into_record(self) -> Self::Record {
($(self.$i.into_record(),)*)
}
}
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
where
B: AutodiffBackend,
$($l: AutodiffModule<B> + Debug + Send + Clone,)*
{
type InnerModule = ($($l::InnerModule,)*);
fn valid(&self) -> Self::InnerModule {
($(self.$i.valid(),)*)
}
fn from_inner(module: Self::InnerModule) -> Self {
($($l::from_inner(module.$i),)*)
}
}
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
where
$($l: ModuleDisplay,)*
{
fn content(&self, content: Content) -> Option<Content> {
let content = content
$(.add(&format!("{}", $i), &self.$i))*
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
content.optional()
}
}
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
};
}
impl_module_tuple!([L0, L1][0, 1]);
impl_module_tuple!([L0, L1, L2][0, 1, 2]);
impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn dont_override_constant_module_when_loading_record() {
let module = Some(42);
let record = Module::<TestBackend>::into_record(module);
let loaded = Module::<TestBackend>::load_record(module, record);
assert_eq!(loaded, module);
}
#[test]
fn dont_override_constant_module_when_loading_none_record() {
let module = Some(42);
let record = None;
let loaded = Module::<TestBackend>::load_record(module, record);
assert_eq!(loaded, module);
}
}

View File

@@ -0,0 +1,258 @@
use super::ParamId;
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor, Param,
};
use alloc::string::ToString;
use alloc::vec::Vec;
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
use burn_std::stub::Mutex;
use burn_tensor::{
Tensor,
backend::{AutodiffBackend, Backend},
ops::Device,
};
#[cfg(feature = "std")]
mod threading {
pub(super) use std::collections::HashMap;
pub(super) use std::thread::ThreadId;
#[inline(always)]
pub(super) fn get_thread_current_id() -> ThreadId {
std::thread::current().id()
}
}
#[cfg(not(feature = "std"))]
mod threading {
pub(super) use burn_std::stub::ThreadId;
pub(super) use hashbrown::HashMap;
#[inline(always)]
pub(super) fn get_thread_current_id() -> ThreadId {
panic!("Current thread id is not available")
}
}
// Re-export items from the disabled/enabled blocks
use threading::*;
/// A state that can be updated during the forward pass while being thread safe.
///
/// # Note
///
/// The state value is the average of all updates on all threads.
#[derive(Clone, Debug)]
pub struct RunningState<V> {
id: ParamId,
values: Arc<Mutex<HashMap<ThreadId, V>>>,
value: Arc<Mutex<V>>,
}
// Implement display for the module
impl<V> core::fmt::Display for RunningState<V> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "RunningState(id={})", self.id)
}
}
impl<V> ModuleDisplayDefault for RunningState<V> {
fn content(&self, content: Content) -> Option<Content> {
content
.add_formatted(&"RunningState".to_string())
.optional()
}
}
impl<V> ModuleDisplay for RunningState<V> {}
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
let tensor = self.value.lock().unwrap();
let param = Param::initialized(self.id, tensor.clone());
visitor.visit_float(&param)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut tensor = self.value.lock().unwrap();
let param = Param::initialized(self.id, tensor.clone());
let param_out = mapper.map_float(param);
let (_, tensor_out, _) = param_out.consume();
*tensor = tensor_out;
core::mem::drop(tensor);
self
}
fn into_record(self) -> Self::Record {
self.sync();
let tensor = self.value.lock().unwrap();
Param::initialized(self.id, tensor.clone())
}
fn load_record(mut self, record: Self::Record) -> Self {
let mut tensor = self.value.lock().unwrap();
*tensor = record.val().to_device(&tensor.device());
self.id = record.id;
core::mem::drop(tensor);
self
}
fn to_device(self, device: &Device<B>) -> Self {
let mut tensor = self.value.lock().unwrap();
let tensor_out = tensor.clone().to_device(device);
*tensor = tensor_out;
core::mem::drop(tensor);
self
}
fn fork(self, device: &Device<B>) -> Self {
self.to_device(device) // Same thing here since no grad.
}
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
let device = self.value.lock().unwrap().device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
/// Create a new running state.
pub fn new(value: Tensor<B, D>) -> Self {
Self {
id: ParamId::new(),
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(Mutex::new(value)),
}
}
/// Create a new running state.
pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {
Self {
id,
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(Mutex::new(value)),
}
}
/// Create a new running state from a record.
pub fn from_record(record: Param<Tensor<B, D>>) -> Self {
let tensor = record.val();
Self {
id: record.id,
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(Mutex::new(tensor)),
}
}
/// Update the value on the current thread.
pub fn update(&self, value: Tensor<B, D>) {
let thread_id = get_thread_current_id();
let mut map = self.values.lock().unwrap();
if map.contains_key(&thread_id) {
self.update_value(&mut map);
}
map.insert(thread_id, value);
}
/// Get the current value,
///
/// # Note
///
/// The current value might be outdated by one update.
pub fn value(&self) -> Tensor<B, D> {
let value = self.value.lock().unwrap();
value.clone()
}
/// Get the current value and make sure it is sync.
///
/// # Note
///
/// Don't use this function after an update on the same thread where other threads might have to
/// register their update before the actual synchronization needs to happen.
pub fn value_sync(&self) -> Tensor<B, D> {
let thread_id = get_thread_current_id();
let mut map = self.values.lock().unwrap();
if map.contains_key(&thread_id) {
self.update_value(&mut map);
}
let value = self.value.lock().unwrap();
value.clone()
}
fn sync(&self) {
let mut map = self.values.lock().unwrap();
if !map.is_empty() {
self.update_value(&mut map);
}
}
fn update_value(&self, map: &mut HashMap<ThreadId, Tensor<B, D>>) {
let mut value_updated: Option<Tensor<B, D>> = None;
let mut counter = 0;
for (_key, tensor) in map.drain() {
counter += 1;
value_updated = match value_updated {
Some(current) => {
let device = current.device();
Some(tensor.to_device(&device).add(current))
}
None => Some(tensor),
};
}
if let Some(value) = value_updated {
let value = value.div_scalar(counter);
let mut value_old = self.value.lock().unwrap();
*value_old = value;
}
}
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tensor<B, D>> {
type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;
fn valid(&self) -> Self::InnerModule {
self.sync();
let value = self.value();
RunningState::with_id(self.id, value.inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.sync();
let value = module.value();
RunningState::with_id(module.id, Tensor::from_inner(value))
}
}

View File

@@ -0,0 +1,571 @@
use super::{Param, ParamId, Parameter};
use crate::module::{
AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault,
ModuleMapper, ModuleVisitor,
};
use crate::tensor::{
Tensor,
backend::{AutodiffBackend, Backend},
};
use alloc::{format, string::ToString, vec::Vec};
use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
type Device = B::Device;
fn device(&self) -> Self::Device {
Tensor::device(self)
}
fn is_require_grad(&self) -> bool {
Tensor::is_require_grad(self)
}
fn set_require_grad(self, require_grad: bool) -> Self {
Tensor::set_require_grad(self, require_grad)
}
}
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
type Device = B::Device;
fn device(&self) -> Self::Device {
Tensor::device(self)
}
fn is_require_grad(&self) -> bool {
false
}
fn set_require_grad(self, _require_grad: bool) -> Self {
self
}
}
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
type Device = B::Device;
fn device(&self) -> Self::Device {
Tensor::device(self)
}
fn is_require_grad(&self) -> bool {
false
}
fn set_require_grad(self, _require_grad: bool) -> Self {
self
}
}
impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
/// Create a new parameter from a float tensor.
///
/// # Warnings
///
/// We strongly recommend using [Param::uninitialized] if you are using this method to
/// initialize parameters inside a module, since the tensor initialization will be lazy,
/// making the loading of weights more performant.
pub fn from_tensor(value: Tensor<B, D>) -> Self {
// When creating a parameter from a float tensor, we automatically mark it as requiring
// gradients, so that it can be updated by an optimizer.
Param::initialized(ParamId::new(), value.require_grad())
}
/// The shape of the parameter, **without triggering initialization**.
///
/// This is critical for shape validation during loading: when applying tensors to an
/// uninitialized parameter, we need to validate the shape without triggering the
/// initialization function (which would allocate an unnecessary tensor).
///
/// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
/// preserve lazy initialization.
pub fn lazy_shape(&self) -> burn_tensor::Shape {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.shape(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.shape.clone(),
None => self.shape(),
}
}
/// Create a new parameter from data.
pub fn from_data<T>(data: T, device: &B::Device) -> Self
where
T: Into<TensorData>,
{
// When creating a parameter from a float tensor, we automatically mark it as requiring
// gradients, so that it can be updated by an optimizer.
B::memory_persistent_allocations(device, data, |data| {
let value = Tensor::from_data(data, device);
Param::initialized(ParamId::new(), value.require_grad())
})
}
/// Transform a parameter for loading by applying load transformations.
///
/// This method is used to restore a parameter from a tensor (typically during deserialization).
/// It ensures the tensor is moved to the expected device, applies the param mapper's
/// `on_load` transformation, and preserves the autodiff settings (require_grad).
pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {
let mut new_tensor = tensor;
let mapper = self.param_mapper.clone();
let expected_device = self.lazy_device();
let expected_require_grad = self.lazy_is_require_grad();
// Make sure we load the tensor into the same module device.
if new_tensor.device() != expected_device {
new_tensor = new_tensor.to_device(&expected_device).detach();
}
new_tensor = mapper.on_load(new_tensor);
// Make sure we load the tensor with the same autodiff setting.
new_tensor = new_tensor.set_require_grad(expected_require_grad);
let mut loaded = Self::initialized(param_id, new_tensor);
loaded.param_mapper = mapper;
loaded
}
/// Transform a parameter for saving by applying save transformations.
///
/// This method is used to prepare a parameter for saving (typically during serialization).
/// It applies the param mapper's `on_save` transformation, which can be used
/// to modify the tensor before serialization (e.g., quantization, precision conversion).
pub fn transform_for_save(&self) -> Self {
let mut tensor = self.val();
let mapper = self.param_mapper.clone();
tensor = mapper.on_save(tensor);
Self::initialized(self.id, tensor)
}
}
impl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {
/// The shape of the parameter, **without triggering initialization**.
///
/// This is critical for shape validation during loading: when applying tensors to an
/// uninitialized parameter, we need to validate the shape without triggering the
/// initialization function (which would allocate an unnecessary tensor).
///
/// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
/// preserve lazy initialization.
pub fn lazy_shape(&self) -> burn_tensor::Shape {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.shape(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.shape.clone(),
None => self.shape(),
}
}
/// Transform a parameter for loading by applying load transformations.
///
/// This method is used to restore a parameter from a tensor (typically during deserialization).
/// It ensures the tensor is moved to the expected device and applies the param mapper's
/// `on_load` transformation.
pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {
let mut new_tensor = tensor;
let mapper = self.param_mapper.clone();
let expected_device = self.lazy_device();
// Make sure we load the tensor into the same module device.
if new_tensor.device() != expected_device {
new_tensor = new_tensor.to_device(&expected_device);
}
new_tensor = mapper.on_load(new_tensor);
let mut loaded = Self::initialized(param_id, new_tensor);
loaded.param_mapper = mapper;
loaded
}
/// Transform a parameter for saving by applying save transformations.
///
/// This method is used to prepare a parameter for saving (typically during serialization).
/// It applies the param mapper's `on_save` transformation, which can be used
/// to modify the tensor before serialization (e.g., quantization, precision conversion).
pub fn transform_for_save(&self) -> Self {
let mut tensor = self.val();
let mapper = self.param_mapper.clone();
tensor = mapper.on_save(tensor);
Self::initialized(self.id, tensor)
}
}
impl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {
/// The shape of the parameter, **without triggering initialization**.
///
/// This is critical for shape validation during loading: when applying tensors to an
/// uninitialized parameter, we need to validate the shape without triggering the
/// initialization function (which would allocate an unnecessary tensor).
///
/// **Returns:**
/// - For uninitialized params: the shape from the `Uninitialized` struct
/// - For initialized params: the actual shape from the tensor
///
/// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
/// preserve lazy initialization.
pub fn lazy_shape(&self) -> burn_tensor::Shape {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.shape(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.shape.clone(),
None => self.shape(),
}
}
/// Transform a parameter for loading by applying load transformations.
///
/// This method is used to restore a parameter from a tensor (typically during deserialization).
/// It ensures the tensor is moved to the expected device and applies the param mapper's
/// `on_load` transformation.
pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {
let mut new_tensor = tensor;
let mapper = self.param_mapper.clone();
let expected_device = self.lazy_device();
// Make sure we load the tensor into the same module device.
if new_tensor.device() != expected_device {
new_tensor = new_tensor.to_device(&expected_device);
}
new_tensor = mapper.on_load(new_tensor);
let mut loaded = Self::initialized(param_id, new_tensor);
loaded.param_mapper = mapper;
loaded
}
/// Transform a parameter for saving by applying save transformations.
///
/// This method is used to prepare a parameter for saving (typically during serialization).
/// It applies the param mapper's `on_save` transformation, which can be used
/// to modify the tensor before serialization (e.g., quantization, precision conversion).
pub fn transform_for_save(&self) -> Self {
let mut tensor = self.val();
let mapper = self.param_mapper.clone();
tensor = mapper.on_save(tensor);
Self::initialized(self.id, tensor)
}
}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit_float(self)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
mapper.map_float(self)
}
fn into_record(self) -> Self::Record {
self.transform_for_save()
}
fn load_record(self, record: Self::Record) -> Self {
let (record_param_id, record_tensor, _) = record.consume();
self.transform_for_load(record_tensor, record_param_id)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|tensor| tensor.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.map(|tensor| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();
if is_require_grad {
tensor = tensor.require_grad();
}
tensor
})
}
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
let device = self.val().device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
self.shape().as_slice()
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
type Record = Param<Tensor<B, D, Int>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit_int(self)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
mapper.map_int(self)
}
fn into_record(self) -> Self::Record {
self.transform_for_save()
}
fn load_record(self, record: Self::Record) -> Self {
let (record_param_id, record_tensor, _) = record.consume();
self.transform_for_load(record_tensor, record_param_id)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|tensor| tensor.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.to_device(device) // Don't support autodiff.
}
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
let device = self.val().device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
self.shape().as_slice()
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
type Record = Param<Tensor<B, D, Bool>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit_bool(self)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
mapper.map_bool(self)
}
fn into_record(self) -> Self::Record {
self.transform_for_save()
}
fn load_record(self, record: Self::Record) -> Self {
let (record_param_id, record_tensor, _) = record.consume();
self.transform_for_load(record_tensor, record_param_id)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|tensor| tensor.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.to_device(device) // Don't support autodiff.
}
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
let device = self.val().device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
self.shape().as_slice()
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
fn valid(&self) -> Self::InnerModule {
// Preserve initialized param `require_grad` state, but reset the inner value's
let require_grad = self.require_grad;
let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false));
param.require_grad = require_grad;
param
}
fn from_inner(module: Self::InnerModule) -> Self {
// Reinstate the param's `require_grad` state
let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad);
Param::initialized(module.id, tensor)
}
}
impl<const D: usize, B: AutodiffBackend> HasAutodiffModule<B>
for Param<Tensor<B::InnerBackend, D>>
{
type TrainModule = Param<Tensor<B, D>>;
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
fn valid(&self) -> Self::InnerModule {
Param::initialized(self.id, self.val().inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
Param::initialized(module.id, Tensor::from_inner(module.val()))
}
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
fn valid(&self) -> Self::InnerModule {
Param::initialized(self.id, self.val().inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
Param::initialized(module.id, Tensor::from_inner(module.val()))
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::{
TestAutodiffBackend,
module::Module,
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
};
#[test]
fn test_load_record_setting() {
let device = Default::default();
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
let bytes = byte_recorder
.record(
Param::initialized(ParamId::new(), tensor.clone()).into_record(),
(),
)
.unwrap();
let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
.no_grad()
.load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
.is_require_grad();
let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
.load_record(byte_recorder.load(bytes, &device).unwrap())
.is_require_grad();
assert!(!no_grad_is_require_grad);
assert!(with_default_is_require_grad);
}
#[test]
fn test_param_require_grad_stateful() {
let device = Default::default();
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
let param = Param::initialized(ParamId::new(), tensor);
assert!(param.is_require_grad());
assert!(param.require_grad);
let param = param.valid();
assert!(!param.is_require_grad());
assert!(param.require_grad); // stateful
// Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying:
// let param: Param<Tensor<TestAutodiffBackend, _>> = param.train();
let param = param.train::<TestAutodiffBackend>();
assert!(param.is_require_grad());
assert!(param.require_grad); // stateful
let param = param.no_grad();
assert!(!param.is_require_grad());
assert!(!param.require_grad); // stateful
let param = param.valid();
assert!(!param.is_require_grad()); // always
assert!(!param.require_grad); // stateful
let param = param.train::<TestAutodiffBackend>();
assert!(!param.is_require_grad());
assert!(!param.require_grad); // stateful
}
}

View File

@@ -0,0 +1,38 @@
use super::{Param, ParamId};
use crate::module::{Module, ModuleVisitor};
use alloc::vec::Vec;
use burn_tensor::{Bool, Int, Tensor, backend::Backend};
use core::marker::PhantomData;
struct ParamIdCollector<'a, M> {
param_ids: &'a mut Vec<ParamId>,
phantom: PhantomData<M>,
}
impl<B, M> ModuleVisitor<B> for ParamIdCollector<'_, M>
where
B: Backend,
M: Module<B>,
{
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
self.param_ids.push(param.id);
}
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
self.param_ids.push(param.id);
}
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
self.param_ids.push(param.id);
}
}
/// List all the parameter ids in a module.
pub fn list_param_ids<M: Module<B>, B: Backend>(module: &M) -> Vec<ParamId> {
let mut params_ids = Vec::new();
let mut visitor = ParamIdCollector {
param_ids: &mut params_ids,
phantom: PhantomData::<M>,
};
module.visit(&mut visitor);
params_ids
}

View File

@@ -0,0 +1,65 @@
use burn_tensor::{
Tensor,
backend::Backend,
quantization::{Calibration, QuantScheme, compute_q_params, compute_range},
};
use crate::module::{ModuleMapper, Param};
/// Describes how to quantize a module.
pub struct Quantizer {
/// The calibration method used in quantization.
pub calibration: Calibration,
/// The quantization scheme.
pub scheme: QuantScheme,
}
impl<B: Backend> ModuleMapper<B> for Quantizer {
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let range = compute_range(&self.scheme, &tensor, &self.calibration);
let qparams = compute_q_params(&self.scheme, range);
let tensor = tensor.quantize(&self.scheme, qparams);
Param::from_mapped_value(id, tensor, mapper)
}
}
#[cfg(all(test, not(feature = "test-tch")))]
mod tests {
use crate::test_utils::SimpleLinear;
use crate::{
TestBackend,
module::{Module, Quantizer},
};
use burn_tensor::{
Device, Tolerance,
ops::QuantizedTensor,
quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue},
};
type B = TestBackend;
#[test]
fn should_quantize_module() {
let device: Device<B> = Default::default();
let module = SimpleLinear::<B>::new(32, 32, &device);
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::Tensor)
.with_param(QuantParam::F32);
let result = module.weight.val();
let calibration = Calibration::MinMax;
let mut quantizer = Quantizer {
calibration,
scheme,
};
let q_module = module.quantize_weights(&mut quantizer);
let q_result = q_module.weight.val().dequantize();
result
.into_data()
.assert_approx_eq::<f32>(&q_result.into_data(), Tolerance::permissive());
}
}

View File

@@ -0,0 +1,203 @@
use super::{Module, ModuleMapper};
use burn_tensor::{
Element, ElementConversion, Tensor, TensorData,
backend::Backend,
ops::{FloatElem, IntElem},
};
use rand::{RngExt, SeedableRng};
#[derive(Debug)]
/// Overrides float and int tensors of [burn modules](super::Module).
///
/// This is useful for testing.
pub struct Reinitializer<B: Backend> {
float: ReinitStrategy<FloatElem<B>>,
int: ReinitStrategy<IntElem<B>>,
}
#[derive(Debug)]
#[allow(missing_docs)]
enum ReinitStrategy<E> {
Range { min: E, max: E },
Constant { value: E },
Random { seed: u64, min: E, max: E },
}
impl<B: Backend> Default for Reinitializer<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> Reinitializer<B> {
/// Create a new [reinitializer](Reinitializer).
pub fn new() -> Self {
Self {
float: ReinitStrategy::Constant {
value: 0.elem::<FloatElem<B>>(),
},
int: ReinitStrategy::Constant {
value: 0.elem::<IntElem<B>>(),
},
}
}
/// Apply the reinitialization to the given [module](Module).
pub fn apply<M: Module<B>>(mut self, module: M) -> M {
module.map(&mut self)
}
/// Set the reinitialization strategy to constant for all tensors.
pub fn constant(self, constant: f64) -> Self {
self.constant_float(constant).constant_int(constant as i64)
}
/// Set the reinitialization strategy to constant for float tensors.
pub fn constant_float(mut self, constant: f64) -> Self {
self.float = ReinitStrategy::Constant {
value: constant.elem(),
};
self
}
/// Set the reinitialization strategy to constant for int tensors.
pub fn constant_int(mut self, constant: i64) -> Self {
self.int = ReinitStrategy::Constant {
value: constant.elem(),
};
self
}
/// Set the reinitialization strategy to random for all tensors.
pub fn random(self, seed: u64, min: f64, max: f64) -> Self {
self.random_float(seed, min, max)
.random_int(seed, min as i64, max as i64)
}
/// Set the reinitialization strategy to random for float tensors.
pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {
self.float = ReinitStrategy::Random {
seed,
min: min.elem(),
max: max.elem(),
};
self
}
/// Set the reinitialization strategy to random for int tensors.
pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {
self.int = ReinitStrategy::Random {
seed,
min: min.elem(),
max: max.elem(),
};
self
}
/// Set the reinitialization strategy to range for all tensors.
pub fn range(self, min: f64, max: f64) -> Self {
self.range_float(min, max).range_int(min as i64, max as i64)
}
/// Set the reinitialization strategy to range for float tensors.
pub fn range_float(mut self, min: f64, max: f64) -> Self {
self.float = ReinitStrategy::Range {
min: min.elem(),
max: max.elem(),
};
self
}
/// Set the reinitialization strategy to range for int tensors.
pub fn range_int(mut self, min: i64, max: i64) -> Self {
self.int = ReinitStrategy::Range {
min: min.elem(),
max: max.elem(),
};
self
}
}
impl<B: Backend> ModuleMapper<B> for Reinitializer<B> {
fn map_float<const D: usize>(
&mut self,
param: super::Param<Tensor<B, D>>,
) -> super::Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let device = tensor.device();
let shape = tensor.shape();
let num_elements = shape.num_elements();
let tensor = match &self.float {
ReinitStrategy::Range { min, max } => {
let tensor = Tensor::arange(0..num_elements as i64, &device)
.reshape(shape)
.float();
let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);
tensor * factor + bias
}
ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
ReinitStrategy::Random { seed, min, max } => {
let data = TensorData::new(
random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),
shape,
);
Tensor::from_data(data, &device)
}
};
super::Param::from_mapped_value(id, tensor, mapper)
}
fn map_int<const D: usize>(
&mut self,
param: super::Param<Tensor<B, D, burn_tensor::Int>>,
) -> super::Param<Tensor<B, D, burn_tensor::Int>> {
let (id, tensor, mapper) = param.consume();
let device = tensor.device();
let shape = tensor.shape();
let num_elements = shape.num_elements();
let tensor = match &self.int {
ReinitStrategy::Range { min, max } => {
let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);
let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);
tensor * factor + bias
}
ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
ReinitStrategy::Random { seed, min, max } => {
let data = TensorData::new(
random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),
shape,
);
Tensor::from_data(data, &device)
}
};
super::Param::from_mapped_value(id, tensor, mapper)
}
fn map_bool<const D: usize>(
&mut self,
param: super::Param<Tensor<B, D, burn_tensor::Bool>>,
) -> super::Param<Tensor<B, D, burn_tensor::Bool>> {
let (id, tensor, mapper) = param.consume();
super::Param::from_mapped_value(id, tensor, mapper)
}
}
fn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {
let range = max.elem::<f64>() - min.elem::<f64>();
let factor = range / num_elements as f64;
let bias = min.elem::<f64>();
(factor.elem(), bias.elem())
}
fn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let dist = rand::distr::Uniform::new(min, max).unwrap();
(0..num_elements)
.map(|_| rng.sample(dist))
.map(|e| e.elem::<E>())
.collect()
}

View File

@@ -0,0 +1,17 @@
pub use burn_derive::Record;
use burn_tensor::backend::Backend;
use super::PrecisionSettings;
use serde::{Serialize, de::DeserializeOwned};
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record<B: Backend>: Send {
/// Type of the item that can be serialized and deserialized.
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned + Clone;
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
/// Convert the given item into a record.
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self;
}

View File

@@ -0,0 +1,421 @@
use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
use burn_tensor::backend::Backend;
use core::marker::PhantomData;
use flate2::{Compression, read::GzDecoder, write::GzEncoder};
use serde::{Serialize, de::DeserializeOwned};
use std::io::{BufReader, BufWriter};
use std::{fs::File, path::PathBuf};
/// Recorder trait specialized to save and load data to and from files.
pub trait FileRecorder<B: Backend>:
Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
{
/// File extension of the format used by the recorder.
fn file_extension() -> &'static str;
}
/// Default [file recorder](FileRecorder).
pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
/// File recorder using the [bincode format](bincode).
#[derive(new, Debug, Default, Clone)]
pub struct BinFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
/// File recorder using the [bincode format](bincode) compressed with gzip.
#[derive(new, Debug, Default, Clone)]
pub struct BinGzFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
/// File recorder using the [json format](serde_json) compressed with gzip.
#[derive(new, Debug, Default, Clone)]
pub struct JsonGzFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
/// File recorder using [pretty json format](serde_json) for easy readability.
#[derive(new, Debug, Default, Clone)]
pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
/// File recorder using the [named msgpack](rmp_serde) format compressed with gzip.
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
/// File recorder using the [named msgpack](rmp_serde) format.
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
fn file_extension() -> &'static str {
"bin.gz"
}
}
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
fn file_extension() -> &'static str {
"bin"
}
}
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
fn file_extension() -> &'static str {
"json.gz"
}
}
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
fn file_extension() -> &'static str {
"json"
}
}
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
fn file_extension() -> &'static str {
"mpk.gz"
}
}
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
fn file_extension() -> &'static str {
"mpk"
}
}
macro_rules! str2reader {
(
$file:expr
) => {{
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
let path = $file.as_path();
File::open(path)
.map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
})
.map(|file| BufReader::new(file))
}};
}
macro_rules! str2writer {
(
$file:expr
) => {{
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
let path = $file.as_path();
log::debug!("Writing to file: {:?}", path);
// Add parent directories if they don't exist
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).ok();
}
if path.exists() {
log::warn!("File exists, replacing");
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
}
File::create(path)
.map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
})
.map(|file| BufWriter::new(file))
}};
}
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let config = bin_config();
let writer = str2writer!(file)?;
let mut writer = GzEncoder::new(writer, Compression::default());
bincode::serde::encode_into_std_write(&item, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(
&self,
file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let reader = str2reader!(file)?;
let mut reader = GzDecoder::new(reader);
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let config = bin_config();
let mut writer = str2writer!(file)?;
bincode::serde::encode_into_std_write(&item, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(
&self,
file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let mut reader = str2reader!(file)?;
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let writer = str2writer!(file)?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &item)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(
&self,
file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let reader = str2reader!(file)?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let writer = str2writer!(file)?;
serde_json::to_writer_pretty(writer, &item)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(
&self,
file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let reader = str2reader!(file)?;
let state = serde_json::from_reader(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let writer = str2writer!(file)?;
let mut writer = GzEncoder::new(writer, Compression::default());
rmp_serde::encode::write_named(&mut writer, &item)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(
&self,
file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let reader = str2reader!(file)?;
let reader = GzDecoder::new(reader);
let state = rmp_serde::decode::from_read(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let mut writer = str2writer!(file)?;
rmp_serde::encode::write_named(&mut writer, &item)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(
&self,
file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let reader = str2reader!(file)?;
let state = rmp_serde::decode::from_read(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate as burn;
use crate::config::Config;
use crate::module::Ignored;
use crate::test_utils::SimpleLinear;
use crate::{
TestBackend,
module::Module,
record::{BinBytesRecorder, FullPrecisionSettings},
};
use burn_tensor::backend::Backend;
#[inline(always)]
fn file_path() -> PathBuf {
std::env::temp_dir()
.as_path()
.join("burn_test_file_recorder")
}
#[test]
fn test_can_save_and_load_jsongz_format() {
test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
}
#[test]
fn test_can_save_and_load_bin_format() {
test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
}
#[test]
fn test_can_save_and_load_bingz_format() {
test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
}
#[test]
fn test_can_save_and_load_pretty_json_format() {
test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
}
#[test]
fn test_can_save_and_load_mpkgz_format() {
test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
}
#[test]
fn test_can_save_and_load_mpk_format() {
test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
}
fn test_can_save_and_load<Recorder>(recorder: Recorder)
where
Recorder: FileRecorder<TestBackend>,
{
let device = Default::default();
let model_before = create_model(&device);
recorder
.record(model_before.clone().into_record(), file_path())
.unwrap();
let model_after =
create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
let model_bytes_before = byte_recorder
.record(model_before.into_record(), ())
.unwrap();
let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
assert_eq!(model_bytes_after, model_bytes_before);
}
#[derive(Config, Debug)]
pub enum PaddingConfig2d {
Same,
Valid,
Explicit(usize, usize),
}
// Dummy model with different record types
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
linear1: SimpleLinear<B>,
phantom: PhantomData<B>,
arr: [usize; 2],
int: usize,
ignore: Ignored<PaddingConfig2d>,
}
pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
let linear1 = SimpleLinear::new(32, 32, device);
Model {
linear1,
phantom: PhantomData,
arr: [2, 2],
int: 0,
ignore: Ignored(PaddingConfig2d::Same),
}
}
}

View File

@@ -0,0 +1,141 @@
use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
use alloc::vec::Vec;
use burn_tensor::backend::Backend;
use serde::{Serialize, de::DeserializeOwned};
/// Recorder trait specialized to save and load data to and from bytes.
///
/// # Notes
///
/// This is especially useful in no_std environment where weights are stored directly in
/// compiled binaries.
pub trait BytesRecorder<
B: Backend,
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
>: Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = L>
{
}
/// In memory recorder using the [bincode format](bincode).
#[derive(new, Debug, Default, Clone)]
pub struct BinBytesRecorder<
S: PrecisionSettings,
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default = Vec<u8>,
> {
_settings: core::marker::PhantomData<S>,
_loadargs: core::marker::PhantomData<L>,
}
impl<
S: PrecisionSettings,
B: Backend,
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
> BytesRecorder<B, L> for BinBytesRecorder<S, L>
{
}
impl<
S: PrecisionSettings,
B: Backend,
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
> Recorder<B> for BinBytesRecorder<S, L>
{
type Settings = S;
type RecordArgs = ();
type RecordOutput = Vec<u8>;
type LoadArgs = L;
fn save_item<I: Serialize>(
&self,
item: I,
_args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError> {
Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())
}
fn load_item<I: DeserializeOwned>(
&self,
args: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
let state = bincode::borrow_decode_from_slice::<'_, bincode::serde::BorrowCompat<I>, _>(
args.as_ref(),
bin_config(),
)
.unwrap()
.0;
Ok(state.0)
}
}
#[cfg(feature = "std")]
/// In memory recorder using the [Named MessagePack](rmp_serde).
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
_settings: core::marker::PhantomData<S>,
}
#[cfg(feature = "std")]
impl<S: PrecisionSettings, B: Backend> BytesRecorder<B, Vec<u8>> for NamedMpkBytesRecorder<S> {}
#[cfg(feature = "std")]
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {
type Settings = S;
type RecordArgs = ();
type RecordOutput = Vec<u8>;
type LoadArgs = Vec<u8>;
fn save_item<I: Serialize>(
&self,
item: I,
_args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError> {
rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))
}
fn load_item<I: DeserializeOwned>(
&self,
args: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
rmp_serde::decode::from_slice(args).map_err(|e| RecorderError::Unknown(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::SimpleLinear;
use crate::{
TestBackend, module::Module, record::FullPrecisionSettings, tensor::backend::Backend,
};
#[test]
fn test_can_save_and_load_bin_format() {
test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())
}
#[cfg(feature = "std")]
#[test]
fn test_can_save_and_load_named_mpk_format() {
test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
}
fn test_can_save_and_load<Recorder>(recorder: Recorder)
where
Recorder: BytesRecorder<TestBackend, Vec<u8>>,
{
let device = Default::default();
let model1 = create_model::<TestBackend>(&device);
let model2 = create_model::<TestBackend>(&device);
let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());
let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
assert_ne!(bytes1, bytes2);
assert_eq!(bytes1, bytes2_after);
}
pub fn create_model<B: Backend>(device: &B::Device) -> SimpleLinear<B> {
SimpleLinear::new(32, 32, device)
}
}

View File

@@ -0,0 +1,22 @@
mod primitive;
mod tensor;
mod base;
mod memory;
mod recorder;
mod settings;
pub use base::*;
pub use memory::*;
pub use recorder::*;
pub use settings::*;
#[cfg(feature = "std")]
mod file;
#[cfg(feature = "std")]
pub use file::*;
pub use primitive::ParamSerde;
#[cfg(feature = "record-item-custom-serde")]
pub mod serde;

View File

@@ -0,0 +1,336 @@
use alloc::{string::String, vec, vec::Vec};
use core::{fmt, marker::PhantomData};
use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
use super::{PrecisionSettings, Record};
use crate::module::{Param, ParamId};
use burn_tensor::{Bool, Int, Tensor, backend::Backend};
use hashbrown::HashMap;
use serde::{
Deserialize, Serialize,
de::{Error, SeqAccess, Visitor},
ser::SerializeTuple,
};
impl<B> Record<B> for ()
where
B: Backend,
{
type Item<S: PrecisionSettings> = ();
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
}
impl<T, B> Record<B> for Vec<T>
where
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.into_iter().map(Record::into_item).collect()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.into_iter()
.map(|i| Record::from_item(i, device))
.collect()
}
}
impl<T, B> Record<B> for Option<T>
where
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = Option<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.map(Record::into_item)
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.map(|i| Record::from_item(i, device))
}
}
impl<const N: usize, T, B> Record<B> for [T; N]
where
T: Record<B>,
B: Backend,
{
/// The record item is an array of the record item of the elements.
/// The reason why we wrap the array in a struct is because serde does not support
/// deserializing arrays of variable size,
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
/// for backward compatibility reasons. Serde APIs were created before const generics.
type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
Array(self.map(Record::into_item))
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.0.map(|i| Record::from_item(i, device))
}
}
/// A macro for generating implementations for tuple records of different sizes.
/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
/// Would generate an implementation for a tuple of size 2.
/// For this macro to work properly, please adhere to the convention:
/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
macro_rules! impl_record_tuple {
// `$r` represents the generic records.
// `$i` represents the indices of the records in the tuple.
([$($r:ident),*][$($i:tt),*]) => {
impl<B, $($r,)*> Record<B> for ($($r,)*)
where
B: Backend,
$($r: Record<B>),*
{
type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
($(self.$i.into_item(),)*)
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
($(Record::from_item(item.$i, device),)*)
}
}
};
}
impl_record_tuple!([R0, R1][0, 1]);
impl_record_tuple!([R0, R1, R2][0, 1, 2]);
impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
impl<T, B> Record<B> for HashMap<ParamId, T>
where
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
let mut items = HashMap::with_capacity(self.len());
self.into_iter().for_each(|(id, record)| {
items.insert(id.serialize(), record.into_item());
});
items
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
let mut record = HashMap::with_capacity(item.len());
item.into_iter().for_each(|(id, item)| {
record.insert(ParamId::deserialize(&id), T::from_item(item, device));
});
record
}
}
/// (De)serialize parameters into a clean format.
#[derive(new, Debug, Clone, Serialize, Deserialize)]
pub struct ParamSerde<T> {
id: String,
param: T,
}
impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
where
B: Backend,
{
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
let (id, tensor, mapper) = self.consume();
let tensor = mapper.on_save(tensor);
ParamSerde::new(id.serialize(), tensor.into_item())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
B::memory_persistent_allocations(device, item, |item| {
Param::initialized(
ParamId::deserialize(&item.id),
Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
// Param from a tensor.
)
})
}
}
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
where
B: Backend,
{
type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
let (id, tensor, mapper) = self.consume();
let tensor = mapper.on_save(tensor);
ParamSerde::new(id.serialize(), tensor.into_item())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
B::memory_persistent_allocations(device, item, |item| {
Param::initialized(
ParamId::deserialize(&item.id),
Tensor::from_item(item.param, device),
)
})
}
}
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
where
B: Backend,
{
type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
let (id, tensor, mapper) = self.consume();
let tensor = mapper.on_save(tensor);
ParamSerde::new(id.serialize(), tensor.into_item::<S>())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
B::memory_persistent_allocations(device, item, |item| {
Param::initialized(
ParamId::deserialize(&item.id),
Tensor::from_item::<S>(item.param, device),
)
})
}
}
// Type that can be serialized as is without any conversion.
macro_rules! primitive {
($type:ty) => {
impl<B: Backend> Record<B> for $type {
type Item<S: PrecisionSettings> = $type;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
item
}
}
};
}
// General Types
primitive!(alloc::string::String);
primitive!(bool);
// Float Types
primitive!(f64);
primitive!(f32);
primitive!(half::bf16);
primitive!(half::f16);
// Unsigned Integer Types
primitive!(usize);
primitive!(u64);
primitive!(u32);
primitive!(u16);
primitive!(u8);
// Signed Integer Types
primitive!(isize);
primitive!(i64);
primitive!(i32);
primitive!(i16);
primitive!(i8);
/// A wrapper around an array of size N, so that it can be serialized and deserialized
/// using serde.
///
/// The reason why we wrap the array in a struct is because serde does not support
/// deserializing arrays of variable size,
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
/// for backward compatibility reasons. Serde APIs were created before const generics.
#[derive(Clone)]
pub struct Array<const N: usize, T>([T; N]);
impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_tuple(self.0.len())?;
for element in &self.0 {
seq.serialize_element(element)?;
}
seq.end()
}
}
impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ArrayVisitor<T, const N: usize> {
marker: PhantomData<T>,
}
impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
where
T: Deserialize<'de>,
{
type Value = Array<N, T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a fixed size array")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut items = vec![];
for i in 0..N {
let item = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(i, &self))?;
items.push(item);
}
let array: [T; N] = items
.into_iter()
.collect::<Vec<_>>()
.try_into()
.map_err(|_| "An array of size {N}")
.unwrap();
Ok(Array(array))
}
}
deserializer.deserialize_tuple(
N,
ArrayVisitor {
marker: PhantomData,
},
)
}
}

View File

@@ -0,0 +1,329 @@
use core::any::type_name;
use core::marker::PhantomData;
use alloc::format;
use alloc::string::{String, ToString};
use burn_tensor::backend::Backend;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
#[cfg(feature = "std")]
use super::{
BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,
PrettyJsonFileRecorder,
};
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
pub trait Recorder<B: Backend>:
Send + Sync + core::default::Default + core::fmt::Debug + Clone
{
/// Type of the settings used by the recorder.
type Settings: PrecisionSettings;
/// Arguments used to record objects.
type RecordArgs: Clone;
/// Record output type.
type RecordOutput;
/// Arguments used to load recorded objects.
type LoadArgs;
/// Records an item.
///
/// # Arguments
///
/// * `record` - The item to record.
/// * `args` - Arguments used to record the item.
///
/// # Returns
///
/// The output of the recording.
fn record<R>(
&self,
record: R,
args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError>
where
R: Record<B>,
{
let item = record.into_item::<Self::Settings>();
let item = BurnRecord::new::<Self>(item);
self.save_item(item, args)
}
/// Load an item from the given arguments.
fn load<R>(&self, mut args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
where
R: Record<B>,
{
let item: BurnRecord<R::Item<Self::Settings>, B> =
self.load_item(&mut args).map_err(|err| {
if let Ok(record) = self.load_item::<BurnRecordNoItem>(&mut args) {
let mut message = "Unable to load record.".to_string();
let metadata = recorder_metadata::<Self, B>();
if metadata.float != record.metadata.float {
message += format!(
"\nMetadata has a different float type: Actual {:?}, Expected {:?}",
record.metadata.float, metadata.float
)
.as_str();
}
if metadata.int != record.metadata.int {
message += format!(
"\nMetadata has a different int type: Actual {:?}, Expected {:?}",
record.metadata.int, metadata.int
)
.as_str();
}
if metadata.format != record.metadata.format {
message += format!(
"\nMetadata has a different format: Actual {:?}, Expected {:?}",
record.metadata.format, metadata.format
)
.as_str();
}
if metadata.version != record.metadata.version {
message += format!(
"\nMetadata has a different Burn version: Actual {:?}, Expected {:?}",
record.metadata.version, metadata.version
)
.as_str();
}
message += format!("\nError: {err:?}").as_str();
return RecorderError::Unknown(message);
}
err
})?;
Ok(R::from_item(item.item, device))
}
/// Saves an item.
///
/// This method is used by [record](Recorder::record) to save the item.
///
/// # Arguments
///
/// * `item` - Item to save.
/// * `args` - Arguments to use to save the item.
///
/// # Returns
///
/// The output of the save operation.
fn save_item<I: Serialize>(
&self,
item: I,
args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError>;
/// Loads an item.
///
/// This method is used by [load](Recorder::load) to load the item.
///
/// # Arguments
///
/// * `args` - Arguments to use to load the item.
///
/// # Returns
///
/// The loaded item.
fn load_item<I>(&self, args: &mut Self::LoadArgs) -> Result<I, RecorderError>
where
I: DeserializeOwned;
}
fn recorder_metadata<R, B>() -> BurnMetadata
where
R: Recorder<B>,
B: Backend,
{
BurnMetadata::new(
type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
type_name::<R>().to_string(),
env!("CARGO_PKG_VERSION").to_string(),
format!("{:?}", R::Settings::default()),
)
}
/// Error that can occur when using a [Recorder](Recorder).
#[derive(Debug)]
pub enum RecorderError {
/// File not found.
FileNotFound(String),
/// Failed to read file.
DeserializeError(String),
/// Other error.
Unknown(String),
}
impl core::fmt::Display for RecorderError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("{self:?}").as_str())
}
}
impl core::error::Error for RecorderError {}
pub(crate) fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}
/// Metadata of a record.
#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct BurnMetadata {
/// Float type used to record the item.
pub float: String,
/// Int type used to record the item.
pub int: String,
/// Format used to record the item.
pub format: String,
/// Burn record version used to record the item.
pub version: String,
/// Settings used to record the item.
pub settings: String,
}
/// Record that can be saved by a [Recorder](Recorder).
#[derive(Serialize, Deserialize, Debug)]
pub struct BurnRecord<I, B: Backend> {
/// Metadata of the record.
pub metadata: BurnMetadata,
/// Item to record.
pub item: I,
_b: PhantomData<B>,
}
impl<I, B: Backend> BurnRecord<I, B> {
/// Creates a new record.
///
/// # Arguments
///
/// * `item` - Item to record.
///
/// # Returns
///
/// The new record.
pub fn new<R: Recorder<B>>(item: I) -> Self {
let metadata = recorder_metadata::<R, B>();
Self {
metadata,
item,
_b: PhantomData,
}
}
}
/// Record that can be saved by a [Recorder](Recorder) without the item.
#[derive(new, Debug, Serialize, Deserialize)]
pub struct BurnRecordNoItem {
/// Metadata of the record.
pub metadata: BurnMetadata,
}
/// Default recorder.
///
/// It uses the [named msgpack](rmp_serde) format for serialization with full precision.
#[cfg(feature = "std")]
pub type DefaultRecorder = DefaultFileRecorder<FullPrecisionSettings>;
/// Recorder optimized for compactness.
///
/// It uses the [named msgpack](rmp_serde) format for serialization with half precision.
/// If you are looking for the recorder that offers the smallest file size, have a look at
/// [sensitive compact recorder](SensitiveCompactRecorder).
#[cfg(feature = "std")]
pub type CompactRecorder = DefaultFileRecorder<HalfPrecisionSettings>;
/// Recorder optimized for compactness making it a good choice for model deployment.
///
/// It uses the [bincode](bincode) format for serialization and half precision.
/// This format is not resilient to type changes since no metadata is encoded.
/// Favor [default recorder](DefaultRecorder) or [compact recorder](CompactRecorder)
/// for long term data storage.
#[cfg(feature = "std")]
pub type SensitiveCompactRecorder = BinGzFileRecorder<HalfPrecisionSettings>;
/// Training recorder compatible with no-std inference.
#[cfg(feature = "std")]
pub type NoStdTrainingRecorder = BinFileRecorder<FullPrecisionSettings>;
/// Inference recorder compatible with no-std.
pub type NoStdInferenceRecorder = BinBytesRecorder<FullPrecisionSettings, &'static [u8]>;
/// Debug recorder.
///
/// It uses the [pretty json](serde_json) format for serialization with full precision making it
/// human readable.
#[cfg(feature = "std")]
pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;
#[cfg(all(test, feature = "std"))]
mod tests {
static FILE_PATH: &str = "/tmp/burn_test_record";
use crate::TestBackend;
use super::*;
use burn_tensor::{Device, ElementConversion};
#[test]
#[should_panic]
fn err_when_invalid_item() {
#[derive(new, Serialize, Deserialize, Clone)]
struct Item<S: PrecisionSettings> {
value: S::FloatElem,
}
impl<D, B> Record<B> for Item<D>
where
D: PrecisionSettings,
B: Backend,
{
type Item<S: PrecisionSettings> = Item<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
Item {
value: self.value.elem(),
}
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
Item {
value: item.value.elem(),
}
}
}
let item = Item::<FullPrecisionSettings>::new(16.elem());
let device: Device<TestBackend> = Default::default();
// Serialize in f32.
let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();
// Can't deserialize f32 into f16.
let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(
&recorder,
FILE_PATH.into(),
&device,
)
.unwrap();
}
}

View File

@@ -0,0 +1,83 @@
use super::data::NestedValue;
/// A trait that defines the adapter for a Burn module.
///
/// This is used to adapt an incoming module to a Burn module.
pub trait BurnModuleAdapter: Sized {
/// Adapts a module.
fn adapt(name: &str, data: NestedValue) -> NestedValue {
match name {
"BatchNorm" => Self::adapt_batch_norm(data),
"Conv1d" => Self::adapt_conv1d(data),
"Conv2d" => Self::adapt_conv2d(data),
"Conv3d" => Self::adapt_conv3d(data),
"ConvTranspose1d" => Self::adapt_conv_transpose_1d(data),
"ConvTranspose2d" => Self::adapt_conv_transpose_2d(data),
"ConvTranspose3d" => Self::adapt_conv_transpose_3d(data),
"Embedding" => Self::adapt_embedding(data),
"GroupNorm" => Self::adapt_group_norm(data),
"LayerNorm" => Self::adapt_layer_norm(data),
"Linear" => Self::adapt_linear(data),
_ => data,
}
}
/// Adapts a linear module.
fn adapt_linear(data: NestedValue) -> NestedValue {
data
}
/// Adapts a Convolution 1D module.
fn adapt_conv1d(data: NestedValue) -> NestedValue {
data
}
/// Adapts a Convolution 2D module.
fn adapt_conv2d(data: NestedValue) -> NestedValue {
data
}
/// Adapts a Convolution 3D module.
fn adapt_conv3d(data: NestedValue) -> NestedValue {
data
}
/// Adapts convolution transpose 1D module.
fn adapt_conv_transpose_1d(data: NestedValue) -> NestedValue {
data
}
/// Adapts convolution transpose 2D module.
fn adapt_conv_transpose_2d(data: NestedValue) -> NestedValue {
data
}
/// Adapts convolution transpose 2D module.
fn adapt_conv_transpose_3d(data: NestedValue) -> NestedValue {
data
}
/// Adapts embedding module.
fn adapt_embedding(data: NestedValue) -> NestedValue {
data
}
/// Adapts group normalization module.
fn adapt_group_norm(data: NestedValue) -> NestedValue {
data
}
/// Adapts layer normalization module.
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
data
}
/// Adapts batch normalization module.
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
data
}
}
/// Default adapter that takes no action.
pub struct DefaultAdapter;
impl BurnModuleAdapter for DefaultAdapter {}

View File

@@ -0,0 +1,399 @@
use std::collections::HashMap;
use super::adapter::BurnModuleAdapter;
use super::de::Deserializer;
use super::error::Error;
use super::ser::Serializer;
use crate::record::{PrecisionSettings, Record};
use crate::tensor::backend::Backend;
use alloc::fmt;
use burn_tensor::Bytes;
use num_traits::cast::ToPrimitive;
use regex::Regex;
use serde::Deserialize;
/// The main data structure used for deserialization.
///
/// It can hold tree-like structures of nested maps and vectors.
#[derive(Clone)]
pub enum NestedValue {
/// The default value, which actually does not hold any value and it is used to indicate that
/// the value should be populated with the default value. It contains an optional string with
/// the originator field name.
Default(Option<String>),
/// A boolean value.
Bool(bool),
/// A string value.
String(String),
/// Floating point 32-bit value.
F32(f32),
/// Floating point 64-bit value.
F64(f64),
/// Signed 16-bit integer value.
I16(i16),
/// Signed 32-bit integer value.
I32(i32),
/// Signed 64-bit integer value.
I64(i64),
/// Unsigned 8-bit integer value.
U8(u8),
/// Unsigned 16-bit integer value used for bf16 and f16 serialization
U16(u16),
/// Unsigned 64-bit integer value.
U64(u64),
/// A map of nested values (typically used for structs)
Map(HashMap<String, NestedValue>),
/// A vector of nested values (typically used for vector of structs or numbers)
Vec(Vec<NestedValue>),
/// A vector of 8-bit unsigned integer values.
U8s(Vec<u8>),
/// A vector of 16-bit unsigned integer values.
U16s(Vec<u16>),
/// A vector of 32-bit floating point values.
F32s(Vec<f32>),
/// An opaque vector of bytes, with alignment.
Bytes(Bytes),
}
impl NestedValue {
/// Get the nested value as a map.
pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {
match self {
NestedValue::Map(map) => Some(map),
_ => None,
}
}
/// Get the nested value as a boolean.
pub fn as_bool(self) -> Option<bool> {
match self {
NestedValue::Bool(bool) => Some(bool),
_ => None,
}
}
/// Get the nested value as a string.
pub fn as_string(self) -> Option<String> {
match self {
NestedValue::String(string) => Some(string),
_ => None,
}
}
/// Get the nested value as a f32.
pub fn as_f32(self) -> Option<f32> {
match self {
NestedValue::F32(f32) => Some(f32),
NestedValue::F64(f) => f.to_f32(),
_ => None,
}
}
/// Get the nested value as a f64.
pub fn as_f64(self) -> Option<f64> {
match self {
NestedValue::F64(f64) => Some(f64),
NestedValue::F32(f) => f.to_f64(),
_ => None,
}
}
/// Get the nested value as an i16.
pub fn as_i16(self) -> Option<i16> {
match self {
NestedValue::I16(i16) => Some(i16),
NestedValue::I32(i) => i.to_i16(),
NestedValue::I64(i) => i.to_i16(),
NestedValue::U16(u) => u.to_i16(),
NestedValue::U64(u) => u.to_i16(),
_ => None,
}
}
/// Get the nested value as an i32.
pub fn as_i32(self) -> Option<i32> {
match self {
NestedValue::I32(i32) => Some(i32),
NestedValue::I16(i) => i.to_i32(),
NestedValue::I64(i) => i.to_i32(),
NestedValue::U16(u) => u.to_i32(),
NestedValue::U64(u) => u.to_i32(),
_ => None,
}
}
/// Get the nested value as an i64.
pub fn as_i64(self) -> Option<i64> {
match self {
NestedValue::I64(i64) => Some(i64),
NestedValue::I16(i) => i.to_i64(),
NestedValue::I32(i) => i.to_i64(),
NestedValue::U16(u) => u.to_i64(),
NestedValue::U64(u) => u.to_i64(),
_ => None,
}
}
/// Get the nested value as a u8.
pub fn as_u8(self) -> Option<u8> {
match self {
NestedValue::U8(u8) => Some(u8),
NestedValue::I16(i) => i.to_u8(),
NestedValue::I32(i) => i.to_u8(),
NestedValue::I64(i) => i.to_u8(),
NestedValue::U16(u) => u.to_u8(),
NestedValue::U64(u) => u.to_u8(),
_ => None,
}
}
/// Get the nested value as a u16.
pub fn as_u16(self) -> Option<u16> {
match self {
NestedValue::U16(u16) => Some(u16),
NestedValue::I16(i) => i.to_u16(),
NestedValue::I32(i) => i.to_u16(),
NestedValue::I64(i) => i.to_u16(),
NestedValue::U64(u) => u.to_u16(),
_ => None,
}
}
/// Get the nested value as a u64.
pub fn as_u64(self) -> Option<u64> {
match self {
NestedValue::U64(u64) => Some(u64),
NestedValue::I16(i) => i.to_u64(),
NestedValue::I32(i) => i.to_u64(),
NestedValue::I64(i) => i.to_u64(),
NestedValue::U16(u) => u.to_u64(),
_ => None,
}
}
/// Get the nested value as a vector of bytes.
pub fn as_bytes(self) -> Option<Bytes> {
match self {
NestedValue::Bytes(u) => Some(u),
NestedValue::U8s(u) => Some(Bytes::from_elems(u)),
_ => None,
}
}
/// Deserialize a nested value into a record type.
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
where
B: Backend,
T: Record<B>,
PS: PrecisionSettings,
A: BurnModuleAdapter,
{
let deserializer = Deserializer::<A>::new(self, false);
let item = T::Item::deserialize(deserializer)?;
// Convert the deserialized item into a Record instance
Ok(T::from_item::<PS>(item, device))
}
}
/// Remap the tensor locations according to the key remapping.
///
/// # Arguments
///
/// * `tensors` - A map of tensors.
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
/// for more information.
///
/// # Returns
///
/// A map of tensors with the remapped keys and
/// a vector of tuples containing the remapped and original.
pub fn remap<T>(
mut tensors: HashMap<String, T>,
key_remap: Vec<(Regex, String)>,
) -> (HashMap<String, T>, Vec<(String, String)>) {
if key_remap.is_empty() {
let remapped_names = tensors
.keys()
.cloned()
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
.collect();
return (tensors, remapped_names);
}
let mut remapped = HashMap::new();
let mut remapped_names = Vec::new();
for (name, tensor) in tensors.drain() {
let mut new_name = name.clone();
for (pattern, replacement) in &key_remap {
if pattern.is_match(&new_name) {
new_name = pattern
.replace_all(&new_name, replacement.as_str())
.to_string();
}
}
remapped_names.push((new_name.clone(), name));
remapped.insert(new_name, tensor);
}
(remapped, remapped_names)
}
/// Helper function to insert a value into a nested map/vector of tensors.
fn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) {
if keys.is_empty() {
*current = value;
return;
}
match current {
NestedValue::Map(map) => {
if !map.contains_key(keys[0]) {
let next = if keys[1..]
.first()
.and_then(|k| k.parse::<usize>().ok())
.is_some()
{
NestedValue::Vec(Vec::new())
} else {
NestedValue::Map(HashMap::new())
};
map.insert(keys[0].to_string(), next);
}
insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value);
}
NestedValue::Vec(vec) => {
let index = keys[0].parse::<usize>().unwrap();
if index >= vec.len() {
vec.resize_with(index + 1, || NestedValue::Map(HashMap::new()));
}
insert_nested_value(&mut vec[index], &keys[1..], value);
}
_ => panic!("Invalid structure encountered"),
}
}
/// A trait for encapsulating the serialization logic.
pub trait Serializable {
/// Serializes the object into a `NestedValue` using the provided `Serializer`.
/// This method is generic over the precision settings `PS`.
///
/// # Parameters
/// - `serializer`: The `Serializer` to use for serializing the object.
///
/// # Returns
/// - `Result<NestedValue, Error>`: The result of serialization.
/// Returns a `NestedValue` on success,
/// or an `Error` on failure.
///
/// # Type Parameters
/// - `PS`: The precision settings to use during serialization.
/// This is a generic parameter and can be any type
/// that implements the `PrecisionSettings` trait.
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, Error>
where
PS: PrecisionSettings;
}
/// Convert a vector of tensors to a nested value.
pub fn unflatten<PS, T>(input: HashMap<String, T>) -> Result<NestedValue, Error>
where
PS: PrecisionSettings,
T: Serializable,
{
let mut result = NestedValue::Map(HashMap::new());
for (key, value) in input {
let parts: Vec<&str> = key.split('.').collect();
let st = value.serialize::<PS>(Serializer::new())?;
insert_nested_value(&mut result, &parts, st);
}
cleanup_empty_maps(&mut result);
Ok(result)
}
/// Removes empty maps from the nested value.
///
/// We need to clean up empty maps from the nested value
/// in some cases when there is non-contiguous indices in keys.
fn cleanup_empty_maps(current: &mut NestedValue) {
match current {
NestedValue::Map(map) => {
map.values_mut().for_each(cleanup_empty_maps);
}
NestedValue::Vec(vec) => {
vec.iter_mut().for_each(cleanup_empty_maps);
vec.retain(|v| !matches!(v, NestedValue::Map(m) if m.is_empty()));
}
_ => {}
}
}
fn write_vec_truncated<T: core::fmt::Debug>(
vec: &[T],
f: &mut core::fmt::Formatter,
) -> fmt::Result {
write!(f, "Vec([")?;
for (i, v) in vec.iter().take(3).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{v:?}")?;
}
write!(f, ", ...] len={})", vec.len())
}
impl fmt::Debug for NestedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
// Truncate values for vector
NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f),
// Handle other variants as usual
NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(),
NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(),
NestedValue::String(s) => f.debug_tuple("String").field(s).finish(),
NestedValue::F32(val) => f.debug_tuple("F32").field(val).finish(),
NestedValue::F64(val) => f.debug_tuple("F64").field(val).finish(),
NestedValue::I16(val) => f.debug_tuple("I16").field(val).finish(),
NestedValue::I32(val) => f.debug_tuple("I32").field(val).finish(),
NestedValue::I64(val) => f.debug_tuple("I64").field(val).finish(),
NestedValue::U8(val) => f.debug_tuple("U8").field(val).finish(),
NestedValue::U16(val) => f.debug_tuple("U16").field(val).finish(),
NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(),
NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(),
NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
use crate::record::RecorderError;
/// The error type for Record serde.
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Failed to deserialize.
#[error("failed to deserialize: {0}")]
Deserialize(#[from] serde::de::value::Error),
/// Failed to serialize.
#[error("failed to serialize")]
Serialize(String),
/// Encountered an invalid state.
#[error("invalid state")]
InvalidState,
/// Other error.
#[error("other error: {0}")]
Other(String),
}
impl serde::de::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::Deserialize(serde::de::value::Error::custom(msg.to_string()))
}
}
impl serde::ser::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::Serialize(msg.to_string())
}
}
// Implement From trait for Error to RecorderError
impl From<Error> for RecorderError {
fn from(error: Error) -> Self {
RecorderError::DeserializeError(error.to_string())
}
}

View File

@@ -0,0 +1,17 @@
//! Module contains the serde implementation for the record module
//! useful for custom importing model weights, such as PyTorch's pt file format.
/// The adapter trait that is used to convert the nested value to the module type.
pub mod adapter;
/// The main data structure used for deserialization.
pub mod data;
/// The deserializer that is used to convert the nested value to the record.
pub mod ser;
/// The deserializer that is used to convert the nested value to the record.
pub mod de;
/// Error types.
pub mod error;

View File

@@ -0,0 +1,387 @@
use std::collections::HashMap;
use super::{
data::NestedValue,
error::{self, Error},
};
use serde::{
Serialize,
ser::{self, SerializeSeq, SerializeStruct, Serializer as SerializerTrait},
};
/// Simple struct serializer that converts a struct into NestedValues.
///
/// NOTE: This is used to serialize Param structs into NestedValues and not so much for
/// the actual serialization of modules (although it could be used for that as well if all
/// primitive types are implemented).
#[derive(Clone)]
pub struct Serializer {
/// The state of the serialization process
state: Option<NestedValue>,
}
impl Serializer {
/// Creates a new serializer.
pub fn new() -> Self {
Serializer { state: None }
}
}
impl Default for Serializer {
fn default() -> Self {
Self::new()
}
}
impl SerializerTrait for Serializer {
type Ok = NestedValue;
type Error = Error;
type SerializeSeq = Self;
type SerializeTuple = ser::Impossible<NestedValue, Self::Error>;
type SerializeTupleStruct = ser::Impossible<NestedValue, Self::Error>;
type SerializeTupleVariant = ser::Impossible<NestedValue, Self::Error>;
type SerializeMap = ser::Impossible<NestedValue, Self::Error>;
type SerializeStruct = Self;
type SerializeStructVariant = ser::Impossible<NestedValue, Self::Error>;
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Self::Error> {
Ok(self)
}
fn serialize_newtype_struct<T>(
self,
_name: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize + ?Sized,
{
value.serialize(self)
}
fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
Ok(self)
}
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::I32(v))
}
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::String(v.to_string()))
}
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::I16(v))
}
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::I64(v))
}
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::U16(v))
}
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::U64(v))
}
fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::F32(v))
}
fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::F64(v))
}
// The following methods are not implemented because they are not needed for the
// serialization of Param structs.
fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::U8s(v.to_vec()))
}
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::Default(None))
}
fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::U8(v))
}
fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
where
T: Serialize + ?Sized,
{
value.serialize(self)
}
fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::Map(HashMap::from([(
_name.to_string(),
NestedValue::String(_variant.to_string()),
)])))
}
fn serialize_newtype_variant<T>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize + ?Sized,
{
unimplemented!()
}
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
unimplemented!()
}
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Self::Error> {
unimplemented!()
}
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant, Self::Error> {
unimplemented!()
}
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
unimplemented!()
}
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
unimplemented!()
}
}
// Implementing the SerializeStruct trait for Serializer
impl SerializeStruct for Serializer {
type Ok = NestedValue;
type Error = Error;
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
where
T: Serialize + ?Sized,
{
let serialized_value = value.serialize(Serializer::new())?;
match self.state {
Some(NestedValue::Map(ref mut map)) => {
map.insert(key.to_string(), serialized_value); // Inserting into the state
}
Some(_) => {
panic!("Invalid state encountered");
}
None => {
let mut map = HashMap::new();
map.insert(key.to_string(), serialized_value); // Inserting into the state
self.state = Some(NestedValue::Map(map));
}
}
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
if self.state.is_none() {
// If the state is empty, return an empty map
Ok(NestedValue::Map(HashMap::new()))
} else {
self.state.ok_or(error::Error::InvalidState)
}
}
}
impl SerializeSeq for Serializer {
type Ok = NestedValue;
type Error = Error;
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize + ?Sized,
{
let serialized_value = value.serialize(Serializer::new())?;
match self.state {
Some(NestedValue::Vec(ref mut vec)) => {
vec.push(serialized_value); // Inserting into the state
}
Some(NestedValue::U8s(ref mut vec)) => {
if let NestedValue::U8(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(NestedValue::U16s(ref mut vec)) => {
if let NestedValue::U16(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(NestedValue::F32s(ref mut vec)) => {
if let NestedValue::F32(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(_) => {
panic!("Invalid state encountered");
}
None => {
let val = match serialized_value {
NestedValue::U8(val) => NestedValue::U8s(vec![val]),
NestedValue::U16(val) => NestedValue::U16s(vec![val]),
NestedValue::F32(val) => NestedValue::F32s(vec![val]),
_ => NestedValue::Vec(vec![serialized_value]),
};
self.state = Some(val);
}
}
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
if self.state.is_none() {
// If the state is empty, return an empty vector
Ok(NestedValue::Vec(Vec::new()))
} else {
self.state.ok_or(error::Error::InvalidState)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
TestBackend,
module::{Param, ParamId},
record::{FullPrecisionSettings, Record},
tensor::Tensor,
};
use serde::Deserialize;
use super::*;
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MyStruct1 {
a: MyStruct3,
b: MyStruct2,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MyStruct2 {
a: i32,
b: Option<i32>,
c: String,
d: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MyStruct3 {
x: String,
y: String,
}
#[test]
fn test_serialize() {
let my_struct = MyStruct1 {
a: MyStruct3 {
x: "Hello".to_owned(),
y: "World".to_owned(),
},
b: MyStruct2 {
a: 1,
b: None,
c: "Hello".to_owned(),
d: Some("World".to_owned()),
},
};
let serialized = my_struct
.serialize(Serializer::new())
.expect("Should serialize item successfully");
let serialized_str = format!("{serialized:?}");
// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 135);
}
#[test]
fn test_param_serde() {
let device = Default::default();
let tensor: Tensor<TestBackend, 2> = Tensor::ones([2, 2], &device);
let param = Param::initialized(ParamId::new(), tensor);
let param_item = param.into_item::<FullPrecisionSettings>();
let serialized = param_item
.serialize(Serializer::new())
.expect("Should serialize item successfully");
let bytes = serialized.as_map().expect("is a map")["param"]
.clone()
.as_map()
.expect("param is a map")["bytes"]
.clone()
.as_bytes()
.expect("has bytes vec");
assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());
}
}

View File

@@ -0,0 +1,40 @@
use burn_tensor::Element;
use serde::{Serialize, de::DeserializeOwned};
/// Settings allowing to control the precision when (de)serializing items.
pub trait PrecisionSettings:
Send + Sync + core::fmt::Debug + core::default::Default + Clone
{
/// Float element type.
type FloatElem: Element + Serialize + DeserializeOwned;
/// Integer element type.
type IntElem: Element + Serialize + DeserializeOwned;
}
/// Default precision settings.
#[derive(Debug, Default, Clone)]
pub struct FullPrecisionSettings;
/// Precision settings optimized for compactness.
#[derive(Debug, Default, Clone)]
pub struct HalfPrecisionSettings;
/// Precision settings optimized for precision.
#[derive(Debug, Default, Clone)]
pub struct DoublePrecisionSettings;
impl PrecisionSettings for FullPrecisionSettings {
type FloatElem = f32;
type IntElem = i32;
}
impl PrecisionSettings for DoublePrecisionSettings {
type FloatElem = f64;
type IntElem = i64;
}
impl PrecisionSettings for HalfPrecisionSettings {
type FloatElem = half::f16;
type IntElem = i16;
}

View File

@@ -0,0 +1,159 @@
use core::marker::PhantomData;
use super::{PrecisionSettings, Record};
use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};
use serde::{Deserialize, Serialize};
use alloc::format;
/// Deserialize the value into [`TensorData`].
fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
where
E: Element + Deserialize<'de>,
De: serde::Deserializer<'de>,
{
let data = TensorData::deserialize(deserializer).map_err(|e| {
serde::de::Error::custom(format!(
"{e:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n"
))
})?;
let data = if let DType::QFloat(_) = data.dtype {
data // do not convert quantized tensors
} else {
data.convert::<E>()
};
Ok(data)
}
/// This struct implements serde to lazily serialize and deserialize a float tensor
/// using the given [record settings](RecordSettings).
#[derive(new, Clone, Debug)]
pub struct FloatTensorSerde<S: PrecisionSettings> {
data: TensorData,
_e: PhantomData<S::FloatElem>,
}
/// This struct implements serde to lazily serialize and deserialize an int tensor
/// using the given [record settings](RecordSettings).
#[derive(new, Clone, Debug)]
pub struct IntTensorSerde<S: PrecisionSettings> {
data: TensorData,
_e: PhantomData<S::IntElem>,
}
/// This struct implements serde to lazily serialize and deserialize an bool tensor.
#[derive(new, Clone, Debug)]
pub struct BoolTensorSerde {
data: TensorData,
}
// --- SERDE IMPLEMENTATIONS --- //
impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
self.data.serialize(serializer)
}
}
impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: serde::Deserializer<'de>,
{
let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
Ok(Self::new(data))
}
}
impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
self.data.serialize(serializer)
}
}
impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: serde::Deserializer<'de>,
{
let data = deserialize_data::<S::IntElem, De>(deserializer)?;
Ok(Self::new(data))
}
}
impl Serialize for BoolTensorSerde {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
self.data.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for BoolTensorSerde {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: serde::Deserializer<'de>,
{
let data = deserialize_data::<bool, De>(deserializer)?;
Ok(Self::new(data))
}
}
// --- RECORD IMPLEMENTATIONS --- //
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
let data = self.into_data();
let data = if let DType::QFloat(_) = data.dtype {
data // do not convert quantized tensors
} else {
data.convert::<S::FloatElem>()
};
FloatTensorSerde::new(data)
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
let data = if let DType::QFloat(_) = item.data.dtype {
item.data // do not convert quantized tensors
} else {
item.data.convert::<B::FloatElem>()
};
Tensor::from_data(data, device)
}
}
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Tensor::from_data(item.data.convert::<B::IntElem>(), device)
}
}
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
type Item<S: PrecisionSettings> = BoolTensorSerde;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
BoolTensorSerde::new(self.into_data())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Tensor::from_data(item.data, device)
}
}

View File

@@ -0,0 +1 @@
pub use burn_tensor::*;

View File

@@ -0,0 +1 @@
pub use burn_vision::*;

View File

@@ -0,0 +1,113 @@
use burn::config::{Config, config_to_json};
use burn_core as burn;
#[derive(Config, Debug, PartialEq, Eq)]
pub struct TestEmptyStructConfig {}
#[derive(Config, Debug, PartialEq)]
pub struct TestStructConfig {
int: i32,
#[config(default = 2)]
int_default: i32,
float: f32,
#[config(default = 2.0)]
float_default: f32,
string: String,
other_config: TestEmptyStructConfig,
}
#[derive(Config, Debug, PartialEq)]
pub enum TestEnumConfig {
None,
Single(f32),
Multiple(f32, String),
Named { first: f32, second: String },
}
#[cfg(feature = "std")]
#[inline(always)]
fn file_path(file_name: &str) -> std::path::PathBuf {
std::env::temp_dir().join(file_name)
}
#[cfg(feature = "std")]
#[test]
fn struct_config_should_impl_serde() {
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
let file_path = file_path("test_struct_config.json");
config.save(&file_path).unwrap();
let config_loaded = TestStructConfig::load(&file_path).unwrap();
assert_eq!(config, config_loaded);
}
#[test]
fn struct_config_should_impl_clone() {
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
assert_eq!(config, config.clone());
}
#[test]
fn struct_config_should_impl_display() {
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
assert_eq!(burn::config::config_to_json(&config), config.to_string());
}
#[cfg(feature = "std")]
#[test]
fn enum_config_no_value_should_impl_serde() {
let config = TestEnumConfig::None;
let file_path = file_path("test_enum_no_value_config.json");
config.save(&file_path).unwrap();
let config_loaded = TestEnumConfig::load(&file_path).unwrap();
assert_eq!(config, config_loaded);
}
#[cfg(feature = "std")]
#[test]
fn enum_config_one_value_should_impl_serde() {
let config = TestEnumConfig::Single(42.0);
let file_path = file_path("test_enum_one_value_config.json");
config.save(&file_path).unwrap();
let config_loaded = TestEnumConfig::load(&file_path).unwrap();
assert_eq!(config, config_loaded);
}
#[cfg(feature = "std")]
#[test]
fn enum_config_multiple_values_should_impl_serde() {
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
let file_path = file_path("test_enum_multiple_values_config.json");
config.save(&file_path).unwrap();
let config_loaded = TestEnumConfig::load(&file_path).unwrap();
assert_eq!(config, config_loaded);
}
#[test]
fn enum_config_should_impl_clone() {
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
assert_eq!(config, config.clone());
}
#[test]
fn enum_config_should_impl_display() {
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
assert_eq!(burn::config::config_to_json(&config), config.to_string());
}
#[test]
fn struct_config_can_load_binary() {
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
let binary = config_to_json(&config).as_bytes().to_vec();
let config_loaded = TestStructConfig::load_binary(&binary).unwrap();
assert_eq!(config, config_loaded);
}

View File

@@ -0,0 +1,323 @@
use std::marker::PhantomData;
use burn::module::Initializer;
use burn::module::{Module, Param};
use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor};
use burn_core as burn;
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "std")]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
#[derive(Module, Debug)]
pub struct ModuleBasic<B: Backend> {
weight_basic: Param<Tensor<B, 2>>,
}
#[derive(Module, Debug)]
#[allow(unused)]
struct ModuleTensorConstInt<B: Backend> {
weight_basic: Tensor<B, 2, Int>,
}
impl<B: Backend> ModuleBasic<B> {
fn new(device: &B::Device) -> Self {
Self {
weight_basic: Initializer::Normal {
std: 1.0,
mean: 0.0,
}
.init([20, 20], device),
}
}
}
#[derive(Module, Debug)]
struct ModuleWithConstGeneric<B: Backend, const N: usize> {
modules: [ModuleBasic<B>; N],
}
#[derive(Module, Debug)]
struct ModuleWithGenericModule<B: Backend, M> {
module: M,
_backend: PhantomData<B>,
}
#[derive(Module, Debug)]
#[allow(clippy::large_enum_variant)]
enum ModuleEnum<B: Backend> {
Basic(ModuleBasic<B>),
Composed(ModuleComposed<B>),
}
#[derive(Module, Debug)]
#[allow(unused)]
enum ModuleEnumNested<B: Backend> {
AnotherEnum(ModuleEnum<B>),
}
#[derive(Module, Debug)]
enum ModuleEnumWithGenericModule<B: Backend, M: Module<B>> {
Basic(ModuleBasic<B>),
Generic(ModuleWithGenericModule<B, M>),
}
#[derive(Module, Debug)]
pub struct ModuleComposed<B: Backend> {
weight: Param<Tensor<B, 2>>,
basic: ModuleBasic<B>,
tuple: (ModuleBasic<B>, ModuleBasic<B>),
}
impl<B: Backend> ModuleComposed<B> {
fn new(device: &B::Device) -> Self {
let weight = Initializer::Normal {
std: 1.0,
mean: 0.0,
}
.init([20, 20], device);
Self {
weight,
basic: ModuleBasic::new(device),
tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),
}
}
}
#[allow(dead_code)]
mod compiletime_clone_impl_check {
use burn_core::{
module::{Module, ModuleDisplay},
prelude::Backend,
record::{PrecisionSettings, Record},
};
use super::*;
type RecordItem<M, B, S> = <<M as Module<B>>::Record as Record<B>>::Item<S>;
fn implements_clone<T: Clone>() {}
fn basic_implements_clone<B: Backend, S: PrecisionSettings>() {
implements_clone::<RecordItem<ModuleBasic<B>, B, S>>();
implements_clone::<RecordItem<ModuleComposed<B>, B, S>>();
}
fn generic_implements_clone<B, S, M>()
where
B: Backend,
S: PrecisionSettings,
M: Module<B> + ModuleDisplay,
RecordItem<M, B, S>: Clone,
{
implements_clone::<RecordItem<ModuleWithGenericModule<B, M>, B, S>>();
implements_clone::<RecordItem<ModuleEnumWithGenericModule<B, M>, B, S>>();
}
}
mod state {
use super::*;
#[test]
fn should_load_from_record_basic() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleBasic::<TestBackend>::new(&device);
let mut module_2 = ModuleBasic::<TestBackend>::new(&device);
let state_1 = module_1.clone().into_record();
assert_ne!(
module_1.weight_basic.to_data(),
module_2.weight_basic.to_data()
);
module_2 = module_2.load_record(state_1);
assert_eq!(
module_1.weight_basic.to_data(),
module_2.weight_basic.to_data()
);
}
#[test]
fn should_load_from_record_compose() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleComposed::<TestBackend>::new(&device);
let mut module_2 = ModuleComposed::<TestBackend>::new(&device);
assert_ne!(module_1.weight.to_data(), module_2.weight.to_data());
assert_ne!(
module_1.basic.weight_basic.to_data(),
module_2.basic.weight_basic.to_data()
);
let state_1 = module_1.clone().into_record();
module_2 = module_2.load_record(state_1);
assert_eq!(module_1.weight.to_data(), module_2.weight.to_data());
assert_eq!(
module_1.basic.weight_basic.to_data(),
module_2.basic.weight_basic.to_data()
);
}
#[test]
fn should_load_from_record_enum() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let mut module_2 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let state_1 = module_1.clone().into_record();
let ModuleEnum::Basic(module_1_basic) = module_1 else {
panic!("Invalid module type")
};
let ModuleEnum::Basic(module_2_basic) = module_2.clone() else {
panic!("Invalid module type")
};
assert_ne!(
module_1_basic.weight_basic.to_data(),
module_2_basic.weight_basic.to_data()
);
module_2 = module_2.load_record(state_1);
let ModuleEnum::Basic(module_2_basic) = module_2 else {
panic!("Invalid module type")
};
assert_eq!(
module_1_basic.weight_basic.to_data(),
module_2_basic.weight_basic.to_data()
);
}
#[test]
fn should_load_from_record_const_generic() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleWithConstGeneric {
modules: [
ModuleBasic::<TestBackend>::new(&device),
ModuleBasic::<TestBackend>::new(&device),
],
};
let mut module_2 = ModuleWithConstGeneric {
modules: [
ModuleBasic::<TestBackend>::new(&device),
ModuleBasic::<TestBackend>::new(&device),
],
};
let state_1 = module_1.clone().into_record();
assert_ne!(
module_1.modules[0].weight_basic.to_data(),
module_2.modules[0].weight_basic.to_data(),
);
assert_ne!(
module_1.modules[1].weight_basic.to_data(),
module_2.modules[1].weight_basic.to_data(),
);
module_2 = module_2.load_record(state_1);
assert_eq!(
module_1.modules[0].weight_basic.to_data(),
module_2.modules[0].weight_basic.to_data(),
);
assert_eq!(
module_1.modules[1].weight_basic.to_data(),
module_2.modules[1].weight_basic.to_data(),
);
}
#[test]
#[should_panic(expected = "Can't parse record from a different variant")]
fn should_panic_load_from_incorrect_enum_variant() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let module_2 = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
let state_1 = module_1.clone().into_record();
module_2.load_record(state_1);
}
}
mod num_params {
use super::*;
#[test]
fn should_calculate_num_params_basic() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleBasic::<TestBackend>::new(&device);
assert_eq!(20 * 20, module.num_params());
}
#[test]
fn should_output_state_composed() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleComposed::<TestBackend>::new(&device);
assert_eq!(4 * 20 * 20, module.num_params());
}
#[test]
fn should_calculate_num_params_enum() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
assert_eq!(20 * 20, module.num_params());
let module = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
assert_eq!(4 * 20 * 20, module.num_params());
}
}
#[cfg(feature = "std")]
mod require_grad {
use burn_tensor::backend::AutodiffBackend;
use super::*;
#[test]
fn should_have_grad_by_default() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleBasic::<TestAutodiffBackend>::new(&device);
let mut grads = calculate_grads(&module);
let grad_x = module.weight_basic.grad_remove(&mut grads);
assert!(grad_x.is_some());
}
#[test]
fn should_have_no_grad_after_no_grad() {
let device = <TestAutodiffBackend as Backend>::Device::default();
let module = ModuleBasic::<TestAutodiffBackend>::new(&device).no_grad();
let mut grads = calculate_grads(&module);
let grad_x = module.weight_basic.grad_remove(&mut grads);
assert!(grad_x.is_none());
}
#[test]
fn should_have_grad_when_from_record() {
let device = <TestAutodiffBackend as Backend>::Device::default();
let module = ModuleBasic::<TestAutodiffBackend>::new(&device);
let record = ModuleBasicRecord {
weight_basic: module.weight_basic.clone(), // Even when param is no_grad,
};
let module = module.load_record(record);
let mut grads = calculate_grads(&module);
let grad_x = module.weight_basic.grad_remove(&mut grads);
assert!(grad_x.is_some());
}
fn calculate_grads(
module: &ModuleBasic<TestAutodiffBackend>,
) -> <TestAutodiffBackend as AutodiffBackend>::Gradients {
let device = module.weight_basic.device();
let x = Tensor::ones([20, 20], &device).require_grad();
let y = module.weight_basic.val().matmul(x);
y.backward()
}
}

View File

@@ -0,0 +1,17 @@
use burn_core as burn;
use burn_core::record::Record;
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;
// It compiles
#[derive(Record)]
pub struct TestWithBackendRecord<B: Backend> {
tensor: Tensor<B, 2>,
}
// It compiles
#[derive(Record)]
pub struct TestWithoutBackendRecord {
_tensor: usize,
}

View File

@@ -0,0 +1,344 @@
#[cfg(feature = "std")]
mod tests {
use burn::{
module::{Module, Param},
record::{
BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings,
PrettyJsonFileRecorder, RecorderError,
},
};
use burn_core as burn;
use burn_ndarray::NdArrayDevice;
use burn_tensor::{Tensor, backend::Backend};
use std::path::PathBuf;
type TestBackend = burn_ndarray::NdArray<f32>;
/// Simple linear module.
#[derive(Module, Debug)]
pub struct Linear<B: Backend> {
pub weight: Param<Tensor<B, 2>>,
pub bias: Option<Param<Tensor<B, 1>>>,
}
impl<B: Backend> Linear<B> {
pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
let weight = Tensor::random(
[out_features, in_features],
burn_tensor::Distribution::Default,
device,
);
let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
Self {
weight: Param::from_tensor(weight),
bias: Some(Param::from_tensor(bias)),
}
}
}
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
single_const: f32,
linear1: Linear<B>,
array_const: [usize; 2],
linear2: Linear<B>,
}
#[derive(Module, Debug)]
pub struct ModelNewOptionalField<B: Backend> {
single_const: f32,
linear1: Linear<B>,
array_const: [usize; 2],
linear2: Linear<B>,
new_field: Option<usize>,
}
#[derive(Module, Debug)]
pub struct ModelNewConstantField<B: Backend> {
single_const: f32,
linear1: Linear<B>,
array_const: [usize; 2],
linear2: Linear<B>,
new_field: usize,
}
#[derive(Module, Debug)]
#[allow(unused)]
pub struct ModelNewFieldOrders<B: Backend> {
array_const: [usize; 2],
linear2: Linear<B>,
single_const: f32,
linear1: Linear<B>,
}
#[test]
fn deserialize_with_new_optional_field_works_with_default_file_recorder() {
deserialize_with_new_optional_field(
"default",
DefaultFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_removed_optional_field_works_with_default_file_recorder() {
deserialize_with_removed_optional_field(
"default",
DefaultFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_new_constant_field_works_with_default_file_recorder() {
deserialize_with_new_constant_field(
"default",
DefaultFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_removed_constant_field_works_with_default_file_recorder() {
deserialize_with_removed_constant_field(
"default",
DefaultFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_new_field_order_works_with_default_file_recorder() {
deserialize_with_new_field_order(
"default",
DefaultFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_new_optional_field_works_with_pretty_json() {
deserialize_with_new_optional_field(
"pretty-json",
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_removed_optional_field_works_with_pretty_json() {
deserialize_with_removed_optional_field(
"pretty-json",
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_new_constant_field_works_with_pretty_json() {
deserialize_with_new_constant_field(
"pretty-json",
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_removed_constant_field_works_with_pretty_json() {
deserialize_with_removed_constant_field(
"pretty-json",
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_new_field_order_works_with_pretty_json() {
deserialize_with_new_field_order(
"pretty-json",
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
#[should_panic]
fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() {
deserialize_with_new_optional_field("bin", BinFileRecorder::<FullPrecisionSettings>::new())
.unwrap();
}
#[test]
fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() {
deserialize_with_removed_optional_field(
"bin",
BinFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
fn deserialize_with_new_constant_field_works_with_bin_file_recorder() {
deserialize_with_new_constant_field("bin", BinFileRecorder::<FullPrecisionSettings>::new())
.unwrap();
}
#[test]
fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() {
deserialize_with_removed_constant_field(
"bin",
BinFileRecorder::<FullPrecisionSettings>::new(),
)
.unwrap();
}
#[test]
#[should_panic]
fn deserialize_with_new_field_order_works_with_bin_file_recorder() {
deserialize_with_new_field_order("bin", BinFileRecorder::<FullPrecisionSettings>::new())
.unwrap();
}
#[inline(always)]
fn file_path(filename: String) -> PathBuf {
std::env::temp_dir().join(filename)
}
#[test]
fn test_tensor_serde() {
let tensor: burn_tensor::Tensor<TestBackend, 1> =
burn_tensor::Tensor::ones([1], &NdArrayDevice::default());
let encoded = serde_json::to_string(&tensor).unwrap();
let decoded: burn_tensor::Tensor<TestBackend, 1> = serde_json::from_str(&encoded).unwrap();
assert_eq!(tensor.into_data(), decoded.into_data());
}
fn deserialize_with_new_optional_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
where
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf = file_path(format!("deserialize_with_new_optional_field-{name}"));
let model = Model {
single_const: 32.0,
linear1: Linear::<TestBackend>::new(20, 20, &device),
array_const: [2, 2],
linear2: Linear::<TestBackend>::new(20, 20, &device),
};
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result =
recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
Ok(())
}
fn deserialize_with_removed_optional_field<R>(
name: &str,
recorder: R,
) -> Result<(), RecorderError>
where
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf =
file_path(format!("deserialize_with_removed_optional_field-{name}"));
let model = ModelNewOptionalField {
single_const: 32.0,
linear1: Linear::<TestBackend>::new(20, 20, &device),
array_const: [2, 2],
linear2: Linear::<TestBackend>::new(20, 20, &device),
new_field: None,
};
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
Ok(())
}
fn deserialize_with_new_constant_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
where
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf = file_path(format!("deserialize_with_new_constant_field-{name}"));
let model = Model {
single_const: 32.0,
array_const: [2, 2],
linear1: Linear::<TestBackend>::new(20, 20, &device),
linear2: Linear::<TestBackend>::new(20, 20, &device),
};
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result =
recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
Ok(())
}
fn deserialize_with_removed_constant_field<R>(
name: &str,
recorder: R,
) -> Result<(), RecorderError>
where
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf =
file_path(format!("deserialize_with_removed_constant_field-{name}"));
let model = ModelNewConstantField {
single_const: 32.0,
array_const: [2, 2],
linear1: Linear::<TestBackend>::new(20, 20, &device),
linear2: Linear::<TestBackend>::new(20, 20, &device),
new_field: 0,
};
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
Ok(())
}
fn deserialize_with_new_field_order<R>(name: &str, recorder: R) -> Result<(), RecorderError>
where
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf = file_path(format!("deserialize_with_new_field_order-{name}"));
let model = Model {
array_const: [2, 2],
single_const: 32.0,
linear1: Linear::<TestBackend>::new(20, 20, &device),
linear2: Linear::<TestBackend>::new(20, 20, &device),
};
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result =
recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
Ok(())
}
}