feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
151
crates/stable-diffusion-burn/burn-crates/burn-core/Cargo.toml
Normal file
151
crates/stable-diffusion-burn/burn-crates/burn-core/Cargo.toml
Normal file
@@ -0,0 +1,151 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Flexible and Comprehensive Deep Learning Framework in Rust"
|
||||
documentation = "https://docs.rs/burn-core"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-core"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-core"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"std",
|
||||
"burn-std/default",
|
||||
"burn-dataset?/default",
|
||||
"burn-tensor/default",
|
||||
]
|
||||
doc = [
|
||||
"std",
|
||||
"dataset",
|
||||
"audio",
|
||||
# Doc features
|
||||
"burn-std/doc",
|
||||
"burn-dataset/doc",
|
||||
"burn-tensor/doc",
|
||||
]
|
||||
tracing = [
|
||||
"burn-std/tracing",
|
||||
"burn-tensor/tracing",
|
||||
"burn-dataset?/tracing",
|
||||
"burn-vision?/tracing",
|
||||
]
|
||||
|
||||
|
||||
dataset = ["burn-dataset"]
|
||||
|
||||
network = ["burn-std/network"]
|
||||
sqlite = ["burn-dataset?/sqlite"]
|
||||
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
|
||||
std = [
|
||||
"bincode/std",
|
||||
"burn-std/std",
|
||||
"burn-tensor/std",
|
||||
"flate2",
|
||||
"half/std",
|
||||
"log",
|
||||
"rand/std",
|
||||
"rmp-serde",
|
||||
"serde/std",
|
||||
"serde_json/std",
|
||||
"num-traits/std",
|
||||
]
|
||||
vision = ["burn-vision", "burn-dataset?/vision"]
|
||||
audio = ["burn-dataset?/audio"]
|
||||
|
||||
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
|
||||
record-item-custom-serde = ["thiserror"]
|
||||
|
||||
# Serialization formats
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
|
||||
test-cuda = [
|
||||
"burn-cuda/default",
|
||||
] # To use cuda during testing, default uses ndarray.
|
||||
test-rocm = [
|
||||
"burn-rocm/default",
|
||||
] # To use hip during testing, default uses ndarray.
|
||||
test-tch = [
|
||||
"burn-tch/default",
|
||||
] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = [
|
||||
"burn-wgpu/default",
|
||||
] # To use wgpu during testing, default uses ndarray.
|
||||
test-vulkan = [
|
||||
"test-wgpu",
|
||||
"burn-wgpu/vulkan",
|
||||
] # To use wgpu-spirv during testing, default uses ndarray.
|
||||
test-metal = [
|
||||
"test-wgpu",
|
||||
"burn-wgpu/metal",
|
||||
] # To use wgpu-spirv during testing, default uses ndarray.
|
||||
|
||||
# Memory checks are disabled by default
|
||||
test-memory-checks = ["burn-fusion/memory-checks"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-dataset = { path = "../burn-dataset", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-derive = { path = "../burn-derive", version = "=0.21.0-pre.2" }
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-vision = { path = "../burn-vision", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
|
||||
data-encoding = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true, optional = true }
|
||||
rand = { workspace = true }
|
||||
|
||||
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
|
||||
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
|
||||
|
||||
# Serialize Deserialize
|
||||
flate2 = { workspace = true, optional = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
ahash = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
half = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
rmp-serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
|
||||
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
|
||||
thiserror = { workspace = true, optional = true }
|
||||
|
||||
[target.'cfg(target_has_atomic = "ptr")'.dependencies]
|
||||
regex = { workspace = true }
|
||||
|
||||
# FOR TESTING
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
|
||||
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
portable-atomic-util = { workspace = true }
|
||||
portable-atomic = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
|
||||
burn-dataset = { path = "../burn-dataset", version = "=0.21.0-pre.2", features = [
|
||||
"fake",
|
||||
] }
|
||||
rstest = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-core/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-core/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
15
crates/stable-diffusion-burn/burn-crates/burn-core/README.md
Normal file
15
crates/stable-diffusion-burn/burn-crates/burn-core/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Burn Core
|
||||
|
||||
This crate should be used with [burn](https://github.com/tracel-ai/burn). It contains the core
|
||||
traits and components for building and training deep learning models with Burn.
|
||||
|
||||
[](https://crates.io/crates/burn-core)
|
||||
[](https://github.com/tracel-ai/burn-core/blob/master/README.md)
|
||||
|
||||
## Feature Flags
|
||||
|
||||
This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the
|
||||
default `std` feature.
|
||||
|
||||
- `std` - enables the standard library. Enabled by default.
|
||||
- `experimental-named-tensor` - enables experimental named tensor.
|
||||
@@ -0,0 +1,98 @@
|
||||
use alloc::{format, string::String, string::ToString};
|
||||
pub use burn_derive::Config;
|
||||
use core::fmt::Debug;
|
||||
|
||||
/// Configuration IO error.
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
/// Invalid format.
|
||||
InvalidFormat(String),
|
||||
|
||||
/// File not found.
|
||||
FileNotFound(String),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let mut message = "Config error => ".to_string();
|
||||
|
||||
match self {
|
||||
Self::InvalidFormat(err) => {
|
||||
message += format!("Invalid format: {err}").as_str();
|
||||
}
|
||||
Self::FileNotFound(err) => {
|
||||
message += format!("File not found: {err}").as_str();
|
||||
}
|
||||
};
|
||||
|
||||
f.write_str(message.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl core::error::Error for ConfigError {}
|
||||
|
||||
/// Configuration trait.
|
||||
pub trait Config: Debug + serde::Serialize + serde::de::DeserializeOwned {
|
||||
/// Saves the configuration to a file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - File to save the configuration to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output of the save operation.
|
||||
#[cfg(feature = "std")]
|
||||
fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
|
||||
std::fs::write(file, config_to_json(self))
|
||||
}
|
||||
|
||||
/// Loads the configuration from a file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - File to load the configuration from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The loaded configuration.
|
||||
#[cfg(feature = "std")]
|
||||
fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
|
||||
let content = std::fs::read_to_string(file.as_ref())
|
||||
.map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
|
||||
config_from_str(&content)
|
||||
}
|
||||
|
||||
/// Loads the configuration from a binary buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - Binary buffer to load the configuration from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The loaded configuration.
|
||||
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
|
||||
let content = core::str::from_utf8(data).map_err(|_| {
|
||||
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
|
||||
})?;
|
||||
config_from_str(content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a configuration to a JSON string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Configuration to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The JSON string.
|
||||
pub fn config_to_json<C: Config>(config: &C) -> String {
|
||||
serde_json::to_string_pretty(config).unwrap()
|
||||
}
|
||||
|
||||
fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
|
||||
serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
pub use crate::data::dataset::{Dataset, DatasetIterator};
|
||||
use core::iter::Iterator;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A progress struct that can be used to track the progress of a data loader.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct Progress {
|
||||
/// The number of items that have been processed.
|
||||
pub items_processed: usize,
|
||||
|
||||
/// The total number of items that need to be processed.
|
||||
pub items_total: usize,
|
||||
}
|
||||
|
||||
/// A data loader iterator that can be used to iterate over a data loader.
|
||||
pub trait DataLoaderIterator<O>: Iterator<Item = O> {
|
||||
/// Returns the progress of the data loader.
|
||||
fn progress(&self) -> Progress;
|
||||
}
|
||||
|
||||
/// A data loader that can be used to iterate over a dataset.
|
||||
pub trait DataLoader<B: Backend, O>: Send + Sync {
|
||||
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
|
||||
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
|
||||
|
||||
/// The number of items (not the number of batches nor the number of iterations),
|
||||
/// corresponding to the items_total of the progress returned by the iterator.
|
||||
fn num_items(&self) -> usize;
|
||||
|
||||
/// Move the data loader to the given device, ensuring the batches are assigned to the correct device.
|
||||
fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>>;
|
||||
|
||||
/// Returns a new data loader containing a subset of the data.
|
||||
///
|
||||
/// The subset includes items from `start` (inclusive) to `end` (exclusive),
|
||||
/// preserving the batch size and ordering of the original data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `start` - The starting index of the subset (inclusive).
|
||||
/// * `end` - The ending index of the subset (exclusive).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boxed [`DataLoader`] instance containing only the specified range.
|
||||
fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>>;
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
|
||||
use burn_dataset::{
|
||||
Dataset,
|
||||
transform::{PartialDataset, ShuffledDataset},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
use rand::SeedableRng;
|
||||
use std::ops::DerefMut;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A data loader that can be used to iterate over a dataset in batches.
|
||||
pub struct BatchDataLoader<B: Backend, I, O> {
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
device: B::Device,
|
||||
rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
strategy: self.strategy.clone_dyn(),
|
||||
dataset: self.dataset.clone(),
|
||||
batcher: self.batcher.clone(),
|
||||
device: self.device.clone(),
|
||||
rng: self.rng.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
|
||||
/// Creates a new batch data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The batch strategy.
|
||||
/// * `dataset` - The dataset.
|
||||
/// * `batcher` - The batcher.
|
||||
/// * `device` - The device to use when loading a batch.
|
||||
/// * `rng` - The rng determining if the dataset is shuffled each time a dataloader
|
||||
/// iterator is created.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batch data loader.
|
||||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
device: B::Device,
|
||||
rng: Option<rand::rngs::StdRng>,
|
||||
) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
dataset,
|
||||
batcher,
|
||||
device,
|
||||
rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A data loader iterator that can be used to iterate over a data loader.
|
||||
struct BatchDataloaderIterator<B: Backend, I, O> {
|
||||
current_index: usize,
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
|
||||
where
|
||||
B: Backend,
|
||||
I: Send + Sync + Clone + 'static,
|
||||
O: Send + 'static,
|
||||
{
|
||||
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
|
||||
// When starting a new iteration, we first check if the dataloader was created with an rng,
|
||||
// implying that we should shuffle the dataset beforehand, while advancing the current
|
||||
// rng to ensure that each new iteration shuffles the dataset differently.
|
||||
let dataset = match &self.rng {
|
||||
Some(rng) => Arc::new(ShuffledDataset::new(
|
||||
self.dataset.clone(),
|
||||
rng.lock().deref_mut(),
|
||||
)),
|
||||
None => self.dataset.clone(),
|
||||
};
|
||||
Box::new(BatchDataloaderIterator::new(
|
||||
self.strategy.clone_dyn(),
|
||||
dataset,
|
||||
self.batcher.clone(),
|
||||
self.device.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn num_items(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
|
||||
fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
|
||||
let rng = self.rng.as_ref().map(|rng| {
|
||||
let mut rng = rng.lock();
|
||||
rng.fork()
|
||||
});
|
||||
Arc::new(Self::new(
|
||||
self.strategy.clone_dyn(),
|
||||
self.dataset.clone(),
|
||||
self.batcher.clone(),
|
||||
device.clone(),
|
||||
rng,
|
||||
))
|
||||
}
|
||||
|
||||
fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
|
||||
let rng = self.rng.as_ref().map(|rng| {
|
||||
let mut rng = rng.lock();
|
||||
rng.fork()
|
||||
});
|
||||
let dataloader = Self::new(
|
||||
self.strategy.clone_dyn(),
|
||||
Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
|
||||
self.batcher.clone(),
|
||||
self.device.clone(),
|
||||
rng,
|
||||
);
|
||||
Arc::new(dataloader)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
|
||||
/// Creates a new batch data loader iterator.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The batch strategy.
|
||||
/// * `dataset` - The dataset.
|
||||
/// * `batcher` - The batcher.
|
||||
/// * `device` - The device to use when loading a batch.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batch data loader iterator.
|
||||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
device: B::Device,
|
||||
) -> Self {
|
||||
BatchDataloaderIterator {
|
||||
current_index: 0,
|
||||
strategy,
|
||||
dataset,
|
||||
batcher,
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
|
||||
type Item = O;
|
||||
|
||||
fn next(&mut self) -> Option<O> {
|
||||
while let Some(item) = self.dataset.get(self.current_index) {
|
||||
self.current_index += 1;
|
||||
self.strategy.add(item);
|
||||
|
||||
if let Some(items) = self.strategy.batch(false) {
|
||||
return Some(self.batcher.batch(items, &self.device));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(items) = self.strategy.batch(true) {
|
||||
return Some(self.batcher.batch(items, &self.device));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
|
||||
fn progress(&self) -> Progress {
|
||||
Progress::new(self.current_index, self.dataset.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::*;
|
||||
use crate::data::dataloader::FixBatchStrategy;
|
||||
use crate::data::dataloader::batcher::TestBatcher;
|
||||
use crate::data::dataset::FakeDataset;
|
||||
|
||||
#[test]
|
||||
fn test_batch_dataloader() {
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(27));
|
||||
let dataloader = BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset.clone(),
|
||||
batcher,
|
||||
Default::default(),
|
||||
None,
|
||||
);
|
||||
|
||||
let mut items_dataset = HashSet::new();
|
||||
let mut items_dataloader = HashSet::new();
|
||||
|
||||
for item in dataset.iter() {
|
||||
items_dataset.insert(item);
|
||||
}
|
||||
|
||||
for items in dataloader.iter() {
|
||||
for item in items {
|
||||
items_dataloader.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(items_dataset, items_dataloader);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_dataloader_slice() {
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(27));
|
||||
let dataloader = BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset.clone(),
|
||||
batcher,
|
||||
Default::default(),
|
||||
None,
|
||||
);
|
||||
let dataloader_slice = dataloader.slice(5, 15);
|
||||
|
||||
let mut items_dataloader = HashSet::new();
|
||||
let mut items_dataloader_slice = HashSet::new();
|
||||
|
||||
let mut idx = 0;
|
||||
for items in dataloader.iter() {
|
||||
for item in items {
|
||||
if (5..15).contains(&idx) {
|
||||
items_dataloader.insert(item);
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for items in dataloader_slice.iter() {
|
||||
for item in items {
|
||||
items_dataloader_slice.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(items_dataloader, items_dataloader_slice);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::TestBackend;
|
||||
|
||||
/// A trait for batching items of type `I` into items of type `O`.
|
||||
pub trait Batcher<B: Backend, I, O>: Send + Sync {
|
||||
/// Batches the given items on the specified device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - The items to batch.
|
||||
/// * `device` - The backend device to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batched items.
|
||||
fn batch(&self, items: Vec<I>, device: &B::Device) -> O;
|
||||
}
|
||||
|
||||
/// Test batcher
|
||||
#[cfg(test)]
|
||||
#[derive(new, Clone)]
|
||||
pub struct TestBatcher;
|
||||
|
||||
#[cfg(test)]
|
||||
impl<I> Batcher<TestBackend, I, Vec<I>> for TestBatcher {
|
||||
fn batch(&self, items: Vec<I>, _device: &<TestBackend as Backend>::Device) -> Vec<I> {
|
||||
items
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
use super::{
|
||||
BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
|
||||
batcher::Batcher,
|
||||
};
|
||||
use burn_dataset::Dataset;
|
||||
use burn_tensor::backend::Backend;
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A builder for data loaders.
|
||||
pub struct DataLoaderBuilder<B: Backend, I, O> {
|
||||
strategy: Option<Box<dyn BatchStrategy<I>>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
num_threads: Option<usize>,
|
||||
shuffle: Option<u64>,
|
||||
device: Option<B::Device>,
|
||||
}
|
||||
|
||||
impl<B, I, O> DataLoaderBuilder<B, I, O>
|
||||
where
|
||||
B: Backend,
|
||||
I: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||
O: Send + Clone + std::fmt::Debug + 'static,
|
||||
{
|
||||
/// Creates a new data loader builder.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batcher` - The batcher.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn new<Bt>(batcher: Bt) -> Self
|
||||
where
|
||||
Bt: Batcher<B, I, O> + 'static,
|
||||
{
|
||||
Self {
|
||||
batcher: Arc::new(batcher),
|
||||
strategy: None,
|
||||
num_threads: None,
|
||||
shuffle: None,
|
||||
device: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the batch size to a fix number.
|
||||
///
|
||||
/// The [fix batch strategy](FixBatchStrategy) will be used.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batch_size` - The batch size.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn batch_size(mut self, batch_size: usize) -> Self {
|
||||
self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the seed for shuffling.
|
||||
///
|
||||
/// Each time the dataloader starts a new iteration, the dataset will be shuffled.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `seed` - The seed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn shuffle(mut self, seed: u64) -> Self {
|
||||
self.shuffle = Some(seed);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of workers.
|
||||
///
|
||||
/// - `Some(0)` or `None`: the dataloader will run without work threads.
|
||||
/// - `Some(n); n > 0`: the dataloader will run with `n` background threads.
|
||||
///
|
||||
/// A 1-worker threaded dataloader will run loads in a background thread,
|
||||
/// while a 0-worker threaded dataloader will run loads in the main thread.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num_workers` - The number of workers.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn num_workers(mut self, num_workers: usize) -> Self {
|
||||
self.num_threads = Some(num_workers);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the data loader device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The device to use when loading a batch.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn set_device(mut self, device: B::Device) -> Self {
|
||||
self.device = Some(device);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The dataset.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader.
|
||||
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
|
||||
where
|
||||
D: Dataset<I> + 'static,
|
||||
{
|
||||
let dataset = Arc::new(dataset);
|
||||
|
||||
let device = self.device.unwrap_or_default();
|
||||
let rng = self.shuffle.map(StdRng::seed_from_u64);
|
||||
let strategy = match self.strategy {
|
||||
Some(strategy) => strategy,
|
||||
None => Box::new(FixBatchStrategy::new(1)),
|
||||
};
|
||||
|
||||
if let Some(num_threads) = self.num_threads
|
||||
&& num_threads > 0
|
||||
{
|
||||
return Arc::new(MultiThreadDataLoader::new(
|
||||
strategy,
|
||||
dataset,
|
||||
self.batcher,
|
||||
num_threads,
|
||||
device,
|
||||
rng,
|
||||
));
|
||||
}
|
||||
|
||||
Arc::new(BatchDataLoader::new(
|
||||
strategy,
|
||||
dataset,
|
||||
self.batcher,
|
||||
device,
|
||||
rng,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use crate::data::dataset::FakeDataset;
|
||||
|
||||
#[derive(new, Clone)]
|
||||
struct TestBatcherDevice;
|
||||
|
||||
#[cfg(test)]
|
||||
impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
|
||||
fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
|
||||
*device
|
||||
}
|
||||
}
|
||||
|
||||
type TestDevice = <TestBackend as Backend>::Device;
|
||||
|
||||
#[test]
|
||||
fn test_dataloader_no_workers() {
|
||||
type TestDevice = <TestBackend as Backend>::Device;
|
||||
|
||||
let default_device = TestDevice::default();
|
||||
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
|
||||
.batch_size(1)
|
||||
.build(FakeDataset::<String>::new(9));
|
||||
|
||||
assert_eq!(dataloader.num_items(), 9);
|
||||
|
||||
for device in dataloader.iter() {
|
||||
assert_eq!(device, default_device)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataloader_default_device() {
|
||||
let default_device = TestDevice::default();
|
||||
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
|
||||
.batch_size(1)
|
||||
.num_workers(1)
|
||||
.build(FakeDataset::<String>::new(9));
|
||||
|
||||
assert_eq!(dataloader.num_items(), 9);
|
||||
|
||||
for device in dataloader.iter() {
|
||||
assert_eq!(device, default_device)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataloader_slice_multi_device() {
|
||||
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
|
||||
.batch_size(1)
|
||||
.num_workers(1)
|
||||
.build(FakeDataset::<String>::new(11));
|
||||
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-tch"),
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda")
|
||||
))]
|
||||
// Only one device exists...
|
||||
let (device1, device2) = (
|
||||
burn_ndarray::NdArrayDevice::Cpu,
|
||||
burn_ndarray::NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
#[cfg(all(test, feature = "test-tch"))]
|
||||
let (device1, device2) = (
|
||||
burn_tch::LibTorchDevice::Cuda(0),
|
||||
burn_tch::LibTorchDevice::Cuda(1),
|
||||
);
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
let (device1, device2) = (
|
||||
burn_wgpu::WgpuDevice::DiscreteGpu(0),
|
||||
burn_wgpu::WgpuDevice::DiscreteGpu(1),
|
||||
);
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
|
||||
|
||||
assert_eq!(dataloader.num_items(), 11);
|
||||
let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
|
||||
let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
|
||||
|
||||
assert_eq!(dataloader_1.num_items(), 5);
|
||||
assert_eq!(dataloader_2.num_items(), 6);
|
||||
|
||||
let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
|
||||
|
||||
for _ in 0..5 {
|
||||
assert_eq!(iterator_1.next(), Some(device1));
|
||||
assert_eq!(iterator_2.next(), Some(device2));
|
||||
}
|
||||
|
||||
assert_eq!(iterator_1.next(), None);
|
||||
// For uneven split, the last dataloader (partial dataset) will have the remaining item
|
||||
assert_eq!(iterator_2.next(), Some(device2));
|
||||
assert_eq!(iterator_2.next(), None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
mod base;
|
||||
mod batch;
|
||||
mod builder;
|
||||
mod multithread;
|
||||
mod strategy;
|
||||
|
||||
/// Module for batching items.
|
||||
pub mod batcher;
|
||||
/// Module to split a dataloader.
|
||||
pub mod split;
|
||||
|
||||
pub use base::*;
|
||||
pub use batch::*;
|
||||
pub use builder::*;
|
||||
pub use multithread::*;
|
||||
pub use strategy::*;
|
||||
@@ -0,0 +1,441 @@
|
||||
use burn_dataset::Dataset;
|
||||
use burn_dataset::transform::PartialDataset;
|
||||
use burn_tensor::backend::Backend;
|
||||
use rand::distr::{Distribution, StandardUniform};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::batcher::Batcher;
|
||||
use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
|
||||
use std::sync::{Arc, OnceLock, mpsc};
|
||||
use std::thread;
|
||||
|
||||
const MAX_QUEUED_ITEMS: usize = 100;
|
||||
|
||||
type RngSeed = <StdRng as SeedableRng>::Seed;
|
||||
|
||||
/// A multi-threaded data loader that can be used to iterate over a dataset.
|
||||
pub struct MultiThreadDataLoader<B: Backend, I, O> {
|
||||
// Configuration parameters needed for initialization
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
device: B::Device,
|
||||
seed: Option<RngSeed>,
|
||||
num_threads: usize,
|
||||
|
||||
// The lazily initialized data loaders
|
||||
dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,
|
||||
}
|
||||
|
||||
/// A message that can be sent between threads.
|
||||
#[derive(Debug)]
|
||||
pub enum Message<O> {
|
||||
/// A batch of items.
|
||||
Batch(usize, O, Progress),
|
||||
|
||||
/// The thread is done.
|
||||
Done,
|
||||
}
|
||||
|
||||
struct MultiThreadsDataloaderIterator<O> {
|
||||
num_done: usize,
|
||||
workers: Vec<thread::JoinHandle<()>>,
|
||||
receiver: mpsc::Receiver<Message<O>>,
|
||||
progresses: Vec<Progress>,
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
|
||||
where
|
||||
I: Send + Sync + Clone + 'static,
|
||||
O: Send + 'static,
|
||||
{
|
||||
/// Creates a new multi-threaded batch data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The batch strategy.
|
||||
/// * `dataset` - The dataset.
|
||||
/// * `batcher` - The batcher.
|
||||
/// * `num_threads` - The number of threads.
|
||||
/// * `device` - The device to use when loading a batch.
|
||||
/// * `rng` - The rng determining if the dataset is shuffled each time a dataloader
|
||||
/// iterator is created.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The multi-threaded batch data loader.
|
||||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
num_threads: usize,
|
||||
device: B::Device,
|
||||
rng: Option<rand::rngs::StdRng>,
|
||||
) -> Self {
|
||||
let mut seed = None;
|
||||
if let Some(mut rng) = rng {
|
||||
// RNG stream splitting (not state cloning): derive a new seed from the RNG's output.
|
||||
// This is exactly what `rng.fork()` does.
|
||||
let mut s = RngSeed::default();
|
||||
rng.fill_bytes(&mut s);
|
||||
|
||||
seed = Some(s);
|
||||
}
|
||||
Self::from_seed(strategy, dataset, batcher, num_threads, device, seed)
|
||||
}
|
||||
|
||||
fn from_seed(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<B, I, O>>,
|
||||
num_threads: usize,
|
||||
device: B::Device,
|
||||
seed: Option<RngSeed>,
|
||||
) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
dataset,
|
||||
batcher,
|
||||
num_threads,
|
||||
device,
|
||||
seed,
|
||||
dataloaders: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Force initialization if needed.
|
||||
fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
|
||||
self.dataloaders
|
||||
.get_or_init(|| {
|
||||
let mut dataset = self.dataset.clone();
|
||||
if let Some(seed) = self.seed.as_ref() {
|
||||
// Pre-shuffle the dataset before split if shuffle is enabled.
|
||||
// This ensures that each thread gets a uniform random sample of the dataset.
|
||||
let mut rng = StdRng::from_seed(*seed);
|
||||
dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(
|
||||
dataset, &mut rng,
|
||||
));
|
||||
}
|
||||
|
||||
let datasets = match self.strategy.batch_size() {
|
||||
Some(batch_size) => {
|
||||
PartialDataset::split_chunks(dataset, self.num_threads, batch_size)
|
||||
}
|
||||
None => PartialDataset::split(dataset, self.num_threads),
|
||||
};
|
||||
|
||||
// Create more rngs from the first one, one for each new dataloader.
|
||||
let mut rng = self.seed.map(StdRng::from_seed);
|
||||
let rngs = (0..self.num_threads).map(|_| {
|
||||
rng.as_mut().map(|rng| {
|
||||
StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
|
||||
})
|
||||
});
|
||||
|
||||
datasets
|
||||
.into_iter()
|
||||
.zip(rngs)
|
||||
.map(|(dataset, rng)| {
|
||||
let strategy = self.strategy.clone_dyn();
|
||||
BatchDataLoader::new(
|
||||
strategy,
|
||||
Arc::new(dataset),
|
||||
self.batcher.clone(),
|
||||
self.device.clone(),
|
||||
rng,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
|
||||
where
|
||||
I: Send + Sync + Clone + 'static,
|
||||
O: Send + 'static + std::fmt::Debug,
|
||||
{
|
||||
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
|
||||
// This will initialize the loader if it hasn't been initialized yet
|
||||
let dataloaders = self.initialize();
|
||||
|
||||
let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
|
||||
|
||||
let mut progresses = Vec::with_capacity(dataloaders.len());
|
||||
|
||||
let handlers: Vec<_> = dataloaders
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, dataloader)| {
|
||||
let dataloader_cloned = dataloader.clone();
|
||||
let sender_cloned = sender.clone();
|
||||
progresses.push(Progress::new(0, dataloader_cloned.num_items()));
|
||||
|
||||
thread::spawn(move || {
|
||||
let mut iterator = dataloader_cloned.iter();
|
||||
while let Some(item) = iterator.next() {
|
||||
let progress = iterator.progress();
|
||||
|
||||
match sender_cloned.send(Message::Batch(index, item, progress)) {
|
||||
Ok(_) => {}
|
||||
// The receiver is probably gone, no need to panic, just need to stop
|
||||
// iterating.
|
||||
Err(_) => return,
|
||||
};
|
||||
}
|
||||
// Same thing.
|
||||
sender_cloned.send(Message::Done).ok();
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Box::new(MultiThreadsDataloaderIterator::new(
|
||||
receiver, handlers, progresses,
|
||||
))
|
||||
}
|
||||
|
||||
fn num_items(&self) -> usize {
|
||||
// For num_items, we can directly use the dataset size without
|
||||
// necessarily initializing the full loader
|
||||
self.dataset.len()
|
||||
}
|
||||
|
||||
fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
|
||||
Arc::new(Self::from_seed(
|
||||
self.strategy.clone_dyn(),
|
||||
self.dataset.clone(),
|
||||
self.batcher.clone(),
|
||||
self.num_threads,
|
||||
device.clone(),
|
||||
self.seed,
|
||||
))
|
||||
}
|
||||
|
||||
fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
|
||||
let dataloader = Self::from_seed(
|
||||
self.strategy.clone_dyn(),
|
||||
Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
|
||||
self.batcher.clone(),
|
||||
self.num_threads,
|
||||
self.device.clone(),
|
||||
self.seed,
|
||||
);
|
||||
Arc::new(dataloader)
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> MultiThreadsDataloaderIterator<O> {
|
||||
pub fn new(
|
||||
receiver: mpsc::Receiver<Message<O>>,
|
||||
workers: Vec<thread::JoinHandle<()>>,
|
||||
progresses: Vec<Progress>,
|
||||
) -> Self {
|
||||
MultiThreadsDataloaderIterator {
|
||||
num_done: 0,
|
||||
workers,
|
||||
receiver,
|
||||
progresses,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
|
||||
fn progress(&self) -> Progress {
|
||||
let mut items_total = 0;
|
||||
let mut items_processed = 0;
|
||||
|
||||
for progress in self.progresses.iter() {
|
||||
items_total += progress.items_total;
|
||||
items_processed += progress.items_processed;
|
||||
}
|
||||
|
||||
Progress::new(items_processed, items_total)
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
|
||||
type Item = O;
|
||||
|
||||
fn next(&mut self) -> Option<O> {
|
||||
if self.workers.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
loop {
|
||||
let item = self.receiver.recv();
|
||||
let item = item.unwrap();
|
||||
|
||||
match item {
|
||||
Message::Batch(index, item, progress) => {
|
||||
if let Some(current) = self.progresses.get_mut(index) {
|
||||
*current = progress;
|
||||
}
|
||||
return Some(item);
|
||||
}
|
||||
Message::Done => {
|
||||
self.num_done += 1;
|
||||
}
|
||||
};
|
||||
|
||||
if self.num_done == self.workers.len() {
|
||||
while let Some(worker) = self.workers.pop() {
|
||||
worker.join().unwrap();
|
||||
}
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::data::dataloader::FixBatchStrategy;
|
||||
use crate::data::dataloader::batcher::TestBatcher;
|
||||
use crate::data::dataset::FakeDataset;
|
||||
use burn_dataset::InMemDataset;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn test_multi_thread_batch_dataloader() {
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(27));
|
||||
let dataloader_single_thread = BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset.clone(),
|
||||
batcher.clone(),
|
||||
Default::default(),
|
||||
None,
|
||||
);
|
||||
let dataloader_multi_thread = MultiThreadDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset,
|
||||
batcher,
|
||||
4,
|
||||
Default::default(),
|
||||
None,
|
||||
);
|
||||
|
||||
let mut items_single_thread = HashSet::new();
|
||||
let mut items_multi_thread = HashSet::new();
|
||||
|
||||
for items in dataloader_single_thread.iter() {
|
||||
for item in items {
|
||||
items_single_thread.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
for items in dataloader_multi_thread.iter() {
|
||||
for item in items {
|
||||
items_multi_thread.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(items_single_thread, items_multi_thread);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_thread_batch_dataloader_shuffle() {
|
||||
let num_classes = 2;
|
||||
let class_size = 100;
|
||||
let batch_size = 10;
|
||||
|
||||
// Items is a deliberately ordered dataset.
|
||||
let mut items = Vec::new();
|
||||
for class in 0..num_classes {
|
||||
items.extend(vec![class; class_size]);
|
||||
}
|
||||
|
||||
{
|
||||
// Unshuffled multithreaded loader
|
||||
let dataset = Arc::new(InMemDataset::new(items.clone()));
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
|
||||
let loader = MultiThreadDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(batch_size)),
|
||||
dataset,
|
||||
batcher,
|
||||
num_classes,
|
||||
Default::default(),
|
||||
// No rng means no shuffling.
|
||||
None,
|
||||
);
|
||||
|
||||
for batch in loader.iter() {
|
||||
let mut batch_items = HashSet::new();
|
||||
for item in batch {
|
||||
batch_items.insert(item);
|
||||
}
|
||||
|
||||
// Since the dataset is not shuffled, we expect each batch to contain the same item.
|
||||
assert_eq!(batch_items.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Shuffled multithreaded loader
|
||||
let dataset = Arc::new(InMemDataset::new(items.clone()));
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
|
||||
let loader = MultiThreadDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(batch_size)),
|
||||
dataset.clone(),
|
||||
batcher.clone(),
|
||||
num_classes,
|
||||
Default::default(),
|
||||
// The rng enables shuffling.
|
||||
Some(StdRng::seed_from_u64(42)),
|
||||
);
|
||||
|
||||
for batch in loader.iter() {
|
||||
let mut batch_items = HashSet::new();
|
||||
for item in batch {
|
||||
batch_items.insert(item);
|
||||
}
|
||||
|
||||
// Since the dataset is shuffled, we expect to see all items.
|
||||
assert_eq!(batch_items.len(), num_classes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_thread_batch_dataloader_incomplete_batches() {
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(27));
|
||||
let dataloader_single_thread = BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset.clone(),
|
||||
batcher.clone(),
|
||||
Default::default(),
|
||||
None,
|
||||
);
|
||||
let dataloader_multi_thread = MultiThreadDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset,
|
||||
batcher,
|
||||
4,
|
||||
Default::default(),
|
||||
None,
|
||||
);
|
||||
|
||||
let mut items_single_thread = HashSet::new();
|
||||
let mut items_multi_thread = HashSet::new();
|
||||
|
||||
let mut single_thread_cnt = 0;
|
||||
let mut multi_thread_cnt = 0;
|
||||
for items in dataloader_single_thread.iter() {
|
||||
items_single_thread.insert(items);
|
||||
single_thread_cnt += 1;
|
||||
}
|
||||
|
||||
for items in dataloader_multi_thread.iter() {
|
||||
items_multi_thread.insert(items);
|
||||
multi_thread_cnt += 1;
|
||||
}
|
||||
|
||||
assert_eq!(single_thread_cnt, multi_thread_cnt);
|
||||
assert_eq!(items_single_thread, items_multi_thread);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use super::DataLoader;
|
||||
|
||||
/// Splits a dataloader into multiple partial dataloaders (one per device).
|
||||
pub fn split_dataloader<B: Backend, O>(
|
||||
dataloader: Arc<dyn DataLoader<B, O>>,
|
||||
devices: &[B::Device],
|
||||
) -> Vec<Arc<dyn DataLoader<B, O>>> {
|
||||
let num_splits = devices.len();
|
||||
if num_splits > 1 {
|
||||
let num_items = dataloader.num_items();
|
||||
let mut dataloaders = Vec::with_capacity(num_splits);
|
||||
|
||||
let mut start = 0;
|
||||
let step = num_items / num_splits;
|
||||
for (i, device) in devices.iter().enumerate() {
|
||||
let end = if i == (num_splits - 1) {
|
||||
num_items
|
||||
} else {
|
||||
start + step
|
||||
};
|
||||
let dataloader = dataloader.slice(start, end).to_device(device);
|
||||
dataloaders.push(dataloader);
|
||||
start = end;
|
||||
}
|
||||
dataloaders
|
||||
} else {
|
||||
vec![dataloader]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use crate::data::dataloader::batcher::Batcher;
|
||||
use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};
|
||||
use crate::data::dataset::FakeDataset;
|
||||
|
||||
#[test]
|
||||
fn test_split_batch_dataloader() {
|
||||
type TestDevice = <TestBackend as Backend>::Device;
|
||||
|
||||
#[derive(new, Clone)]
|
||||
pub struct TestBatcher;
|
||||
|
||||
#[cfg(test)]
|
||||
impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {
|
||||
fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {
|
||||
(items, *device)
|
||||
}
|
||||
}
|
||||
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(11));
|
||||
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let dataloader = Arc::new(BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset.clone(),
|
||||
batcher,
|
||||
Default::default(),
|
||||
None,
|
||||
));
|
||||
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-tch"),
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda")
|
||||
))]
|
||||
// Only one device exists...
|
||||
let (device1, device2) = (
|
||||
burn_ndarray::NdArrayDevice::Cpu,
|
||||
burn_ndarray::NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
#[cfg(all(test, feature = "test-tch"))]
|
||||
let (device1, device2) = (
|
||||
burn_tch::LibTorchDevice::Cuda(0),
|
||||
burn_tch::LibTorchDevice::Cuda(1),
|
||||
);
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
let (device1, device2) = (
|
||||
burn_wgpu::WgpuDevice::DiscreteGpu(0),
|
||||
burn_wgpu::WgpuDevice::DiscreteGpu(1),
|
||||
);
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
|
||||
|
||||
let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);
|
||||
|
||||
assert_eq!(dataloaders.len(), 2);
|
||||
|
||||
let [dataloader_1, dataloader_2] = match dataloaders.try_into() {
|
||||
Ok(arr) => arr,
|
||||
Err(_) => unreachable!(),
|
||||
};
|
||||
assert_eq!(dataloader_1.num_items(), 5);
|
||||
assert_eq!(dataloader_2.num_items(), 6);
|
||||
|
||||
let mut items_dataloader = HashSet::new();
|
||||
let mut items_dataloader_split = HashSet::new();
|
||||
|
||||
for (items, _device) in dataloader.iter() {
|
||||
for item in items {
|
||||
items_dataloader.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
for (items, device) in dataloader_1.iter() {
|
||||
assert_eq!(device, device1);
|
||||
for item in items {
|
||||
items_dataloader_split.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
for (items, device) in dataloader_2.iter() {
|
||||
assert_eq!(device, device2);
|
||||
for item in items {
|
||||
items_dataloader_split.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(items_dataloader, items_dataloader_split);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
/// A strategy to batch items.
|
||||
pub trait BatchStrategy<I>: Send + Sync {
|
||||
/// Adds an item to the strategy.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item to add.
|
||||
fn add(&mut self, item: I);
|
||||
|
||||
/// Batches the items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `force` - Whether to force batching.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batched items.
|
||||
fn batch(&mut self, force: bool) -> Option<Vec<I>>;
|
||||
|
||||
/// Creates a new strategy of the same type.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The new strategy.
|
||||
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
|
||||
|
||||
/// Returns the expected batch size for this strategy.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batch size, or None if the strategy doesn't have a fixed batch size.
|
||||
fn batch_size(&self) -> Option<usize>;
|
||||
}
|
||||
|
||||
/// A strategy to batch items with a fixed batch size.
|
||||
pub struct FixBatchStrategy<I> {
|
||||
items: Vec<I>,
|
||||
batch_size: usize,
|
||||
}
|
||||
|
||||
impl<I> FixBatchStrategy<I> {
|
||||
/// Creates a new strategy to batch items with a fixed batch size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batch_size` - The batch size.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The strategy.
|
||||
pub fn new(batch_size: usize) -> Self {
|
||||
FixBatchStrategy {
|
||||
items: Vec::with_capacity(batch_size),
|
||||
batch_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
|
||||
fn add(&mut self, item: I) {
|
||||
self.items.push(item);
|
||||
}
|
||||
|
||||
fn batch(&mut self, force: bool) -> Option<Vec<I>> {
|
||||
if self.items.len() < self.batch_size && !force {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut items = Vec::with_capacity(self.batch_size);
|
||||
std::mem::swap(&mut items, &mut self.items);
|
||||
|
||||
if items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(items)
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
|
||||
Box::new(Self::new(self.batch_size))
|
||||
}
|
||||
|
||||
fn batch_size(&self) -> Option<usize> {
|
||||
Some(self.batch_size)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
/// Dataloader module.
|
||||
#[cfg(feature = "dataset")]
|
||||
pub mod dataloader;
|
||||
|
||||
/// Dataset module.
|
||||
#[cfg(feature = "dataset")]
|
||||
pub mod dataset {
|
||||
pub use burn_dataset::*;
|
||||
}
|
||||
|
||||
/// Network module.
|
||||
#[cfg(feature = "network")]
|
||||
pub mod network {
|
||||
pub use burn_std::network::*;
|
||||
}
|
||||
118
crates/stable-diffusion-burn/burn-crates/burn-core/src/lib.rs
Normal file
118
crates/stable-diffusion-burn/burn-crates/burn-core/src/lib.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![recursion_limit = "135"]
|
||||
|
||||
//! The core crate of Burn.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// Re-export serde for proc macros.
|
||||
pub use serde;
|
||||
|
||||
/// The configuration module.
|
||||
pub mod config;
|
||||
|
||||
/// Data module.
|
||||
#[cfg(feature = "std")]
|
||||
pub mod data;
|
||||
|
||||
/// Module for the neural network module.
|
||||
pub mod module;
|
||||
|
||||
/// Module for the recorder.
|
||||
pub mod record;
|
||||
|
||||
/// Module for the tensor.
|
||||
pub mod tensor;
|
||||
// Tensor at root: `burn::Tensor`
|
||||
pub use tensor::Tensor;
|
||||
|
||||
/// Module for visual operations
|
||||
#[cfg(feature = "vision")]
|
||||
pub mod vision;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Backend for test cases
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-tch"),
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda"),
|
||||
not(feature = "test-rocm")
|
||||
))]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-tch"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_tch::LibTorch<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_cuda::Cuda;
|
||||
|
||||
#[cfg(all(test, feature = "test-rocm"))]
|
||||
/// Backend for test cases
|
||||
pub type TestBackend = burn_rocm::Rocm;
|
||||
|
||||
/// Backend for autodiff test cases
|
||||
#[cfg(test)]
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
|
||||
#[cfg(all(test, feature = "test-memory-checks"))]
|
||||
mod tests {
|
||||
burn_fusion::memory_checks!();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils {
|
||||
use crate as burn;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use burn_tensor::Tensor;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
/// Simple linear module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct SimpleLinear<B: Backend> {
|
||||
pub weight: Param<Tensor<B, 2>>,
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleLinear<B> {
|
||||
pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
|
||||
let weight = Tensor::random(
|
||||
[out_features, in_features],
|
||||
burn_tensor::Distribution::Default,
|
||||
device,
|
||||
);
|
||||
let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
|
||||
|
||||
Self {
|
||||
weight: Param::from_tensor(weight),
|
||||
bias: Some(Param::from_tensor(bias)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod prelude {
|
||||
//! Structs and macros used by most projects. Add `use
|
||||
//! burn::prelude::*` to your code to quickly get started with
|
||||
//! Burn.
|
||||
pub use crate::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
tensor::{
|
||||
Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,
|
||||
backend::Backend, cast::ToElement, s,
|
||||
},
|
||||
};
|
||||
pub use burn_std::device::Device as DeviceOps;
|
||||
}
|
||||
@@ -0,0 +1,470 @@
|
||||
use super::{Param, ParamId, Quantizer};
|
||||
use crate::{
|
||||
record::Record,
|
||||
tensor::backend::{AutodiffBackend, Backend},
|
||||
};
|
||||
use alloc::{string::String, vec::Vec};
|
||||
pub use burn_derive::Module;
|
||||
use burn_tensor::{Bool, Int, Tensor, ops::Device};
|
||||
|
||||
/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
|
||||
/// the `alloc` crate.
|
||||
pub type Devices<B> = Vec<Device<B>>;
|
||||
|
||||
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
|
||||
// We may consider making it public in the future.
|
||||
macro_rules! module {
|
||||
(map=$module:ident, ops=$item:expr) => {{
|
||||
struct Mapper;
|
||||
impl<B: Backend> ModuleMapper<B> for Mapper {
|
||||
fn map_float<const D: usize>(
|
||||
&mut self,
|
||||
param: Param<Tensor<B, D>>,
|
||||
) -> Param<Tensor<B, D>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
let func = $item;
|
||||
let tensor = func(tensor);
|
||||
Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
}
|
||||
let mut mapper = Mapper;
|
||||
$module.map(&mut mapper)
|
||||
}};
|
||||
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
|
||||
struct Visitor<'a, B: Backend> {
|
||||
state: &'a mut $state_ty,
|
||||
backend: core::marker::PhantomData<B>,
|
||||
}
|
||||
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
let func = $item;
|
||||
func(¶m.val(), &mut self.state)
|
||||
}
|
||||
}
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
let mut state = $init();
|
||||
let mut visitor = Visitor {
|
||||
state: &mut state,
|
||||
backend: core::marker::PhantomData,
|
||||
};
|
||||
$module.visit(&mut visitor);
|
||||
state
|
||||
}};
|
||||
}
|
||||
|
||||
/// Trait for all neural network modules.
|
||||
///
|
||||
/// Modules should be created using the [derive](burn_derive::Module) attribute.
|
||||
/// This will make your module trainable, savable and loadable via
|
||||
/// `state` and `load`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
|
||||
/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
|
||||
/// necessary to optimize and train the module on any backend.
|
||||
///
|
||||
/// ```rust, ignore
|
||||
/// // Not necessary when using the burn crate directly.
|
||||
/// use burn_core as burn;
|
||||
///
|
||||
/// use burn::{
|
||||
/// module::Module,
|
||||
/// nn::Linear,
|
||||
/// tensor::Tensor,
|
||||
/// tensor::backend::Backend,
|
||||
/// };
|
||||
///
|
||||
/// #[derive(Module, Debug)]
|
||||
/// struct MyModule<B: Backend> {
|
||||
/// my_param: Linear<B>,
|
||||
/// my_other_field: usize,
|
||||
/// }
|
||||
/// ```
|
||||
pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
|
||||
/// Type to save and load the module.
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Return all the devices found in the underneath module tree added to the given vector
|
||||
/// without duplicates.
|
||||
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
|
||||
|
||||
/// Return all the devices found in the underneath module tree without duplicates.
|
||||
fn devices(&self) -> Devices<B> {
|
||||
self.collect_devices(Devices::<B>::new())
|
||||
}
|
||||
|
||||
/// Fork the module and all of its sub-modules to the given device.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is similar to [to_device](Module::to_device), but it ensures the output module on the
|
||||
/// new device will have its own autodiff graph.
|
||||
fn fork(self, device: &B::Device) -> Self;
|
||||
|
||||
/// Move the module and all of its sub-modules to the given device.
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// The operation supports autodiff and it will be registered when activated. However, this may
|
||||
/// not be what you want. The output model will be an intermediary model, meaning that you
|
||||
/// can't optimize it with gradient descent. If you want to optimize the output network on the
|
||||
/// target device, use [fork](Module::fork) instead.
|
||||
fn to_device(self, device: &B::Device) -> Self;
|
||||
|
||||
/// Each tensor in the module tree will not require grad.
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// This should not be used for inference, use [valid](AutodiffModule::valid) when using
|
||||
/// AD modules. This is mostly useful when performing partial finetuning, which is updating only
|
||||
/// a small fraction of the parameters instead of finetuning all of them.
|
||||
fn no_grad(self) -> Self {
|
||||
module!(
|
||||
map = self,
|
||||
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
|
||||
)
|
||||
}
|
||||
|
||||
/// Move the module and all of its sub-modules to the autodiff backend.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// * Only plain modules (not already on an autodiff backend) can be moved.
|
||||
/// * Calling `train()` on a module that is already on an autodiff backend
|
||||
/// will result in a type error, because the module's inner backend does not match.
|
||||
fn train<AB>(self) -> <Self as HasAutodiffModule<AB>>::TrainModule
|
||||
where
|
||||
AB: AutodiffBackend<InnerBackend = B>,
|
||||
Self: HasAutodiffModule<AB>,
|
||||
{
|
||||
<Self as HasAutodiffModule<AB>>::TrainModule::from_inner(self)
|
||||
}
|
||||
|
||||
/// Get the number of parameters the module has, including all of its sub-modules.
|
||||
fn num_params(&self) -> usize {
|
||||
module!(
|
||||
visit_float = self,
|
||||
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
|
||||
*state += tensor.shape().num_elements();
|
||||
},
|
||||
state = usize,
|
||||
init = || 0
|
||||
)
|
||||
}
|
||||
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
|
||||
fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
|
||||
|
||||
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
|
||||
fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
|
||||
|
||||
/// Load the module state from a record.
|
||||
fn load_record(self, record: Self::Record) -> Self;
|
||||
|
||||
/// Convert the module into a record containing the state.
|
||||
fn into_record(self) -> Self::Record;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
|
||||
///
|
||||
/// List of supported file recorders:
|
||||
///
|
||||
/// * [default](crate::record::DefaultFileRecorder)
|
||||
/// * [bincode](crate::record::BinFileRecorder)
|
||||
/// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
|
||||
/// * [json pretty](crate::record::PrettyJsonFileRecorder)
|
||||
/// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
|
||||
/// * [named mpk](crate::record::NamedMpkFileRecorder)
|
||||
/// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
|
||||
///
|
||||
/// ## Notes
|
||||
///
|
||||
/// The file extension is automatically added depending on the file recorder provided, you
|
||||
/// don't have to specify it.
|
||||
fn save_file<FR, PB>(
|
||||
self,
|
||||
file_path: PB,
|
||||
recorder: &FR,
|
||||
) -> Result<(), crate::record::RecorderError>
|
||||
where
|
||||
FR: crate::record::FileRecorder<B>,
|
||||
PB: Into<std::path::PathBuf>,
|
||||
{
|
||||
let record = Self::into_record(self);
|
||||
recorder.record(record, file_path.into())
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
|
||||
///
|
||||
/// The recorder should be the same as the one used to save the module, see
|
||||
/// [save_file](Self::save_file).
|
||||
///
|
||||
/// ## Notes
|
||||
///
|
||||
/// The file extension is automatically added depending on the file recorder provided, you
|
||||
/// don't have to specify it.
|
||||
fn load_file<FR, PB>(
|
||||
self,
|
||||
file_path: PB,
|
||||
recorder: &FR,
|
||||
device: &B::Device,
|
||||
) -> Result<Self, crate::record::RecorderError>
|
||||
where
|
||||
FR: crate::record::FileRecorder<B>,
|
||||
PB: Into<std::path::PathBuf>,
|
||||
{
|
||||
let record = recorder.load(file_path.into(), device)?;
|
||||
|
||||
Ok(self.load_record(record))
|
||||
}
|
||||
|
||||
/// Quantize the weights of the module.
|
||||
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
|
||||
self.map(quantizer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Module visitor trait for traversing and inspecting module parameters.
|
||||
pub trait ModuleVisitor<B: Backend> {
|
||||
/// Visit a float parameter in the module.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `param`: The float parameter to visit
|
||||
#[allow(unused_variables)]
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
|
||||
|
||||
/// Visit an int parameter in the module.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `param`: The integer parameter to visit
|
||||
#[allow(unused_variables)]
|
||||
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
|
||||
|
||||
/// Visit a bool parameter in the module.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `param`: The boolean parameter to visit
|
||||
#[allow(unused_variables)]
|
||||
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
|
||||
|
||||
/// Called when entering a submodule.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `name`: The name of the submodule being entered
|
||||
/// - `container_type`: The type of the container with format:
|
||||
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
|
||||
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
|
||||
/// - For Vec containers: "Vec" (name is the index)
|
||||
/// - For Tuple containers: "Tuple" (name is the index)
|
||||
/// - For Array containers: "Array" (name is the index)
|
||||
///
|
||||
/// Note: Option containers do not call enter_module/exit_module to preserve
|
||||
/// the field name in the path (e.g., "bias" instead of "bias.Some")
|
||||
#[allow(unused_variables)]
|
||||
fn enter_module(&mut self, name: &str, container_type: &str) {}
|
||||
|
||||
/// Called when exiting a submodule.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `name`: The name of the submodule being exited
|
||||
/// - `container_type`: The type of the container with format:
|
||||
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
|
||||
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
|
||||
/// - For Vec containers: "Vec" (name is the index)
|
||||
/// - For Tuple containers: "Tuple" (name is the index)
|
||||
/// - For Array containers: "Array" (name is the index)
|
||||
///
|
||||
/// Note: Option containers do not call enter_module/exit_module to preserve
|
||||
/// the field name in the path (e.g., "bias" instead of "bias.Some")
|
||||
#[allow(unused_variables)]
|
||||
fn exit_module(&mut self, name: &str, container_type: &str) {}
|
||||
|
||||
/// Visit a float tensor with its full module path.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
|
||||
/// Each element represents a module name in the hierarchy, with the final element
|
||||
/// being the parameter name. This allows efficient reuse of the path stack.
|
||||
/// - `id`: The unique identifier of the parameter
|
||||
/// - `tensor`: The float tensor to visit
|
||||
#[allow(unused_variables)]
|
||||
fn visit_float_with_path<const D: usize>(
|
||||
&mut self,
|
||||
path: &[String],
|
||||
id: ParamId,
|
||||
tensor: &Tensor<B, D>,
|
||||
) {
|
||||
}
|
||||
|
||||
/// Visit an int tensor with its full module path.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
|
||||
/// Each element represents a module name in the hierarchy, with the final element
|
||||
/// being the parameter name. This allows efficient reuse of the path stack.
|
||||
/// - `id`: The unique identifier of the parameter
|
||||
/// - `tensor`: The integer tensor to visit
|
||||
#[allow(unused_variables)]
|
||||
fn visit_int_with_path<const D: usize>(
|
||||
&mut self,
|
||||
path: &[String],
|
||||
id: ParamId,
|
||||
tensor: &Tensor<B, D, Int>,
|
||||
) {
|
||||
}
|
||||
|
||||
/// Visit a bool tensor with its full module path.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
|
||||
/// Each element represents a module name in the hierarchy, with the final element
|
||||
/// being the parameter name. This allows efficient reuse of the path stack.
|
||||
/// - `id`: The unique identifier of the parameter
|
||||
/// - `tensor`: The boolean tensor to visit
|
||||
#[allow(unused_variables)]
|
||||
fn visit_bool_with_path<const D: usize>(
|
||||
&mut self,
|
||||
path: &[String],
|
||||
id: ParamId,
|
||||
tensor: &Tensor<B, D, Bool>,
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Module mapper trait for transforming module parameters.
|
||||
pub trait ModuleMapper<B: Backend> {
|
||||
/// Called when entering a submodule.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `name`: The name of the submodule being entered
|
||||
/// - `container_type`: The type of the container with format:
|
||||
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
|
||||
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
|
||||
/// - For Vec containers: "Vec" (name is the index)
|
||||
/// - For Tuple containers: "Tuple" (name is the index)
|
||||
/// - For Array containers: "Array" (name is the index)
|
||||
///
|
||||
/// Note: Option containers do not call enter_module/exit_module to preserve
|
||||
/// the field name in the path (e.g., "bias" instead of "bias.Some")
|
||||
#[allow(unused_variables)]
|
||||
fn enter_module(&mut self, name: &str, container_type: &str) {}
|
||||
|
||||
/// Called when exiting a submodule.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `name`: The name of the submodule being exited
|
||||
/// - `container_type`: The type of the container with format:
|
||||
/// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
|
||||
/// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
|
||||
/// - For Vec containers: "Vec" (name is the index)
|
||||
/// - For Tuple containers: "Tuple" (name is the index)
|
||||
/// - For Array containers: "Array" (name is the index)
|
||||
///
|
||||
/// Note: Option containers do not call enter_module/exit_module to preserve
|
||||
/// the field name in the path (e.g., "bias" instead of "bias.Some")
|
||||
#[allow(unused_variables)]
|
||||
fn exit_module(&mut self, name: &str, container_type: &str) {}
|
||||
|
||||
/// Map a float parameter in the module.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `param`: The float parameter to transform
|
||||
///
|
||||
/// # Returns
|
||||
/// The transformed parameter
|
||||
#[allow(unused_variables)]
|
||||
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
|
||||
/// Map an int parameter in the module.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `param`: The integer parameter to transform
|
||||
///
|
||||
/// # Returns
|
||||
/// The transformed parameter
|
||||
#[allow(unused_variables)]
|
||||
fn map_int<const D: usize>(
|
||||
&mut self,
|
||||
param: Param<Tensor<B, D, Int>>,
|
||||
) -> Param<Tensor<B, D, Int>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
|
||||
/// Map a bool parameter in the module.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `param`: The boolean parameter to transform
|
||||
///
|
||||
/// # Returns
|
||||
/// The transformed parameter
|
||||
#[allow(unused_variables)]
|
||||
fn map_bool<const D: usize>(
|
||||
&mut self,
|
||||
param: Param<Tensor<B, D, Bool>>,
|
||||
) -> Param<Tensor<B, D, Bool>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
}
|
||||
|
||||
/// Module with auto-differentiation backend.
|
||||
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
|
||||
/// Inner module without auto-differentiation.
|
||||
type InnerModule: Module<B::InnerBackend>;
|
||||
|
||||
/// Returns the same module, but on the inner backend without auto-differentiation.
|
||||
fn valid(&self) -> Self::InnerModule;
|
||||
|
||||
/// Wraps an inner module back into an auto-diff module.
|
||||
fn from_inner(module: Self::InnerModule) -> Self;
|
||||
}
|
||||
|
||||
/// Helper trait to associate a module with its autodiff version.
|
||||
pub trait HasAutodiffModule<B: AutodiffBackend> {
|
||||
/// The module with auto-differentiation.
|
||||
type TrainModule: AutodiffModule<B, InnerModule = Self>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::TestAutodiffBackend;
|
||||
use crate::test_utils::SimpleLinear;
|
||||
|
||||
#[test]
|
||||
fn test_module_val_train_stateful() {
|
||||
let device = Default::default();
|
||||
let module = SimpleLinear::<TestAutodiffBackend>::new(4, 4, &device);
|
||||
|
||||
assert!(module.weight.is_require_grad());
|
||||
assert!(module.weight.require_grad);
|
||||
|
||||
let module = module.valid();
|
||||
assert!(!module.weight.is_require_grad());
|
||||
assert!(module.weight.require_grad); // stateful
|
||||
|
||||
// Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying
|
||||
// let module: SimpleLinear<TestAutodiffBackend> = module.train();
|
||||
let module = module.train::<TestAutodiffBackend>();
|
||||
assert!(module.weight.is_require_grad());
|
||||
assert!(module.weight.require_grad); // stateful
|
||||
|
||||
let module = module.no_grad();
|
||||
assert!(!module.weight.is_require_grad());
|
||||
assert!(!module.weight.require_grad); // stateful
|
||||
|
||||
let module = module.valid();
|
||||
assert!(!module.weight.is_require_grad()); // always
|
||||
assert!(!module.weight.require_grad); // stateful
|
||||
|
||||
let module = module.train::<TestAutodiffBackend>();
|
||||
assert!(!module.weight.is_require_grad());
|
||||
assert!(!module.weight.require_grad); // stateful
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,543 @@
|
||||
use alloc::{
|
||||
borrow::ToOwned,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
use core::any;
|
||||
use core::fmt::{Display, Write};
|
||||
|
||||
/// Default display settings for a module.
|
||||
pub trait ModuleDisplayDefault {
|
||||
/// Attributes of the module used for display purposes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `_content` - The content object that contains display settings and attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional content object containing the display attributes.
|
||||
fn content(&self, _content: Content) -> Option<Content>;
|
||||
|
||||
/// Gets the number of the parameters of the module.
|
||||
fn num_params(&self) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to implement custom display settings for a module.
|
||||
///
|
||||
/// In order to implement custom display settings for a module,
|
||||
/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
|
||||
/// 2. Implement ModuleDisplay trait for the module
|
||||
pub trait ModuleDisplay: ModuleDisplayDefault {
|
||||
/// Formats the module with provided display settings.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `passed_settings` - Display settings passed to the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A string representation of the formatted module.
|
||||
fn format(&self, passed_settings: DisplaySettings) -> String {
|
||||
let settings = if let Some(custom_settings) = self.custom_settings() {
|
||||
custom_settings.inherit(passed_settings)
|
||||
} else {
|
||||
passed_settings
|
||||
};
|
||||
|
||||
let indent = " ".repeat(settings.level * settings.indentation_size());
|
||||
let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
|
||||
|
||||
let settings = settings.level_up();
|
||||
|
||||
let self_type = extract_type_name::<Self>();
|
||||
|
||||
// Use custom content if it is implemented and show_all_attributes is false,
|
||||
// otherwise use default content
|
||||
let content = if !settings.show_all_attributes() {
|
||||
self.custom_content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| {
|
||||
self.content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| {
|
||||
panic!("Default content should be implemented for {self_type}.")
|
||||
})
|
||||
})
|
||||
} else {
|
||||
self.content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
|
||||
};
|
||||
|
||||
let top_level_type = if let Some(top_level_type) = content.top_level_type {
|
||||
top_level_type.to_owned()
|
||||
} else {
|
||||
self_type.to_owned()
|
||||
};
|
||||
|
||||
// If there is only one item in the content, return it or no attributes
|
||||
if let Some(item) = content.single_item {
|
||||
return item;
|
||||
} else if content.attributes.is_empty() {
|
||||
return top_level_type.to_string();
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
|
||||
// Print the struct name
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{top_level_type} {{").unwrap();
|
||||
} else {
|
||||
write!(result, "{top_level_type} {{").unwrap();
|
||||
}
|
||||
|
||||
for (i, attribute) in content.attributes.iter().enumerate() {
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
|
||||
} else if i == 0 {
|
||||
write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
|
||||
} else {
|
||||
write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if settings.show_num_parameters() {
|
||||
let num_params = self.num_params();
|
||||
if num_params > 0 {
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{indent}params: {num_params}").unwrap();
|
||||
} else {
|
||||
write!(result, ", params: {num_params}").unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if settings.new_line_after_attribute() {
|
||||
write!(result, "{indent_close_braces}}}").unwrap();
|
||||
} else {
|
||||
write!(result, "}}").unwrap();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Custom display settings for the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional display settings object.
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Custom attributes for the module.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `_content` - The content object that contains display settings and attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional content object containing the custom attributes.
|
||||
fn custom_content(&self, _content: Content) -> Option<Content> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom module display settings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DisplaySettings {
|
||||
/// Whether to print the module parameter ids.
|
||||
show_param_id: Option<bool>,
|
||||
|
||||
/// Whether to print the module attributes.
|
||||
show_all_attributes: Option<bool>,
|
||||
|
||||
/// Whether to print the module number of parameters.
|
||||
show_num_parameters: Option<bool>,
|
||||
|
||||
/// Print new line after an attribute.
|
||||
new_line_after_attribute: Option<bool>,
|
||||
|
||||
/// Indentation size.
|
||||
indentation_size: Option<usize>,
|
||||
|
||||
/// Level of indentation.
|
||||
level: usize,
|
||||
}
|
||||
|
||||
impl Default for DisplaySettings {
|
||||
fn default() -> Self {
|
||||
DisplaySettings {
|
||||
show_param_id: None,
|
||||
show_all_attributes: None,
|
||||
show_num_parameters: None,
|
||||
new_line_after_attribute: None,
|
||||
indentation_size: None,
|
||||
level: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplaySettings {
|
||||
/// Create a new format settings.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new instance of `DisplaySettings`.
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// Sets a flag to show module parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_param_id(mut self, flag: bool) -> Self {
|
||||
self.show_param_id = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to show module attributes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show all module attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
|
||||
self.show_all_attributes = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to show the number of module parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show the number of module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
|
||||
self.show_num_parameters = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to print a new line after an attribute.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to print a new line after an attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
|
||||
self.new_line_after_attribute = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the indentation size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_indentation_size(mut self, size: usize) -> Self {
|
||||
self.indentation_size = Some(size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Inherits settings from the provided settings and return a new settings object.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `top` - The top level `DisplaySettings` to inherit from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn inherit(self, top: Self) -> Self {
|
||||
let mut updated = self.clone();
|
||||
|
||||
if let Some(show_param_id) = top.show_param_id {
|
||||
updated.show_param_id = Some(show_param_id);
|
||||
};
|
||||
|
||||
if let Some(show_all_attributes) = top.show_all_attributes {
|
||||
updated.show_all_attributes = Some(show_all_attributes);
|
||||
}
|
||||
|
||||
if let Some(show_num_parameters) = top.show_num_parameters {
|
||||
updated.show_num_parameters = Some(show_num_parameters);
|
||||
}
|
||||
|
||||
if let Some(new_line_after_attribute) = top.new_line_after_attribute {
|
||||
updated.new_line_after_attribute = Some(new_line_after_attribute);
|
||||
}
|
||||
|
||||
if let Some(indentation_size) = top.indentation_size {
|
||||
updated.indentation_size = Some(indentation_size);
|
||||
}
|
||||
|
||||
updated.level = top.level;
|
||||
|
||||
updated
|
||||
}
|
||||
|
||||
/// A convenience method to wrap the DisplaySettings struct in an option.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional `DisplaySettings`.
|
||||
pub fn optional(self) -> Option<Self> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
/// Increases the level of indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance with increased indentation level.
|
||||
pub fn level_up(mut self) -> Self {
|
||||
self.level += 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets `show_param_id` flag, substitutes false if not set.
|
||||
///
|
||||
/// This flag is used to print the module parameter ids.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show parameter ids.
|
||||
pub fn show_param_id(&self) -> bool {
|
||||
self.show_param_id.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Gets `show_all_attributes`, substitutes false if not set.
|
||||
///
|
||||
/// This flag is used to force to print all module attributes, overriding custom attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show all attributes.
|
||||
pub fn show_all_attributes(&self) -> bool {
|
||||
self.show_all_attributes.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Gets `show_num_parameters`, substitutes true if not set.
|
||||
///
|
||||
/// This flag is used to print the number of module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show the number of parameters.
|
||||
pub fn show_num_parameters(&self) -> bool {
|
||||
self.show_num_parameters.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Gets `new_line_after_attribute`, substitutes true if not set.
|
||||
///
|
||||
/// This flag is used to print a new line after an attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to print a new line after an attribute.
|
||||
pub fn new_line_after_attribute(&self) -> bool {
|
||||
self.new_line_after_attribute.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Gets `indentation_size`, substitutes 2 if not set.
|
||||
///
|
||||
/// This flag is used to set the size of indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An integer value indicating the size of indentation.
|
||||
pub fn indentation_size(&self) -> usize {
|
||||
self.indentation_size.unwrap_or(2)
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct to store the attributes of a module for formatting.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Content {
|
||||
/// List of attributes.
|
||||
pub attributes: Vec<Attribute>,
|
||||
|
||||
/// Single item content.
|
||||
pub single_item: Option<String>,
|
||||
|
||||
/// Display settings.
|
||||
pub display_settings: DisplaySettings,
|
||||
|
||||
/// Top level type name.
|
||||
pub top_level_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
/// Creates a new attributes struct.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `display_settings` - Display settings for the content.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new instance of `Content`.
|
||||
pub fn new(display_settings: DisplaySettings) -> Self {
|
||||
Content {
|
||||
attributes: Vec::new(),
|
||||
single_item: None,
|
||||
display_settings,
|
||||
top_level_type: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds an attribute to the format settings. The value will be formatted and stored as a string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Name of the attribute.
|
||||
/// * `value` - Value of the attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the new attribute added.
|
||||
pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
|
||||
if self.single_item.is_some() {
|
||||
panic!("Cannot add multiple attributes when single item is set.");
|
||||
}
|
||||
|
||||
let attribute = Attribute {
|
||||
name: name.to_owned(),
|
||||
value: value.format(self.display_settings.clone()), // TODO level + 1
|
||||
ty: any::type_name::<T>().to_string(),
|
||||
};
|
||||
self.attributes.push(attribute);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a single item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - Rendered string of the single item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the single item added.
|
||||
pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
|
||||
if !self.attributes.is_empty() {
|
||||
panic!("Cannot add single item when attributes are set.");
|
||||
}
|
||||
|
||||
self.single_item = Some(value.format(self.display_settings.clone()));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a single item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - Formatted display value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the formatted single item added.
|
||||
pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
|
||||
if !self.attributes.is_empty() {
|
||||
panic!("Cannot add single item when attributes are set.");
|
||||
}
|
||||
|
||||
self.single_item = Some(format!("{value}"));
|
||||
self
|
||||
}
|
||||
|
||||
/// A convenience method to wrap the Attributes struct in an option
|
||||
/// because it is often used as an optional field.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional `Content`.
|
||||
pub fn optional(self) -> Option<Self> {
|
||||
if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the top level type name.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ty` - The type name to set.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the top level type name set.
|
||||
pub fn set_top_level_type(mut self, ty: &str) -> Self {
|
||||
self.top_level_type = Some(ty.to_owned());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Attribute to print in the display method.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Attribute {
|
||||
/// Name of the attribute.
|
||||
pub name: String,
|
||||
|
||||
/// Value of the attribute.
|
||||
pub value: String,
|
||||
|
||||
/// Type of the attribute.
|
||||
pub ty: String,
|
||||
}
|
||||
|
||||
/// Extracts the short name of a type T
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A string slice representing the short name of the type.
|
||||
pub fn extract_type_name<T: ?Sized>() -> &'static str {
|
||||
// Get the full type name of T, including module path and generic parameters
|
||||
let ty = any::type_name::<T>();
|
||||
|
||||
// Find the first occurrence of '<' in the full type name
|
||||
// If not found, use the length of the type name
|
||||
let end = ty.find('<').unwrap_or(ty.len());
|
||||
|
||||
// Slice the type name up to the first '<' or the end
|
||||
let ty = &ty[0..end];
|
||||
|
||||
// Find the last occurrence of "::" in the sliced type name
|
||||
// If found, add 2 to skip the "::" itself
|
||||
// If not found, start from the beginning of the type name
|
||||
let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
|
||||
|
||||
// Find the last occurrence of '<' in the sliced type name
|
||||
// If not found, use the length of the type name
|
||||
let end = ty.rfind('<').unwrap_or(ty.len());
|
||||
|
||||
// If the start index is less than the end index,
|
||||
// return the slice of the type name from start to end
|
||||
// Otherwise, return the entire sliced type name
|
||||
if start < end { &ty[start..end] } else { ty }
|
||||
}
|
||||
@@ -0,0 +1,627 @@
|
||||
use crate::tensor::Shape;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Param, ParamId};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, Tensor, s};
|
||||
|
||||
use crate as burn;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float as _;
|
||||
|
||||
/// Enum specifying with what values a tensor should be initialized
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum Initializer {
|
||||
/// Fills tensor with specified value everywhere
|
||||
Constant {
|
||||
/// The value to fill the tensor with
|
||||
value: f64,
|
||||
},
|
||||
/// Fills tensor with 1s everywhere
|
||||
Ones,
|
||||
/// Fills tensor with 0s everywhere
|
||||
Zeros,
|
||||
/// Fills tensor with values drawn uniformly between specified values
|
||||
Uniform {
|
||||
/// The minimum value to draw from
|
||||
min: f64,
|
||||
|
||||
/// The maximum value to draw from
|
||||
max: f64,
|
||||
},
|
||||
/// Fills tensor with values drawn from normal distribution with specified mean and std
|
||||
Normal {
|
||||
/// The mean of the normal distribution
|
||||
mean: f64,
|
||||
|
||||
/// The standard deviation of the normal distribution
|
||||
std: f64,
|
||||
},
|
||||
/// Fills tensor with values according to the uniform version of Kaiming initialization
|
||||
KaimingUniform {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
|
||||
/// Whether to use fan out only in initialization formula
|
||||
fan_out_only: bool,
|
||||
},
|
||||
/// Fills tensor with values according to the uniform version of Kaiming initialization
|
||||
KaimingNormal {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
|
||||
/// Whether to use fan out only in initialization formula
|
||||
fan_out_only: bool,
|
||||
},
|
||||
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization
|
||||
/// described in [Understanding the difficulty of training deep feedforward neural networks
|
||||
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
XavierUniform {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
},
|
||||
/// Fills tensor with values according to the normal version of Xavier Glorot initialization
|
||||
/// described in [Understanding the difficulty of training deep feedforward neural networks
|
||||
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
XavierNormal {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
},
|
||||
/// Fills tensor with values according to the (semi) orthogonal initialization
|
||||
/// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
|
||||
/// - [Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120)
|
||||
Orthogonal {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Initializer {
|
||||
/// Inits a tensor parameter of given shape with values depending on initializer kind.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - shape: Shape of the initiated tensor.
|
||||
pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
|
||||
&self,
|
||||
shape: S,
|
||||
device: &B::Device,
|
||||
) -> Param<Tensor<B, D>> {
|
||||
self.init_with(shape, None, None, device)
|
||||
}
|
||||
|
||||
/// Inits a tensor parameter of given shape with values depending on initializer kind.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - shape: Shape of the initiated tensor.
|
||||
pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
|
||||
&self,
|
||||
shape: S,
|
||||
fan_in: Option<usize>,
|
||||
fan_out: Option<usize>,
|
||||
device: &B::Device,
|
||||
) -> Param<Tensor<B, D>> {
|
||||
let device = device.clone();
|
||||
let shape: Shape = shape.into();
|
||||
let config = self.clone();
|
||||
let shape_for_closure = shape.clone();
|
||||
|
||||
Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, require_grad| {
|
||||
B::memory_persistent_allocations(device, (), move |_| {
|
||||
let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
|
||||
|
||||
if require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
|
||||
tensor
|
||||
})
|
||||
},
|
||||
device,
|
||||
true,
|
||||
shape_for_closure,
|
||||
)
|
||||
}
|
||||
|
||||
fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
|
||||
&self,
|
||||
shape: S,
|
||||
fan_in: Option<usize>,
|
||||
fan_out: Option<usize>,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D> {
|
||||
let shape = shape.into();
|
||||
match self {
|
||||
Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
|
||||
Initializer::Ones => Tensor::<B, D>::ones(shape, device),
|
||||
Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
|
||||
Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
|
||||
Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
|
||||
Initializer::KaimingUniform { gain, fan_out_only } => {
|
||||
let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
|
||||
uniform_draw(shape, -a, a, device)
|
||||
}
|
||||
Initializer::KaimingNormal { gain, fan_out_only } => {
|
||||
let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
|
||||
normal_draw(shape, 0.0, std, device)
|
||||
}
|
||||
Initializer::XavierUniform { gain } => {
|
||||
let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
|
||||
uniform_draw(shape, -a, a, device)
|
||||
}
|
||||
Initializer::XavierNormal { gain } => {
|
||||
let std = *gain * self.xavier_std(fan_in, fan_out);
|
||||
normal_draw(shape, 0.0, std, device)
|
||||
}
|
||||
Initializer::Orthogonal { gain } => {
|
||||
// following the implementation in pytorch:
|
||||
// https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574
|
||||
|
||||
assert!(
|
||||
D >= 2,
|
||||
"Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
|
||||
);
|
||||
|
||||
let rows: usize = shape.dims::<D>()[0];
|
||||
let cols: usize = shape.num_elements() / rows;
|
||||
|
||||
let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
|
||||
|
||||
if rows < cols {
|
||||
t = t.transpose();
|
||||
}
|
||||
|
||||
let (q, r) = qr_decomposition(t, device);
|
||||
let [r_rows, r_cols] = r.clone().dims();
|
||||
|
||||
let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
|
||||
.matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
|
||||
|
||||
let ph = diag_r.clone().sign();
|
||||
|
||||
let mut q = q.mul(ph);
|
||||
|
||||
if rows < cols {
|
||||
q = q.transpose();
|
||||
}
|
||||
|
||||
q.reshape(shape).mul_scalar(*gain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn kaiming_std(
|
||||
&self,
|
||||
fan_out_only: bool,
|
||||
fan_in: Option<usize>,
|
||||
fan_out: Option<usize>,
|
||||
) -> f64 {
|
||||
let fan = if fan_out_only { fan_out } else { fan_in };
|
||||
let fan = fan.expect(
|
||||
"Can't use Kaiming initialization without specifying fan. Use init_with method.",
|
||||
);
|
||||
|
||||
1.0 / (fan as f64).sqrt()
|
||||
}
|
||||
|
||||
fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
|
||||
let fan_in = fan_in.expect(
|
||||
"Can't use Xavier initialization without specifying fan in. Use init_with method and \
|
||||
provide fan_in.",
|
||||
);
|
||||
let fan_out = fan_out.expect(
|
||||
"Can't use Xavier initialization without specifying fan out. Use init_with method and \
|
||||
provide fan_out.",
|
||||
);
|
||||
(2.0 / (fan_in + fan_out) as f64).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
|
||||
shape: S,
|
||||
low: f64,
|
||||
high: f64,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D> {
|
||||
let distribution = Distribution::Uniform(low, high);
|
||||
Tensor::<B, D>::random(shape, distribution, device)
|
||||
}
|
||||
|
||||
fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
|
||||
shape: S,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D> {
|
||||
let distribution = Distribution::Normal(mean, std);
|
||||
Tensor::<B, D>::random(shape, distribution, device)
|
||||
}
|
||||
|
||||
fn qr_decomposition<B: Backend>(
|
||||
a: Tensor<B, 2>,
|
||||
device: &B::Device,
|
||||
) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
||||
// Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
|
||||
|
||||
let [m, n] = a.clone().dims();
|
||||
let mut q = Tensor::<B, 2>::zeros([m, n], device);
|
||||
let mut r = Tensor::<B, 2>::zeros([n, n], device);
|
||||
|
||||
for j in 0..n {
|
||||
let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
|
||||
|
||||
for i in 0..j {
|
||||
let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
|
||||
let r_ij = q_i.clone().mul(v.clone()).sum();
|
||||
|
||||
r = r
|
||||
.clone()
|
||||
.slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
|
||||
|
||||
v = v - q_i.mul(r_ij);
|
||||
}
|
||||
|
||||
// norm of v
|
||||
let r_jj = v
|
||||
.clone()
|
||||
.powf(Tensor::from_floats([2.0], device))
|
||||
.sum()
|
||||
.sqrt();
|
||||
|
||||
r = r
|
||||
.clone()
|
||||
.slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
|
||||
|
||||
let q_j = v / r_jj;
|
||||
|
||||
q = q
|
||||
.clone()
|
||||
.slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
|
||||
}
|
||||
|
||||
(q, r)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use burn_tensor::{ElementConversion, TensorData};
|
||||
use num_traits::Pow;
|
||||
|
||||
pub type TB = burn_ndarray::NdArray<f32>;
|
||||
use burn_tensor::{Tolerance, ops::FloatElem};
|
||||
type FT = FloatElem<TB>;
|
||||
|
||||
fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
|
||||
let (actual_vars, actual_means) = tensor.clone().var_mean(0);
|
||||
let actual_vars = actual_vars.to_data();
|
||||
let actual_vars = actual_vars.as_slice::<FT>().unwrap();
|
||||
let actual_means = actual_means.to_data();
|
||||
let actual_means = actual_means.as_slice::<FT>().unwrap();
|
||||
|
||||
for i in 0..tensor.shape()[0] {
|
||||
let actual_var = actual_vars[i] as f64;
|
||||
let actual_mean = actual_means[i] as f64;
|
||||
|
||||
assert!(
|
||||
(expected_var - actual_var).abs() <= 0.1,
|
||||
"Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
|
||||
);
|
||||
assert!(
|
||||
(expected_mean - actual_mean).abs() <= 0.1,
|
||||
"Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_uniform_init() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let (min, max) = (0.0, 1.0);
|
||||
let uniform = Initializer::Uniform { min, max };
|
||||
let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_within_range::<FT>(min.elem()..max.elem());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_normal_init() {
|
||||
// seed random generator
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let (mean, std) = (0.0, 1.0);
|
||||
let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
|
||||
.init([1000], &Default::default())
|
||||
.into_value();
|
||||
let (var_act, mean_act) = normal.var_mean(0);
|
||||
|
||||
let var_act: f32 = var_act.into_scalar().elem();
|
||||
let mean_act: f32 = mean_act.into_scalar().elem();
|
||||
|
||||
assert!(
|
||||
var_act > 0.9 && var_act < 1.1,
|
||||
"Expected variance to be between 1.0 += 0.1, but got {var_act}"
|
||||
);
|
||||
assert!(
|
||||
mean_act > -0.1 && mean_act < 0.1,
|
||||
"Expected mean to be between 0.0 += 0.1, but got {mean_act}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_constant_init() {
|
||||
let value = 5.0;
|
||||
let constants: Tensor<TB, 4> = Initializer::Constant { value }
|
||||
.init([2, 2, 2, 2], &Default::default())
|
||||
.into_value();
|
||||
constants.sum().to_data().assert_approx_eq::<FT>(
|
||||
&TensorData::from([value as f32 * 16.0]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros_init() {
|
||||
let zeros: Tensor<TB, 4> = Initializer::Zeros
|
||||
.init([2, 2, 2, 2], &Default::default())
|
||||
.into_value();
|
||||
zeros
|
||||
.sum()
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_ones_init() {
|
||||
let ones: Tensor<TB, 4> = Initializer::Ones
|
||||
.init([2, 2, 2, 2], &Default::default())
|
||||
.into_value();
|
||||
ones.sum()
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_kaiming_uniform_init() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2_f64;
|
||||
let (fan_in, fan_out) = (5, 6);
|
||||
let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
|
||||
|
||||
let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
|
||||
gain,
|
||||
fan_out_only: false,
|
||||
}
|
||||
.init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
|
||||
.into_value();
|
||||
tensor.into_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_kaiming_normal_init() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2.;
|
||||
let (fan_in, fan_out) = (1000, 10);
|
||||
let expected_mean = 0_f64;
|
||||
|
||||
let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
|
||||
let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
|
||||
gain,
|
||||
fan_out_only: false,
|
||||
}
|
||||
.init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
|
||||
.into_value();
|
||||
assert_normal_init(expected_mean, expected_var, &tensor)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_kaiming_uniform_init_bias() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2_f64;
|
||||
let shape = [3];
|
||||
let fan_in = 5;
|
||||
let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
|
||||
|
||||
let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
|
||||
gain,
|
||||
fan_out_only: false,
|
||||
}
|
||||
.init_with(shape, Some(fan_in), None, &Default::default())
|
||||
.into_value();
|
||||
tensor.into_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_kaiming_uniform_init_fan_out() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2_f64;
|
||||
let (fan_in, fan_out) = (5, 6);
|
||||
let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
|
||||
|
||||
let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
|
||||
gain,
|
||||
fan_out_only: true,
|
||||
}
|
||||
.init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
|
||||
.into_value();
|
||||
tensor.into_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn initializer_kaiming_uniform_no_fan() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2_f64;
|
||||
let (fan_in, fan_out) = (5, 6);
|
||||
|
||||
let _: Tensor<TB, 2> = Initializer::KaimingUniform {
|
||||
gain,
|
||||
fan_out_only: false,
|
||||
}
|
||||
.init([fan_out, fan_in], &Default::default())
|
||||
.into_value();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_xavier_uniform_init() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2.;
|
||||
let (fan_in, fan_out) = (5, 6);
|
||||
let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
|
||||
let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
|
||||
.init_with(
|
||||
[fan_out, fan_in],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
&Default::default(),
|
||||
)
|
||||
.into_value();
|
||||
|
||||
tensor.into_data().assert_within_range(-bound..bound);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_xavier_normal_init() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2.;
|
||||
let (fan_in, fan_out) = (1000, 10);
|
||||
let expected_mean = 0_f64;
|
||||
|
||||
let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
|
||||
let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
|
||||
.init_with(
|
||||
[fan_out, fan_in],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
&Default::default(),
|
||||
)
|
||||
.into_value();
|
||||
assert_normal_init(expected_mean, expected_var, &tensor)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn initializer_xavier_uniform_no_fan() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 2.;
|
||||
let (fan_in, fan_out) = (5, 6);
|
||||
let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
|
||||
.init([fan_out, fan_in], &Default::default())
|
||||
.into_value();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qr_decomposition() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
// test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr
|
||||
let a = Tensor::<TB, 2>::from_floats(
|
||||
[[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
|
||||
&Default::default(),
|
||||
);
|
||||
let qr = qr_decomposition(a.clone(), &Default::default());
|
||||
|
||||
// Q @ R should reconstruct input `a`
|
||||
let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
|
||||
|
||||
// assert that the difference between input (`a`) and Q @ R is (almost) zero
|
||||
q_matmul_r
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_orthogonal_correct() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 1.;
|
||||
|
||||
// test 2D tensor
|
||||
let size = 10;
|
||||
let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
|
||||
.init([size, size], &Default::default())
|
||||
.into_value();
|
||||
let eye = Tensor::<TB, 2>::eye(size, &Default::default());
|
||||
|
||||
// Q.T @ Q should be close to identity matrix
|
||||
q.clone()
|
||||
.transpose()
|
||||
.matmul(q)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_orthogonal_init() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 1.;
|
||||
|
||||
// test 2D tensor
|
||||
let shape = [25, 30];
|
||||
let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
|
||||
.init(shape, &Default::default())
|
||||
.into_value();
|
||||
let dims = t.dims();
|
||||
assert_eq!(
|
||||
shape, dims,
|
||||
"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
|
||||
);
|
||||
|
||||
// test 3D tensor
|
||||
let shape = [24, 6, 85];
|
||||
let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
|
||||
.init(shape, &Default::default())
|
||||
.into_value();
|
||||
let dims = t.dims();
|
||||
assert_eq!(
|
||||
shape, dims,
|
||||
"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn initializer_orthogonal_init_1d() {
|
||||
let device = Default::default();
|
||||
TB::seed(&device, 0);
|
||||
|
||||
let gain = 1.;
|
||||
|
||||
// test 1D tensor
|
||||
let shape = [3];
|
||||
let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
|
||||
.init(shape, &Default::default())
|
||||
.into_value();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
mod base;
|
||||
mod display;
|
||||
mod initializer;
|
||||
mod param;
|
||||
mod quantize;
|
||||
#[cfg(feature = "std")]
|
||||
mod reinit;
|
||||
|
||||
pub use base::*;
|
||||
pub use display::*;
|
||||
pub use initializer::*;
|
||||
pub use param::*;
|
||||
pub use quantize::*;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use reinit::*;
|
||||
@@ -0,0 +1,424 @@
|
||||
use super::ParamId;
|
||||
use alloc::{boxed::Box, format};
|
||||
use burn_std::stub::RwLock;
|
||||
use burn_tensor::Shape;
|
||||
use core::cell::OnceCell;
|
||||
use core::ops::Deref;
|
||||
|
||||
#[cfg(target_has_atomic = "ptr")]
|
||||
use alloc::sync::Arc;
|
||||
|
||||
#[cfg(not(target_has_atomic = "ptr"))]
|
||||
use portable_atomic_util::Arc;
|
||||
|
||||
#[cfg(target_has_atomic = "ptr")]
|
||||
type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
|
||||
|
||||
#[cfg(not(target_has_atomic = "ptr"))]
|
||||
type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
|
||||
|
||||
#[cfg(target_has_atomic = "ptr")]
|
||||
fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
|
||||
Arc::new(func)
|
||||
}
|
||||
|
||||
#[cfg(not(target_has_atomic = "ptr"))]
|
||||
fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
|
||||
Arc::new(Box::new(func))
|
||||
}
|
||||
|
||||
/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
|
||||
/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during
|
||||
/// training, and loaded during inference. If you don't want to save the tensors
|
||||
/// and/or don't want to update it during training, you don't need this type to wrap your tensor.
|
||||
///
|
||||
/// # Core Lazy Initialization Architecture
|
||||
///
|
||||
/// `Param<T>` has a dual-state design using `OnceCell<T>`:
|
||||
///
|
||||
/// ## State Management
|
||||
///
|
||||
/// **Two possible states:**
|
||||
///
|
||||
/// 1. **Initialized**: `state: OnceCell<T>` contains value, `initialization: None`
|
||||
/// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock<Option<Uninitialized<T>>>)`
|
||||
pub struct Param<T: Parameter> {
|
||||
/// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.
|
||||
pub id: ParamId,
|
||||
/// The OnceCell holding the initialized parameter value.
|
||||
/// Empty for uninitialized parameters, populated after first access or explicit initialization.
|
||||
pub(crate) state: OnceCell<T>,
|
||||
/// The deferred initialization state for lazy parameters.
|
||||
///
|
||||
/// **State Transitions:**
|
||||
/// - Initialized params: `None`
|
||||
/// - Uninitialized params: `Some(RwLock<Some(Uninitialized<T>)>)`
|
||||
/// - After lazy init triggers: `Some(RwLock<None>)` (inner Option is taken)
|
||||
pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
|
||||
pub(crate) param_mapper: ParamMapper<T>,
|
||||
// For stateful `module.valid()` <> `module.train()`
|
||||
pub(crate) require_grad: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Applies transformations when loading and saving parameters.
|
||||
///
|
||||
/// # Mapper System
|
||||
///
|
||||
/// `ParamMapper<T>` allows applying transformations during serialization and deserialization:
|
||||
/// - `load: Option<Mapper<T>>` - transformation during deserialization (applied in `transform_for_load()`)
|
||||
/// - `save: Option<Mapper<T>>` - transformation during serialization (applied in `transform_for_save()`)
|
||||
///
|
||||
/// These are commonly used for:
|
||||
/// - Quantization/dequantization
|
||||
/// - Precision conversion (e.g., FP32 ↔ FP16)
|
||||
/// - Custom parameter transformations
|
||||
pub struct ParamMapper<T: Parameter> {
|
||||
load: Option<Mapper<T>>,
|
||||
save: Option<Mapper<T>>,
|
||||
}
|
||||
|
||||
impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"ParamMapper {{ load: {}, save: {} }}",
|
||||
self.load.is_some(),
|
||||
self.save.is_some()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> ParamMapper<T> {
|
||||
/// Applies the transformation when loading the given parameter.
|
||||
pub fn on_load(&self, param: T) -> T {
|
||||
match &self.load {
|
||||
Some(mapper) => mapper(param),
|
||||
None => param,
|
||||
}
|
||||
}
|
||||
/// Applies the transformation when saving the given parameter.
|
||||
pub fn on_save(&self, param: T) -> T {
|
||||
match &self.save {
|
||||
Some(mapper) => mapper(param),
|
||||
None => param,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> Default for ParamMapper<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
load: None,
|
||||
save: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> core::fmt::Display for Param<T> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("Param: {}", self.id).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> core::fmt::Debug for Param<T> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that defines what is necessary for a type to be a parameter.
|
||||
pub trait Parameter: Clone + core::fmt::Debug + Send {
|
||||
/// The device type to be used.
|
||||
type Device: Clone;
|
||||
|
||||
/// Fetch the device.
|
||||
fn device(&self) -> Self::Device;
|
||||
|
||||
/// Fetch the gradient requirement.
|
||||
fn is_require_grad(&self) -> bool;
|
||||
|
||||
/// Set the gradient requirement.
|
||||
fn set_require_grad(self, require_grad: bool) -> Self;
|
||||
}
|
||||
|
||||
/// The deferred initialization state for lazy parameters.
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) struct Uninitialized<P: Parameter> {
|
||||
/// The initialization function. Called with `(device, is_require_grad) -> Parameter`.
|
||||
/// This function is consumed during initialization via `FnOnce`.
|
||||
init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
|
||||
/// The target device on which the parameter should be initialized.
|
||||
/// Used by `lazy_device()` to provide device information without triggering initialization.
|
||||
pub(crate) device: P::Device,
|
||||
/// The gradient requirement for the parameter.
|
||||
/// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization.
|
||||
pub(crate) is_require_grad: bool,
|
||||
/// The shape of the tensor parameter.
|
||||
/// Used by `lazy_shape()` to provide shape information without triggering initialization.
|
||||
pub(crate) shape: Shape,
|
||||
}
|
||||
|
||||
impl<P: Parameter> Uninitialized<P> {
|
||||
/// Consumes the uninitialized state and runs the initialization function.
|
||||
///
|
||||
/// This is called by [Param::val] when accessing an uninitialized parameter for the first time.
|
||||
/// The function is given the stored device and gradient requirement, and returns the initialized parameter.
|
||||
fn initialize(self) -> P {
|
||||
let init = self.init;
|
||||
init(&self.device, self.is_require_grad)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> Param<T> {
|
||||
/// Create a new parameter that is already initialized.
|
||||
pub fn initialized(id: ParamId, value: T) -> Self {
|
||||
let require_grad = value.is_require_grad();
|
||||
Self {
|
||||
id,
|
||||
state: OnceCell::from(value),
|
||||
initialization: None,
|
||||
param_mapper: Default::default(),
|
||||
require_grad,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new parameter that is not already initialized.
|
||||
pub fn uninitialized<F>(
|
||||
id: ParamId,
|
||||
init: F,
|
||||
device: T::Device,
|
||||
is_require_grad: bool,
|
||||
shape: Shape,
|
||||
) -> Self
|
||||
where
|
||||
F: FnOnce(&T::Device, bool) -> T + Send + 'static,
|
||||
{
|
||||
Self {
|
||||
id,
|
||||
state: OnceCell::new(),
|
||||
initialization: Some(RwLock::new(Some(Uninitialized {
|
||||
init: Box::new(init),
|
||||
device,
|
||||
is_require_grad,
|
||||
shape,
|
||||
}))),
|
||||
param_mapper: Default::default(),
|
||||
require_grad: is_require_grad,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the parameter value, initializing it lazily if needed.
|
||||
///
|
||||
/// For initialized parameters, this returns a clone of the cached value.
|
||||
/// For uninitialized parameters, this triggers initialization:
|
||||
pub fn val(&self) -> T {
|
||||
self.state
|
||||
.get_or_init(|| {
|
||||
let mut result = self
|
||||
.initialization
|
||||
.as_ref()
|
||||
.expect("Should have an initialization when no state provided.")
|
||||
.write()
|
||||
.unwrap();
|
||||
let state = result.take().expect("Should exist when not initialized");
|
||||
state.initialize()
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Check if the parameter has been initialized.
|
||||
///
|
||||
/// Returns `true` if the parameter's value has been computed and cached,
|
||||
/// `false` if it's still lazy and will be initialized on first access.
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.state.get().is_some()
|
||||
}
|
||||
|
||||
/// Gets the parameter's value while consuming the parameter.
|
||||
pub fn into_value(self) -> T {
|
||||
self.consume().1
|
||||
}
|
||||
|
||||
/// Gets the parameter id and value while consuming the parameter.
|
||||
pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
|
||||
let tensor = self.val();
|
||||
|
||||
core::mem::drop(self.state);
|
||||
|
||||
(self.id, tensor, self.param_mapper)
|
||||
}
|
||||
|
||||
/// Execute the given function on the inner value.
|
||||
pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
|
||||
let (id, tensor, param_mapper) = self.consume();
|
||||
let tensor = func(tensor);
|
||||
let require_grad = tensor.is_require_grad();
|
||||
|
||||
Self {
|
||||
id,
|
||||
state: OnceCell::from(tensor),
|
||||
initialization: None,
|
||||
param_mapper,
|
||||
require_grad,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an initialized parameter with the given id, value, and param mapper.
|
||||
///
|
||||
/// This is a helper method for creating parameters while preserving the param mapper,
|
||||
/// typically used in ModuleMapper implementations.
|
||||
pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
|
||||
let require_grad = value.is_require_grad();
|
||||
Self {
|
||||
id,
|
||||
state: OnceCell::from(value),
|
||||
initialization: None,
|
||||
param_mapper,
|
||||
require_grad,
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a transformation on the parameter when loading.
|
||||
pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
|
||||
self.param_mapper.load = Some(new_mapper(func));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Runs a transformation on the parameter when saving.
|
||||
pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
|
||||
self.param_mapper.save = Some(new_mapper(func));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute the given function on the inner value.
|
||||
pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self
|
||||
where
|
||||
T: 'static,
|
||||
{
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.map(func),
|
||||
};
|
||||
|
||||
let mut init = initialization.write().unwrap();
|
||||
|
||||
match init.as_mut() {
|
||||
Some(value) => {
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut prev: Box<dyn FnOnce(&T::Device, bool) -> T + Send> =
|
||||
Box::new(|_, _| panic!("Fake func to not have null ref."));
|
||||
core::mem::swap(&mut prev, &mut value.init);
|
||||
|
||||
value.init = Box::new(|a, b| {
|
||||
let tensor = prev(a, b);
|
||||
func(tensor)
|
||||
});
|
||||
core::mem::drop(init);
|
||||
self
|
||||
}
|
||||
None => {
|
||||
core::mem::drop(init);
|
||||
self.map(func)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The device on which the parameter is or will be initialized, **without triggering initialization**.
|
||||
///
|
||||
/// This is critical for the load optimization: when loading tensors into an uninitialized parameter,
|
||||
/// we need to know the target device to move the loaded tensor appropriately, but we don't want to
|
||||
/// trigger the initialization function (which would allocate an unnecessary tensor).
|
||||
///
|
||||
/// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to
|
||||
/// preserve lazy initialization.
|
||||
pub fn lazy_device(&self) -> T::Device {
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.device(),
|
||||
};
|
||||
|
||||
let init = initialization.read().unwrap();
|
||||
|
||||
match init.as_ref() {
|
||||
Some(value) => value.device.clone(),
|
||||
None => self.device(),
|
||||
}
|
||||
}
|
||||
|
||||
/// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**.
|
||||
///
|
||||
/// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization.
|
||||
/// When loading tensors into an uninitialized parameter, we need to apply the correct gradient
|
||||
/// setting to the loaded tensor without triggering the initialization function.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is a crate-private function, since users are not expected to use `is_require_grad` of an
|
||||
/// uninitialized module to then override its value. All low-level functions should be provided
|
||||
/// by `burn` and should handle those details.
|
||||
pub(crate) fn lazy_is_require_grad(&self) -> bool {
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.is_require_grad(),
|
||||
};
|
||||
|
||||
let init = initialization.read().unwrap();
|
||||
|
||||
match init.as_ref() {
|
||||
Some(value) => value.is_require_grad,
|
||||
None => self.is_require_grad(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Override the gradient requirement for the current parameter.
|
||||
pub fn set_require_grad(self, require_grad: bool) -> Self {
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
|
||||
};
|
||||
|
||||
let mut init = initialization.write().unwrap();
|
||||
let mut is_lazy = false;
|
||||
|
||||
if let Some(value) = init.as_mut() {
|
||||
is_lazy = true;
|
||||
value.is_require_grad = require_grad;
|
||||
};
|
||||
|
||||
core::mem::drop(init);
|
||||
|
||||
if is_lazy {
|
||||
return self;
|
||||
}
|
||||
|
||||
self.map(|tensor| tensor.set_require_grad(require_grad))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> Clone for Param<T> {
|
||||
fn clone(&self) -> Self {
|
||||
let mut param = Param::initialized(self.id, self.val());
|
||||
param.param_mapper = self.param_mapper.clone();
|
||||
param
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Parameter> Deref for Param<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.state.get_or_init(|| {
|
||||
let mut result = self
|
||||
.initialization
|
||||
.as_ref()
|
||||
.expect("Should have an initialization when no state provided.")
|
||||
.write()
|
||||
.unwrap();
|
||||
|
||||
let state = result.take().expect("Should exist when not initialized");
|
||||
state.initialize()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
use alloc::{format, string::ToString};
|
||||
use core::{fmt::Display, marker::PhantomData};
|
||||
|
||||
use crate as burn;
|
||||
use crate::{
|
||||
module::{
|
||||
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
|
||||
ModuleMapper, ModuleVisitor,
|
||||
},
|
||||
record::{PrecisionSettings, Record},
|
||||
};
|
||||
use burn_tensor::{
|
||||
BasicAutodiffOps, BasicOps, Tensor,
|
||||
backend::{AutodiffBackend, Backend},
|
||||
ops::Device,
|
||||
};
|
||||
|
||||
/// Record used for constant type implementing the [module](crate::module::Module) trait.
|
||||
#[derive(Debug, Clone, Copy, new, Default)]
|
||||
pub struct ConstantRecord;
|
||||
|
||||
impl serde::Serialize for ConstantRecord {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// nothing to serialize
|
||||
S::serialize_none(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for ConstantRecord {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_option(serde::de::IgnoredAny).ok();
|
||||
Ok(ConstantRecord::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Record<B> for ConstantRecord {
|
||||
type Item<S: PrecisionSettings> = ConstantRecord;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
item
|
||||
}
|
||||
}
|
||||
/// Constant macro.
|
||||
#[macro_export]
|
||||
macro_rules! constant {
|
||||
(module) => {
|
||||
type Record = burn::module::ConstantRecord;
|
||||
|
||||
fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
burn::module::ConstantRecord::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _: &B::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, _: &B::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
devices
|
||||
}
|
||||
};
|
||||
|
||||
(ad_module, $type:ty) => {
|
||||
type InnerModule = $type;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
}
|
||||
};
|
||||
|
||||
($type:ty) => {
|
||||
impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
|
||||
constant!(module);
|
||||
}
|
||||
|
||||
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
|
||||
constant!(ad_module, $type);
|
||||
}
|
||||
|
||||
impl burn::module::ModuleDisplayDefault for $type {
|
||||
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
let string = format!("{}", self);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl burn::module::ModuleDisplay for $type {}
|
||||
};
|
||||
}
|
||||
|
||||
// General Types
|
||||
constant!(alloc::string::String);
|
||||
constant!(bool);
|
||||
|
||||
// Float Types
|
||||
constant!(f64);
|
||||
constant!(f32);
|
||||
constant!(half::bf16);
|
||||
constant!(half::f16);
|
||||
|
||||
// Unsigned Integer Types
|
||||
constant!(usize);
|
||||
constant!(u64);
|
||||
constant!(u32);
|
||||
constant!(u16);
|
||||
constant!(u8);
|
||||
|
||||
// Signed Integer Types
|
||||
constant!(isize);
|
||||
constant!(i64);
|
||||
constant!(i32);
|
||||
constant!(i16);
|
||||
constant!(i8);
|
||||
|
||||
impl burn::module::ModuleDisplay for str {}
|
||||
impl burn::module::ModuleDisplayDefault for str {
|
||||
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
content.add_formatted(&self).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
||||
type Record = ConstantRecord;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
ConstantRecord
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
self.to_device(device)
|
||||
}
|
||||
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
self.to_device(device)
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().as_slice());
|
||||
content.add_single(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
|
||||
for Tensor<B, D, K>
|
||||
{
|
||||
type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone().inner()
|
||||
}
|
||||
|
||||
fn from_inner(tensor: Self::InnerModule) -> Self {
|
||||
Tensor::from_inner(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Module<B> for PhantomData<B> {
|
||||
type Record = ConstantRecord;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
ConstantRecord::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _: &Device<B>) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, _: &Device<B>) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
content.add_single(&"PhantomData".to_string()).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PhantomData<B> {}
|
||||
|
||||
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
|
||||
type InnerModule = PhantomData<B::InnerBackend>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
PhantomData
|
||||
}
|
||||
|
||||
fn from_inner(_module: Self::InnerModule) -> Self {
|
||||
PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
/// Container to satisfy the Module trait for types that are not modules.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Ignored<T>(pub T);
|
||||
|
||||
impl<B, T> Module<B> for Ignored<T>
|
||||
where
|
||||
B: Backend,
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
type Record = ConstantRecord;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
ConstantRecord::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _: &Device<B>) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, _: &Device<B>) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ModuleDisplayDefault for Ignored<T>
|
||||
where
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
// For now, just print the debug representation of the ignored value
|
||||
content.add_single(&format!("{:?}", self.0)).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
|
||||
|
||||
impl<T> Display for Ignored<T>
|
||||
where
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "{:?}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
type InnerModule = Ignored<T>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
}
|
||||
}
|
||||
|
||||
// Implement deref for Ignored
|
||||
impl<T> core::ops::Deref for Ignored<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Device, Tensor};
|
||||
|
||||
use crate::TestBackend;
|
||||
use crate::{
|
||||
TestAutodiffBackend,
|
||||
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
|
||||
};
|
||||
use burn::module::Module;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
#[test]
|
||||
fn tensor_load_record_setting() {
|
||||
let device: &Device<TestAutodiffBackend> = &Default::default();
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
|
||||
|
||||
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
|
||||
let bytes = Recorder::<TestAutodiffBackend>::record(
|
||||
&byte_recorder,
|
||||
tensor.clone().into_record(),
|
||||
(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let no_grad_is_require_grad = tensor
|
||||
.clone()
|
||||
.no_grad()
|
||||
.load_record(
|
||||
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
|
||||
.unwrap(),
|
||||
)
|
||||
.is_require_grad();
|
||||
|
||||
let with_default_is_require_grad = tensor
|
||||
.load_record(
|
||||
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
|
||||
.unwrap(),
|
||||
)
|
||||
.is_require_grad();
|
||||
|
||||
assert!(!no_grad_is_require_grad);
|
||||
assert!(!with_default_is_require_grad);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_module_with_phantom() {
|
||||
#[derive(Module, Debug, new)]
|
||||
struct EmptyModule<B: Backend> {
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
let _module = EmptyModule::<TestBackend>::new();
|
||||
|
||||
assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
use core::hash::{BuildHasher, Hasher};
|
||||
|
||||
use alloc::string::String;
|
||||
use burn_std::id::IdGenerator;
|
||||
use data_encoding::BASE32_DNSSEC;
|
||||
|
||||
// Hashbrown changed its default hasher in 0.15, but there are some issues
|
||||
// https://github.com/rust-lang/hashbrown/issues/577
|
||||
// Also, `param_serde_deserialize_legacy_uuid` doesn't pass with the default hasher.
|
||||
type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
|
||||
|
||||
/// Parameter ID.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
|
||||
pub struct ParamId {
|
||||
value: u64,
|
||||
}
|
||||
|
||||
impl From<u64> for ParamId {
|
||||
fn from(value: u64) -> Self {
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParamId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ParamId {
|
||||
/// Create a new parameter ID.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
value: IdGenerator::generate(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the internal value of the id.
|
||||
pub fn val(&self) -> u64 {
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Convert the parameter ID into a string.
|
||||
pub fn serialize(self) -> String {
|
||||
BASE32_DNSSEC.encode(&self.value.to_le_bytes())
|
||||
}
|
||||
|
||||
/// Deserialize a param id.
|
||||
///
|
||||
/// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).
|
||||
pub fn deserialize(encoded: &str) -> ParamId {
|
||||
let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
|
||||
Ok(bytes) => {
|
||||
let mut buffer = [0u8; 8];
|
||||
buffer[..bytes.len()].copy_from_slice(&bytes);
|
||||
u64::from_le_bytes(buffer)
|
||||
}
|
||||
Err(err) => match uuid::Uuid::try_parse(encoded) {
|
||||
// Backward compatibility with uuid parameter identifiers
|
||||
Ok(id) => {
|
||||
// Hash the 128-bit uuid to 64-bit
|
||||
// Though not *theoretically* unique, the probability of a collision should be extremely low
|
||||
let mut hasher = DefaultHashBuilder::default().build_hasher();
|
||||
// let mut hasher = DefaultHasher::new();
|
||||
hasher.write(id.as_bytes());
|
||||
hasher.finish()
|
||||
}
|
||||
Err(_) => panic!("Invalid id. {err}"),
|
||||
},
|
||||
};
|
||||
|
||||
ParamId::from(u64_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for ParamId {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(&self.serialize())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn param_serde_deserialize() {
|
||||
let val = ParamId::from(123456u64);
|
||||
let deserialized = ParamId::deserialize(&val.serialize());
|
||||
assert_eq!(val, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn param_serde_deserialize_legacy() {
|
||||
let legacy_val = [45u8; 6];
|
||||
let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
|
||||
assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
|
||||
assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn param_serde_deserialize_legacy_uuid() {
|
||||
// Ensure support for legacy uuid deserialization and make sure it results in the same output
|
||||
let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
|
||||
let param_id1 = ParamId::deserialize(legacy_id);
|
||||
let param_id2 = ParamId::deserialize(legacy_id);
|
||||
assert_eq!(param_id1, param_id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Invalid id."]
|
||||
fn param_serde_deserialize_invalid_id() {
|
||||
let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
|
||||
let _ = ParamId::deserialize(invalid_uuid);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
mod base;
|
||||
mod constant;
|
||||
mod id;
|
||||
mod primitive;
|
||||
mod running;
|
||||
mod tensor;
|
||||
mod visitor;
|
||||
|
||||
pub use base::*;
|
||||
pub use constant::*;
|
||||
pub use id::*;
|
||||
pub use running::*;
|
||||
pub use visitor::*;
|
||||
@@ -0,0 +1,426 @@
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor,
|
||||
};
|
||||
|
||||
use alloc::{format, string::ToString, vec::Vec};
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
ops::Device,
|
||||
};
|
||||
use core::fmt::Debug;
|
||||
|
||||
impl<T, B> Module<B> for Option<T>
|
||||
where
|
||||
T: Module<B> + Debug + Send + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
type Record = Option<T::Record>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
if let Some(module) = self {
|
||||
module.visit(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
self.map(|module| module.map(mapper))
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let is_constant = self.num_params() == 0;
|
||||
|
||||
if is_constant {
|
||||
return self;
|
||||
}
|
||||
|
||||
self.zip(record)
|
||||
.map(|(module, record)| module.load_record(record))
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.map(Module::into_record)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
self.map(|module| module.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.map(|module| module.fork(device))
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
if let Some(module) = self.as_ref() {
|
||||
devices = module.collect_devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
match self {
|
||||
Some(module) => content.add_single(module).optional(),
|
||||
None => content.add_single("None").optional(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Option<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type InnerModule = Option<T::InnerModule>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.as_ref().map(|module| module.valid())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.map(|module| T::from_inner(module))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> Module<B> for Vec<T>
|
||||
where
|
||||
T: Module<B> + Debug + Send + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
type Record = Vec<T::Record>;
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.iter() {
|
||||
num_params += module.num_params();
|
||||
}
|
||||
|
||||
num_params
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
for (i, module) in self.iter().enumerate() {
|
||||
let index_str = alloc::format!("{}", i);
|
||||
visitor.enter_module(&index_str, "Vec");
|
||||
module.visit(visitor);
|
||||
visitor.exit_module(&index_str, "Vec");
|
||||
}
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
self.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, module)| {
|
||||
let index_str = alloc::format!("{}", i);
|
||||
mapper.enter_module(&index_str, "Vec");
|
||||
let mapped = module.map(mapper);
|
||||
mapper.exit_module(&index_str, "Vec");
|
||||
mapped
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.into_iter().map(Module::into_record).collect()
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
assert_eq!(
|
||||
self.len(),
|
||||
record.len(),
|
||||
r#"[Load Record Error] The vec record does not the same length as the module.
|
||||
Make sure you module initialization is compatible with the record being loaded.
|
||||
"#,
|
||||
);
|
||||
|
||||
self.into_iter()
|
||||
.zip(record)
|
||||
.map(|(module, record)| module.load_record(record))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
self.into_iter()
|
||||
.map(|module| module.to_device(device))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.into_iter().map(|module| module.fork(device)).collect()
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
for module in self.iter() {
|
||||
devices = module.collect_devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
self.iter()
|
||||
.enumerate()
|
||||
.fold(content, |acc, (i, module)| {
|
||||
let index = format!("{i}");
|
||||
acc.add(&index, module)
|
||||
})
|
||||
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Vec<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type InnerModule = Vec<T::InnerModule>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.iter().map(|module| module.valid()).collect()
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
.into_iter()
|
||||
.map(|module| T::from_inner(module))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T, B> Module<B> for [T; N]
|
||||
where
|
||||
T: Module<B> + Debug + Send + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
type Record = [T::Record; N];
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
for module in self.iter() {
|
||||
devices = module.collect_devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.iter() {
|
||||
num_params += module.num_params();
|
||||
}
|
||||
|
||||
num_params
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
for (i, module) in self.iter().enumerate() {
|
||||
let index_str = alloc::format!("{}", i);
|
||||
visitor.enter_module(&index_str, "Array");
|
||||
module.visit(visitor);
|
||||
visitor.exit_module(&index_str, "Array");
|
||||
}
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let mut result = Vec::with_capacity(N);
|
||||
for (i, module) in IntoIterator::into_iter(self).enumerate() {
|
||||
let index_str = alloc::format!("{}", i);
|
||||
mapper.enter_module(&index_str, "Array");
|
||||
let mapped = module.map(mapper);
|
||||
mapper.exit_module(&index_str, "Array");
|
||||
result.push(mapped);
|
||||
}
|
||||
result
|
||||
.try_into()
|
||||
.unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
self.into_iter()
|
||||
.zip(record)
|
||||
.map(|(module, record)| module.load_record(record))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.map(Module::into_record)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
self.map(|module| module.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.map(|module| module.fork(device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
self.iter()
|
||||
.enumerate()
|
||||
.fold(content, |acc, (i, module)| {
|
||||
let index = format!("{i}");
|
||||
acc.add(&index, module)
|
||||
})
|
||||
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
|
||||
|
||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
T::InnerModule: Debug,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type InnerModule = [T::InnerModule; N];
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone().map(|module| module.valid())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.map(|module| T::from_inner(module))
|
||||
}
|
||||
}
|
||||
|
||||
/// A macro for generating implementations for tuple modules of different sizes.
|
||||
/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
|
||||
/// Would generate an implementation for a tuple of size 2.
|
||||
/// For this macro to work properly, please adhere to the convention:
|
||||
/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
|
||||
macro_rules! impl_module_tuple {
|
||||
// `$l` represents the generic modules.
|
||||
// `$i` represents the indices of the modules in the tuple.
|
||||
([$($l:ident),*][$($i:tt),*]) => {
|
||||
impl<B, $($l,)*> Module<B> for ($($l,)*)
|
||||
where
|
||||
B: Backend,
|
||||
$($l: Module<B> + Debug + Send + Clone,)*
|
||||
{
|
||||
type Record = ($($l::Record),*);
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
$(devices = self.$i.collect_devices(devices);)*
|
||||
devices
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
($(self.$i.fork(device),)*)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
($(self.$i.to_device(device),)*)
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
$(
|
||||
let index_str = $i.to_string();
|
||||
visitor.enter_module(&index_str, "Tuple");
|
||||
self.$i.visit(visitor);
|
||||
visitor.exit_module(&index_str, "Tuple");
|
||||
)*
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
($(
|
||||
{
|
||||
let index_str = $i.to_string();
|
||||
mapper.enter_module(&index_str, "Tuple");
|
||||
let mapped = self.$i.map(mapper);
|
||||
mapper.exit_module(&index_str, "Tuple");
|
||||
mapped
|
||||
}
|
||||
,)*)
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
($(self.$i.load_record(record.$i),)*)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
($(self.$i.into_record(),)*)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
$($l: AutodiffModule<B> + Debug + Send + Clone,)*
|
||||
{
|
||||
type InnerModule = ($($l::InnerModule,)*);
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
($(self.$i.valid(),)*)
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
($($l::from_inner(module.$i),)*)
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
|
||||
where
|
||||
$($l: ModuleDisplay,)*
|
||||
{
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let content = content
|
||||
$(.add(&format!("{}", $i), &self.$i))*
|
||||
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
|
||||
content.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
impl_module_tuple!([L0, L1][0, 1]);
|
||||
impl_module_tuple!([L0, L1, L2][0, 1, 2]);
|
||||
impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn dont_override_constant_module_when_loading_record() {
|
||||
let module = Some(42);
|
||||
|
||||
let record = Module::<TestBackend>::into_record(module);
|
||||
let loaded = Module::<TestBackend>::load_record(module, record);
|
||||
|
||||
assert_eq!(loaded, module);
|
||||
}
|
||||
#[test]
|
||||
fn dont_override_constant_module_when_loading_none_record() {
|
||||
let module = Some(42);
|
||||
|
||||
let record = None;
|
||||
let loaded = Module::<TestBackend>::load_record(module, record);
|
||||
|
||||
assert_eq!(loaded, module);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
use super::ParamId;
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor, Param,
|
||||
};
|
||||
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[cfg(target_has_atomic = "ptr")]
|
||||
use alloc::sync::Arc;
|
||||
|
||||
#[cfg(not(target_has_atomic = "ptr"))]
|
||||
use portable_atomic_util::Arc;
|
||||
|
||||
use burn_std::stub::Mutex;
|
||||
use burn_tensor::{
|
||||
Tensor,
|
||||
backend::{AutodiffBackend, Backend},
|
||||
ops::Device,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod threading {
|
||||
pub(super) use std::collections::HashMap;
|
||||
pub(super) use std::thread::ThreadId;
|
||||
|
||||
#[inline(always)]
|
||||
pub(super) fn get_thread_current_id() -> ThreadId {
|
||||
std::thread::current().id()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
mod threading {
|
||||
pub(super) use burn_std::stub::ThreadId;
|
||||
pub(super) use hashbrown::HashMap;
|
||||
|
||||
#[inline(always)]
|
||||
pub(super) fn get_thread_current_id() -> ThreadId {
|
||||
panic!("Current thread id is not available")
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export items from the disabled/enabled blocks
|
||||
use threading::*;
|
||||
|
||||
/// A state that can be updated during the forward pass while being thread safe.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The state value is the average of all updates on all threads.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RunningState<V> {
|
||||
id: ParamId,
|
||||
values: Arc<Mutex<HashMap<ThreadId, V>>>,
|
||||
value: Arc<Mutex<V>>,
|
||||
}
|
||||
|
||||
// Implement display for the module
|
||||
|
||||
impl<V> core::fmt::Display for RunningState<V> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "RunningState(id={})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> ModuleDisplayDefault for RunningState<V> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add_formatted(&"RunningState".to_string())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> ModuleDisplay for RunningState<V> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
let tensor = self.value.lock().unwrap();
|
||||
let param = Param::initialized(self.id, tensor.clone());
|
||||
visitor.visit_float(¶m)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let mut tensor = self.value.lock().unwrap();
|
||||
let param = Param::initialized(self.id, tensor.clone());
|
||||
let param_out = mapper.map_float(param);
|
||||
let (_, tensor_out, _) = param_out.consume();
|
||||
|
||||
*tensor = tensor_out;
|
||||
core::mem::drop(tensor);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.sync();
|
||||
let tensor = self.value.lock().unwrap();
|
||||
|
||||
Param::initialized(self.id, tensor.clone())
|
||||
}
|
||||
|
||||
fn load_record(mut self, record: Self::Record) -> Self {
|
||||
let mut tensor = self.value.lock().unwrap();
|
||||
*tensor = record.val().to_device(&tensor.device());
|
||||
self.id = record.id;
|
||||
|
||||
core::mem::drop(tensor);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
let mut tensor = self.value.lock().unwrap();
|
||||
let tensor_out = tensor.clone().to_device(device);
|
||||
|
||||
*tensor = tensor_out;
|
||||
core::mem::drop(tensor);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.to_device(device) // Same thing here since no grad.
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
|
||||
let device = self.value.lock().unwrap().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
||||
/// Create a new running state.
|
||||
pub fn new(value: Tensor<B, D>) -> Self {
|
||||
Self {
|
||||
id: ParamId::new(),
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(Mutex::new(value)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new running state.
|
||||
pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(Mutex::new(value)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new running state from a record.
|
||||
pub fn from_record(record: Param<Tensor<B, D>>) -> Self {
|
||||
let tensor = record.val();
|
||||
Self {
|
||||
id: record.id,
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(Mutex::new(tensor)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the value on the current thread.
|
||||
pub fn update(&self, value: Tensor<B, D>) {
|
||||
let thread_id = get_thread_current_id();
|
||||
let mut map = self.values.lock().unwrap();
|
||||
|
||||
if map.contains_key(&thread_id) {
|
||||
self.update_value(&mut map);
|
||||
}
|
||||
|
||||
map.insert(thread_id, value);
|
||||
}
|
||||
|
||||
/// Get the current value,
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The current value might be outdated by one update.
|
||||
pub fn value(&self) -> Tensor<B, D> {
|
||||
let value = self.value.lock().unwrap();
|
||||
value.clone()
|
||||
}
|
||||
|
||||
/// Get the current value and make sure it is sync.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// Don't use this function after an update on the same thread where other threads might have to
|
||||
/// register their update before the actual synchronization needs to happen.
|
||||
pub fn value_sync(&self) -> Tensor<B, D> {
|
||||
let thread_id = get_thread_current_id();
|
||||
let mut map = self.values.lock().unwrap();
|
||||
|
||||
if map.contains_key(&thread_id) {
|
||||
self.update_value(&mut map);
|
||||
}
|
||||
|
||||
let value = self.value.lock().unwrap();
|
||||
value.clone()
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
let mut map = self.values.lock().unwrap();
|
||||
|
||||
if !map.is_empty() {
|
||||
self.update_value(&mut map);
|
||||
}
|
||||
}
|
||||
|
||||
fn update_value(&self, map: &mut HashMap<ThreadId, Tensor<B, D>>) {
|
||||
let mut value_updated: Option<Tensor<B, D>> = None;
|
||||
let mut counter = 0;
|
||||
|
||||
for (_key, tensor) in map.drain() {
|
||||
counter += 1;
|
||||
|
||||
value_updated = match value_updated {
|
||||
Some(current) => {
|
||||
let device = current.device();
|
||||
Some(tensor.to_device(&device).add(current))
|
||||
}
|
||||
None => Some(tensor),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(value) = value_updated {
|
||||
let value = value.div_scalar(counter);
|
||||
let mut value_old = self.value.lock().unwrap();
|
||||
*value_old = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tensor<B, D>> {
|
||||
type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.sync();
|
||||
let value = self.value();
|
||||
|
||||
RunningState::with_id(self.id, value.inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.sync();
|
||||
let value = module.value();
|
||||
|
||||
RunningState::with_id(module.id, Tensor::from_inner(value))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,571 @@
|
||||
use super::{Param, ParamId, Parameter};
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault,
|
||||
ModuleMapper, ModuleVisitor,
|
||||
};
|
||||
use crate::tensor::{
|
||||
Tensor,
|
||||
backend::{AutodiffBackend, Backend},
|
||||
};
|
||||
use alloc::{format, string::ToString, vec::Vec};
|
||||
use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
|
||||
|
||||
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
|
||||
type Device = B::Device;
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
Tensor::device(self)
|
||||
}
|
||||
|
||||
fn is_require_grad(&self) -> bool {
|
||||
Tensor::is_require_grad(self)
|
||||
}
|
||||
|
||||
fn set_require_grad(self, require_grad: bool) -> Self {
|
||||
Tensor::set_require_grad(self, require_grad)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
|
||||
type Device = B::Device;
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
Tensor::device(self)
|
||||
}
|
||||
|
||||
fn is_require_grad(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn set_require_grad(self, _require_grad: bool) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
|
||||
type Device = B::Device;
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
Tensor::device(self)
|
||||
}
|
||||
|
||||
fn is_require_grad(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn set_require_grad(self, _require_grad: bool) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
|
||||
/// Create a new parameter from a float tensor.
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// We strongly recommend using [Param::uninitialized] if you are using this method to
|
||||
/// initialize parameters inside a module, since the tensor initialization will be lazy,
|
||||
/// making the loading of weights more performant.
|
||||
pub fn from_tensor(value: Tensor<B, D>) -> Self {
|
||||
// When creating a parameter from a float tensor, we automatically mark it as requiring
|
||||
// gradients, so that it can be updated by an optimizer.
|
||||
Param::initialized(ParamId::new(), value.require_grad())
|
||||
}
|
||||
|
||||
/// The shape of the parameter, **without triggering initialization**.
|
||||
///
|
||||
/// This is critical for shape validation during loading: when applying tensors to an
|
||||
/// uninitialized parameter, we need to validate the shape without triggering the
|
||||
/// initialization function (which would allocate an unnecessary tensor).
|
||||
///
|
||||
/// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
|
||||
/// preserve lazy initialization.
|
||||
pub fn lazy_shape(&self) -> burn_tensor::Shape {
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.shape(),
|
||||
};
|
||||
|
||||
let init = initialization.read().unwrap();
|
||||
|
||||
match init.as_ref() {
|
||||
Some(value) => value.shape.clone(),
|
||||
None => self.shape(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new parameter from data.
|
||||
pub fn from_data<T>(data: T, device: &B::Device) -> Self
|
||||
where
|
||||
T: Into<TensorData>,
|
||||
{
|
||||
// When creating a parameter from a float tensor, we automatically mark it as requiring
|
||||
// gradients, so that it can be updated by an optimizer.
|
||||
B::memory_persistent_allocations(device, data, |data| {
|
||||
let value = Tensor::from_data(data, device);
|
||||
Param::initialized(ParamId::new(), value.require_grad())
|
||||
})
|
||||
}
|
||||
|
||||
/// Transform a parameter for loading by applying load transformations.
|
||||
///
|
||||
/// This method is used to restore a parameter from a tensor (typically during deserialization).
|
||||
/// It ensures the tensor is moved to the expected device, applies the param mapper's
|
||||
/// `on_load` transformation, and preserves the autodiff settings (require_grad).
|
||||
pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {
|
||||
let mut new_tensor = tensor;
|
||||
|
||||
let mapper = self.param_mapper.clone();
|
||||
|
||||
let expected_device = self.lazy_device();
|
||||
let expected_require_grad = self.lazy_is_require_grad();
|
||||
|
||||
// Make sure we load the tensor into the same module device.
|
||||
if new_tensor.device() != expected_device {
|
||||
new_tensor = new_tensor.to_device(&expected_device).detach();
|
||||
}
|
||||
|
||||
new_tensor = mapper.on_load(new_tensor);
|
||||
|
||||
// Make sure we load the tensor with the same autodiff setting.
|
||||
new_tensor = new_tensor.set_require_grad(expected_require_grad);
|
||||
|
||||
let mut loaded = Self::initialized(param_id, new_tensor);
|
||||
loaded.param_mapper = mapper;
|
||||
loaded
|
||||
}
|
||||
|
||||
/// Transform a parameter for saving by applying save transformations.
|
||||
///
|
||||
/// This method is used to prepare a parameter for saving (typically during serialization).
|
||||
/// It applies the param mapper's `on_save` transformation, which can be used
|
||||
/// to modify the tensor before serialization (e.g., quantization, precision conversion).
|
||||
pub fn transform_for_save(&self) -> Self {
|
||||
let mut tensor = self.val();
|
||||
let mapper = self.param_mapper.clone();
|
||||
|
||||
tensor = mapper.on_save(tensor);
|
||||
|
||||
Self::initialized(self.id, tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {
|
||||
/// The shape of the parameter, **without triggering initialization**.
|
||||
///
|
||||
/// This is critical for shape validation during loading: when applying tensors to an
|
||||
/// uninitialized parameter, we need to validate the shape without triggering the
|
||||
/// initialization function (which would allocate an unnecessary tensor).
|
||||
///
|
||||
/// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
|
||||
/// preserve lazy initialization.
|
||||
pub fn lazy_shape(&self) -> burn_tensor::Shape {
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.shape(),
|
||||
};
|
||||
|
||||
let init = initialization.read().unwrap();
|
||||
|
||||
match init.as_ref() {
|
||||
Some(value) => value.shape.clone(),
|
||||
None => self.shape(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Transform a parameter for loading by applying load transformations.
|
||||
///
|
||||
/// This method is used to restore a parameter from a tensor (typically during deserialization).
|
||||
/// It ensures the tensor is moved to the expected device and applies the param mapper's
|
||||
/// `on_load` transformation.
|
||||
pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {
|
||||
let mut new_tensor = tensor;
|
||||
|
||||
let mapper = self.param_mapper.clone();
|
||||
|
||||
let expected_device = self.lazy_device();
|
||||
|
||||
// Make sure we load the tensor into the same module device.
|
||||
if new_tensor.device() != expected_device {
|
||||
new_tensor = new_tensor.to_device(&expected_device);
|
||||
}
|
||||
|
||||
new_tensor = mapper.on_load(new_tensor);
|
||||
|
||||
let mut loaded = Self::initialized(param_id, new_tensor);
|
||||
loaded.param_mapper = mapper;
|
||||
loaded
|
||||
}
|
||||
|
||||
/// Transform a parameter for saving by applying save transformations.
|
||||
///
|
||||
/// This method is used to prepare a parameter for saving (typically during serialization).
|
||||
/// It applies the param mapper's `on_save` transformation, which can be used
|
||||
/// to modify the tensor before serialization (e.g., quantization, precision conversion).
|
||||
pub fn transform_for_save(&self) -> Self {
|
||||
let mut tensor = self.val();
|
||||
let mapper = self.param_mapper.clone();
|
||||
|
||||
tensor = mapper.on_save(tensor);
|
||||
|
||||
Self::initialized(self.id, tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {
|
||||
/// The shape of the parameter, **without triggering initialization**.
|
||||
///
|
||||
/// This is critical for shape validation during loading: when applying tensors to an
|
||||
/// uninitialized parameter, we need to validate the shape without triggering the
|
||||
/// initialization function (which would allocate an unnecessary tensor).
|
||||
///
|
||||
/// **Returns:**
|
||||
/// - For uninitialized params: the shape from the `Uninitialized` struct
|
||||
/// - For initialized params: the actual shape from the tensor
|
||||
///
|
||||
/// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
|
||||
/// preserve lazy initialization.
|
||||
pub fn lazy_shape(&self) -> burn_tensor::Shape {
|
||||
let initialization = match &self.initialization {
|
||||
Some(init) => init,
|
||||
None => return self.shape(),
|
||||
};
|
||||
|
||||
let init = initialization.read().unwrap();
|
||||
|
||||
match init.as_ref() {
|
||||
Some(value) => value.shape.clone(),
|
||||
None => self.shape(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Transform a parameter for loading by applying load transformations.
|
||||
///
|
||||
/// This method is used to restore a parameter from a tensor (typically during deserialization).
|
||||
/// It ensures the tensor is moved to the expected device and applies the param mapper's
|
||||
/// `on_load` transformation.
|
||||
pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {
|
||||
let mut new_tensor = tensor;
|
||||
|
||||
let mapper = self.param_mapper.clone();
|
||||
|
||||
let expected_device = self.lazy_device();
|
||||
|
||||
// Make sure we load the tensor into the same module device.
|
||||
if new_tensor.device() != expected_device {
|
||||
new_tensor = new_tensor.to_device(&expected_device);
|
||||
}
|
||||
|
||||
new_tensor = mapper.on_load(new_tensor);
|
||||
|
||||
let mut loaded = Self::initialized(param_id, new_tensor);
|
||||
loaded.param_mapper = mapper;
|
||||
loaded
|
||||
}
|
||||
|
||||
/// Transform a parameter for saving by applying save transformations.
|
||||
///
|
||||
/// This method is used to prepare a parameter for saving (typically during serialization).
|
||||
/// It applies the param mapper's `on_save` transformation, which can be used
|
||||
/// to modify the tensor before serialization (e.g., quantization, precision conversion).
|
||||
pub fn transform_for_save(&self) -> Self {
|
||||
let mut tensor = self.val();
|
||||
let mapper = self.param_mapper.clone();
|
||||
|
||||
tensor = mapper.on_save(tensor);
|
||||
|
||||
Self::initialized(self.id, tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit_float(self)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
mapper.map_float(self)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.transform_for_save()
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let (record_param_id, record_tensor, _) = record.consume();
|
||||
self.transform_for_load(record_tensor, record_param_id)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
self.map(|tensor| tensor.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.map(|tensor| {
|
||||
let is_require_grad = tensor.is_require_grad();
|
||||
let mut tensor = tensor.to_device(device).detach();
|
||||
|
||||
if is_require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
|
||||
tensor
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
|
||||
let device = self.val().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
|
||||
self.shape().as_slice()
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
||||
type Record = Param<Tensor<B, D, Int>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit_int(self)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
mapper.map_int(self)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.transform_for_save()
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let (record_param_id, record_tensor, _) = record.consume();
|
||||
self.transform_for_load(record_tensor, record_param_id)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
self.map(|tensor| tensor.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.to_device(device) // Don't support autodiff.
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
|
||||
let device = self.val().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
|
||||
self.shape().as_slice()
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
||||
type Record = Param<Tensor<B, D, Bool>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit_bool(self)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
mapper.map_bool(self)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.transform_for_save()
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let (record_param_id, record_tensor, _) = record.consume();
|
||||
self.transform_for_load(record_tensor, record_param_id)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &Device<B>) -> Self {
|
||||
self.map(|tensor| tensor.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &Device<B>) -> Self {
|
||||
self.to_device(device) // Don't support autodiff.
|
||||
}
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
|
||||
let device = self.val().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
|
||||
self.shape().as_slice()
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
// Preserve initialized param `require_grad` state, but reset the inner value's
|
||||
let require_grad = self.require_grad;
|
||||
let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false));
|
||||
param.require_grad = require_grad;
|
||||
param
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
// Reinstate the param's `require_grad` state
|
||||
let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad);
|
||||
Param::initialized(module.id, tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> HasAutodiffModule<B>
|
||||
for Param<Tensor<B::InnerBackend, D>>
|
||||
{
|
||||
type TrainModule = Param<Tensor<B, D>>;
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
Param::initialized(self.id, self.val().inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
Param::initialized(module.id, Tensor::from_inner(module.val()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
Param::initialized(self.id, self.val().inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
Param::initialized(module.id, Tensor::from_inner(module.val()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
TestAutodiffBackend,
|
||||
module::Module,
|
||||
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_load_record_setting() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
|
||||
|
||||
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
|
||||
let bytes = byte_recorder
|
||||
.record(
|
||||
Param::initialized(ParamId::new(), tensor.clone()).into_record(),
|
||||
(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
|
||||
.no_grad()
|
||||
.load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
|
||||
.is_require_grad();
|
||||
|
||||
let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
|
||||
.load_record(byte_recorder.load(bytes, &device).unwrap())
|
||||
.is_require_grad();
|
||||
|
||||
assert!(!no_grad_is_require_grad);
|
||||
assert!(with_default_is_require_grad);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_param_require_grad_stateful() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
|
||||
|
||||
let param = Param::initialized(ParamId::new(), tensor);
|
||||
assert!(param.is_require_grad());
|
||||
assert!(param.require_grad);
|
||||
|
||||
let param = param.valid();
|
||||
assert!(!param.is_require_grad());
|
||||
assert!(param.require_grad); // stateful
|
||||
|
||||
// Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying:
|
||||
// let param: Param<Tensor<TestAutodiffBackend, _>> = param.train();
|
||||
let param = param.train::<TestAutodiffBackend>();
|
||||
assert!(param.is_require_grad());
|
||||
assert!(param.require_grad); // stateful
|
||||
|
||||
let param = param.no_grad();
|
||||
assert!(!param.is_require_grad());
|
||||
assert!(!param.require_grad); // stateful
|
||||
|
||||
let param = param.valid();
|
||||
assert!(!param.is_require_grad()); // always
|
||||
assert!(!param.require_grad); // stateful
|
||||
|
||||
let param = param.train::<TestAutodiffBackend>();
|
||||
assert!(!param.is_require_grad());
|
||||
assert!(!param.require_grad); // stateful
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use super::{Param, ParamId};
|
||||
use crate::module::{Module, ModuleVisitor};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Bool, Int, Tensor, backend::Backend};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
struct ParamIdCollector<'a, M> {
|
||||
param_ids: &'a mut Vec<ParamId>,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<B, M> ModuleVisitor<B> for ParamIdCollector<'_, M>
|
||||
where
|
||||
B: Backend,
|
||||
M: Module<B>,
|
||||
{
|
||||
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
|
||||
self.param_ids.push(param.id);
|
||||
}
|
||||
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
|
||||
self.param_ids.push(param.id);
|
||||
}
|
||||
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
|
||||
self.param_ids.push(param.id);
|
||||
}
|
||||
}
|
||||
|
||||
/// List all the parameter ids in a module.
|
||||
pub fn list_param_ids<M: Module<B>, B: Backend>(module: &M) -> Vec<ParamId> {
|
||||
let mut params_ids = Vec::new();
|
||||
let mut visitor = ParamIdCollector {
|
||||
param_ids: &mut params_ids,
|
||||
phantom: PhantomData::<M>,
|
||||
};
|
||||
module.visit(&mut visitor);
|
||||
|
||||
params_ids
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
use burn_tensor::{
|
||||
Tensor,
|
||||
backend::Backend,
|
||||
quantization::{Calibration, QuantScheme, compute_q_params, compute_range},
|
||||
};
|
||||
|
||||
use crate::module::{ModuleMapper, Param};
|
||||
|
||||
/// Describes how to quantize a module.
|
||||
pub struct Quantizer {
|
||||
/// The calibration method used in quantization.
|
||||
pub calibration: Calibration,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantScheme,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleMapper<B> for Quantizer {
|
||||
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
let range = compute_range(&self.scheme, &tensor, &self.calibration);
|
||||
let qparams = compute_q_params(&self.scheme, range);
|
||||
let tensor = tensor.quantize(&self.scheme, qparams);
|
||||
Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(feature = "test-tch")))]
|
||||
mod tests {
|
||||
use crate::test_utils::SimpleLinear;
|
||||
use crate::{
|
||||
TestBackend,
|
||||
module::{Module, Quantizer},
|
||||
};
|
||||
use burn_tensor::{
|
||||
Device, Tolerance,
|
||||
ops::QuantizedTensor,
|
||||
quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue},
|
||||
};
|
||||
|
||||
type B = TestBackend;
|
||||
|
||||
#[test]
|
||||
fn should_quantize_module() {
|
||||
let device: Device<B> = Default::default();
|
||||
let module = SimpleLinear::<B>::new(32, 32, &device);
|
||||
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_level(QuantLevel::Tensor)
|
||||
.with_param(QuantParam::F32);
|
||||
|
||||
let result = module.weight.val();
|
||||
|
||||
let calibration = Calibration::MinMax;
|
||||
let mut quantizer = Quantizer {
|
||||
calibration,
|
||||
scheme,
|
||||
};
|
||||
let q_module = module.quantize_weights(&mut quantizer);
|
||||
let q_result = q_module.weight.val().dequantize();
|
||||
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<f32>(&q_result.into_data(), Tolerance::permissive());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
use super::{Module, ModuleMapper};
|
||||
use burn_tensor::{
|
||||
Element, ElementConversion, Tensor, TensorData,
|
||||
backend::Backend,
|
||||
ops::{FloatElem, IntElem},
|
||||
};
|
||||
use rand::{RngExt, SeedableRng};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Overrides float and int tensors of [burn modules](super::Module).
|
||||
///
|
||||
/// This is useful for testing.
|
||||
pub struct Reinitializer<B: Backend> {
|
||||
float: ReinitStrategy<FloatElem<B>>,
|
||||
int: ReinitStrategy<IntElem<B>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(missing_docs)]
|
||||
enum ReinitStrategy<E> {
|
||||
Range { min: E, max: E },
|
||||
Constant { value: E },
|
||||
Random { seed: u64, min: E, max: E },
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for Reinitializer<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Reinitializer<B> {
|
||||
/// Create a new [reinitializer](Reinitializer).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
float: ReinitStrategy::Constant {
|
||||
value: 0.elem::<FloatElem<B>>(),
|
||||
},
|
||||
int: ReinitStrategy::Constant {
|
||||
value: 0.elem::<IntElem<B>>(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the reinitialization to the given [module](Module).
|
||||
pub fn apply<M: Module<B>>(mut self, module: M) -> M {
|
||||
module.map(&mut self)
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to constant for all tensors.
|
||||
pub fn constant(self, constant: f64) -> Self {
|
||||
self.constant_float(constant).constant_int(constant as i64)
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to constant for float tensors.
|
||||
pub fn constant_float(mut self, constant: f64) -> Self {
|
||||
self.float = ReinitStrategy::Constant {
|
||||
value: constant.elem(),
|
||||
};
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to constant for int tensors.
|
||||
pub fn constant_int(mut self, constant: i64) -> Self {
|
||||
self.int = ReinitStrategy::Constant {
|
||||
value: constant.elem(),
|
||||
};
|
||||
self
|
||||
}
|
||||
/// Set the reinitialization strategy to random for all tensors.
|
||||
pub fn random(self, seed: u64, min: f64, max: f64) -> Self {
|
||||
self.random_float(seed, min, max)
|
||||
.random_int(seed, min as i64, max as i64)
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to random for float tensors.
|
||||
pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {
|
||||
self.float = ReinitStrategy::Random {
|
||||
seed,
|
||||
min: min.elem(),
|
||||
max: max.elem(),
|
||||
};
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to random for int tensors.
|
||||
pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {
|
||||
self.int = ReinitStrategy::Random {
|
||||
seed,
|
||||
min: min.elem(),
|
||||
max: max.elem(),
|
||||
};
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to range for all tensors.
|
||||
pub fn range(self, min: f64, max: f64) -> Self {
|
||||
self.range_float(min, max).range_int(min as i64, max as i64)
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to range for float tensors.
|
||||
pub fn range_float(mut self, min: f64, max: f64) -> Self {
|
||||
self.float = ReinitStrategy::Range {
|
||||
min: min.elem(),
|
||||
max: max.elem(),
|
||||
};
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the reinitialization strategy to range for int tensors.
|
||||
pub fn range_int(mut self, min: i64, max: i64) -> Self {
|
||||
self.int = ReinitStrategy::Range {
|
||||
min: min.elem(),
|
||||
max: max.elem(),
|
||||
};
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleMapper<B> for Reinitializer<B> {
|
||||
fn map_float<const D: usize>(
|
||||
&mut self,
|
||||
param: super::Param<Tensor<B, D>>,
|
||||
) -> super::Param<Tensor<B, D>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
let device = tensor.device();
|
||||
let shape = tensor.shape();
|
||||
let num_elements = shape.num_elements();
|
||||
|
||||
let tensor = match &self.float {
|
||||
ReinitStrategy::Range { min, max } => {
|
||||
let tensor = Tensor::arange(0..num_elements as i64, &device)
|
||||
.reshape(shape)
|
||||
.float();
|
||||
let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);
|
||||
tensor * factor + bias
|
||||
}
|
||||
ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
|
||||
ReinitStrategy::Random { seed, min, max } => {
|
||||
let data = TensorData::new(
|
||||
random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),
|
||||
shape,
|
||||
);
|
||||
Tensor::from_data(data, &device)
|
||||
}
|
||||
};
|
||||
|
||||
super::Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
|
||||
fn map_int<const D: usize>(
|
||||
&mut self,
|
||||
param: super::Param<Tensor<B, D, burn_tensor::Int>>,
|
||||
) -> super::Param<Tensor<B, D, burn_tensor::Int>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
let device = tensor.device();
|
||||
let shape = tensor.shape();
|
||||
let num_elements = shape.num_elements();
|
||||
|
||||
let tensor = match &self.int {
|
||||
ReinitStrategy::Range { min, max } => {
|
||||
let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);
|
||||
let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);
|
||||
tensor * factor + bias
|
||||
}
|
||||
ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
|
||||
ReinitStrategy::Random { seed, min, max } => {
|
||||
let data = TensorData::new(
|
||||
random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),
|
||||
shape,
|
||||
);
|
||||
Tensor::from_data(data, &device)
|
||||
}
|
||||
};
|
||||
|
||||
super::Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
|
||||
fn map_bool<const D: usize>(
|
||||
&mut self,
|
||||
param: super::Param<Tensor<B, D, burn_tensor::Bool>>,
|
||||
) -> super::Param<Tensor<B, D, burn_tensor::Bool>> {
|
||||
let (id, tensor, mapper) = param.consume();
|
||||
super::Param::from_mapped_value(id, tensor, mapper)
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {
|
||||
let range = max.elem::<f64>() - min.elem::<f64>();
|
||||
let factor = range / num_elements as f64;
|
||||
let bias = min.elem::<f64>();
|
||||
|
||||
(factor.elem(), bias.elem())
|
||||
}
|
||||
|
||||
fn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
let dist = rand::distr::Uniform::new(min, max).unwrap();
|
||||
(0..num_elements)
|
||||
.map(|_| rng.sample(dist))
|
||||
.map(|e| e.elem::<E>())
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
pub use burn_derive::Record;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use super::PrecisionSettings;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
|
||||
pub trait Record<B: Backend>: Send {
|
||||
/// Type of the item that can be serialized and deserialized.
|
||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned + Clone;
|
||||
|
||||
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
|
||||
|
||||
/// Convert the given item into a record.
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self;
|
||||
}
|
||||
@@ -0,0 +1,421 @@
|
||||
use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
|
||||
use burn_tensor::backend::Backend;
|
||||
use core::marker::PhantomData;
|
||||
use flate2::{Compression, read::GzDecoder, write::GzEncoder};
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from files.
|
||||
pub trait FileRecorder<B: Backend>:
|
||||
Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
|
||||
{
|
||||
/// File extension of the format used by the recorder.
|
||||
fn file_extension() -> &'static str;
|
||||
}
|
||||
|
||||
/// Default [file recorder](FileRecorder).
|
||||
pub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;
|
||||
|
||||
/// File recorder using the [bincode format](bincode).
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct BinFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using the [bincode format](bincode) compressed with gzip.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct BinGzFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using the [json format](serde_json) compressed with gzip.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct JsonGzFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using [pretty json format](serde_json) for easy readability.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using the [named msgpack](rmp_serde) format compressed with gzip.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using the [named msgpack](rmp_serde) format.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"bin.gz"
|
||||
}
|
||||
}
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"bin"
|
||||
}
|
||||
}
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"json.gz"
|
||||
}
|
||||
}
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"json"
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"mpk.gz"
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"mpk"
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! str2reader {
|
||||
(
|
||||
$file:expr
|
||||
) => {{
|
||||
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
|
||||
let path = $file.as_path();
|
||||
|
||||
File::open(path)
|
||||
.map_err(|err| match err.kind() {
|
||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
||||
_ => RecorderError::Unknown(err.to_string()),
|
||||
})
|
||||
.map(|file| BufReader::new(file))
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! str2writer {
|
||||
(
|
||||
$file:expr
|
||||
) => {{
|
||||
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
|
||||
let path = $file.as_path();
|
||||
|
||||
log::debug!("Writing to file: {:?}", path);
|
||||
|
||||
// Add parent directories if they don't exist
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).ok();
|
||||
}
|
||||
|
||||
if path.exists() {
|
||||
log::warn!("File exists, replacing");
|
||||
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
}
|
||||
|
||||
File::create(path)
|
||||
.map_err(|err| match err.kind() {
|
||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
||||
_ => RecorderError::Unknown(err.to_string()),
|
||||
})
|
||||
.map(|file| BufWriter::new(file))
|
||||
}};
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let config = bin_config();
|
||||
let writer = str2writer!(file)?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
|
||||
bincode::serde::encode_into_std_write(&item, &mut writer, config)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let reader = str2reader!(file)?;
|
||||
let mut reader = GzDecoder::new(reader);
|
||||
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let config = bin_config();
|
||||
let mut writer = str2writer!(file)?;
|
||||
bincode::serde::encode_into_std_write(&item, &mut writer, config)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let mut reader = str2reader!(file)?;
|
||||
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let writer = str2writer!(file)?;
|
||||
let writer = GzEncoder::new(writer, Compression::default());
|
||||
serde_json::to_writer(writer, &item)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let reader = str2reader!(file)?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = serde_json::from_reader(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let writer = str2writer!(file)?;
|
||||
serde_json::to_writer_pretty(writer, &item)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let reader = str2reader!(file)?;
|
||||
let state = serde_json::from_reader(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let writer = str2writer!(file)?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
rmp_serde::encode::write_named(&mut writer, &item)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let reader = str2reader!(file)?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = rmp_serde::decode::from_read(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let mut writer = str2writer!(file)?;
|
||||
|
||||
rmp_serde::encode::write_named(&mut writer, &item)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let reader = str2reader!(file)?;
|
||||
let state = rmp_serde::decode::from_read(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Ignored;
|
||||
use crate::test_utils::SimpleLinear;
|
||||
use crate::{
|
||||
TestBackend,
|
||||
module::Module,
|
||||
record::{BinBytesRecorder, FullPrecisionSettings},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[inline(always)]
|
||||
fn file_path() -> PathBuf {
|
||||
std::env::temp_dir()
|
||||
.as_path()
|
||||
.join("burn_test_file_recorder")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_jsongz_format() {
|
||||
test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_bin_format() {
|
||||
test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_bingz_format() {
|
||||
test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_pretty_json_format() {
|
||||
test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_mpkgz_format() {
|
||||
test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_mpk_format() {
|
||||
test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder>(recorder: Recorder)
|
||||
where
|
||||
Recorder: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let model_before = create_model(&device);
|
||||
recorder
|
||||
.record(model_before.clone().into_record(), file_path())
|
||||
.unwrap();
|
||||
|
||||
let model_after =
|
||||
create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
|
||||
|
||||
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
|
||||
let model_bytes_before = byte_recorder
|
||||
.record(model_before.into_record(), ())
|
||||
.unwrap();
|
||||
let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();
|
||||
|
||||
assert_eq!(model_bytes_after, model_bytes_before);
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub enum PaddingConfig2d {
|
||||
Same,
|
||||
Valid,
|
||||
Explicit(usize, usize),
|
||||
}
|
||||
|
||||
// Dummy model with different record types
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
linear1: SimpleLinear<B>,
|
||||
phantom: PhantomData<B>,
|
||||
arr: [usize; 2],
|
||||
int: usize,
|
||||
ignore: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {
|
||||
let linear1 = SimpleLinear::new(32, 32, device);
|
||||
|
||||
Model {
|
||||
linear1,
|
||||
phantom: PhantomData,
|
||||
arr: [2, 2],
|
||||
int: 0,
|
||||
ignore: Ignored(PaddingConfig2d::Same),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
use super::{PrecisionSettings, Recorder, RecorderError, bin_config};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::backend::Backend;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from bytes.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is especially useful in no_std environment where weights are stored directly in
|
||||
/// compiled binaries.
|
||||
pub trait BytesRecorder<
|
||||
B: Backend,
|
||||
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
|
||||
>: Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = L>
|
||||
{
|
||||
}
|
||||
|
||||
/// In memory recorder using the [bincode format](bincode).
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct BinBytesRecorder<
|
||||
S: PrecisionSettings,
|
||||
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default = Vec<u8>,
|
||||
> {
|
||||
_settings: core::marker::PhantomData<S>,
|
||||
_loadargs: core::marker::PhantomData<L>,
|
||||
}
|
||||
|
||||
impl<
|
||||
S: PrecisionSettings,
|
||||
B: Backend,
|
||||
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
|
||||
> BytesRecorder<B, L> for BinBytesRecorder<S, L>
|
||||
{
|
||||
}
|
||||
|
||||
impl<
|
||||
S: PrecisionSettings,
|
||||
B: Backend,
|
||||
L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,
|
||||
> Recorder<B> for BinBytesRecorder<S, L>
|
||||
{
|
||||
type Settings = S;
|
||||
type RecordArgs = ();
|
||||
type RecordOutput = Vec<u8>;
|
||||
type LoadArgs = L;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
_args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError> {
|
||||
Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
args: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
let state = bincode::borrow_decode_from_slice::<'_, bincode::serde::BorrowCompat<I>, _>(
|
||||
args.as_ref(),
|
||||
bin_config(),
|
||||
)
|
||||
.unwrap()
|
||||
.0;
|
||||
Ok(state.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// In memory recorder using the [Named MessagePack](rmp_serde).
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
|
||||
_settings: core::marker::PhantomData<S>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl<S: PrecisionSettings, B: Backend> BytesRecorder<B, Vec<u8>> for NamedMpkBytesRecorder<S> {}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = ();
|
||||
type RecordOutput = Vec<u8>;
|
||||
type LoadArgs = Vec<u8>;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
_args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError> {
|
||||
rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))
|
||||
}
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
args: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
rmp_serde::decode::from_slice(args).map_err(|e| RecorderError::Unknown(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test_utils::SimpleLinear;
|
||||
use crate::{
|
||||
TestBackend, module::Module, record::FullPrecisionSettings, tensor::backend::Backend,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_bin_format() {
|
||||
test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_can_save_and_load_named_mpk_format() {
|
||||
test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder>(recorder: Recorder)
|
||||
where
|
||||
Recorder: BytesRecorder<TestBackend, Vec<u8>>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let model1 = create_model::<TestBackend>(&device);
|
||||
let model2 = create_model::<TestBackend>(&device);
|
||||
let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
|
||||
let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
|
||||
|
||||
let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());
|
||||
let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
|
||||
|
||||
assert_ne!(bytes1, bytes2);
|
||||
assert_eq!(bytes1, bytes2_after);
|
||||
}
|
||||
|
||||
pub fn create_model<B: Backend>(device: &B::Device) -> SimpleLinear<B> {
|
||||
SimpleLinear::new(32, 32, device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
mod primitive;
|
||||
mod tensor;
|
||||
|
||||
mod base;
|
||||
mod memory;
|
||||
mod recorder;
|
||||
mod settings;
|
||||
|
||||
pub use base::*;
|
||||
pub use memory::*;
|
||||
pub use recorder::*;
|
||||
pub use settings::*;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod file;
|
||||
#[cfg(feature = "std")]
|
||||
pub use file::*;
|
||||
|
||||
pub use primitive::ParamSerde;
|
||||
|
||||
#[cfg(feature = "record-item-custom-serde")]
|
||||
pub mod serde;
|
||||
@@ -0,0 +1,336 @@
|
||||
use alloc::{string::String, vec, vec::Vec};
|
||||
use core::{fmt, marker::PhantomData};
|
||||
|
||||
use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
|
||||
use super::{PrecisionSettings, Record};
|
||||
use crate::module::{Param, ParamId};
|
||||
|
||||
use burn_tensor::{Bool, Int, Tensor, backend::Backend};
|
||||
|
||||
use hashbrown::HashMap;
|
||||
use serde::{
|
||||
Deserialize, Serialize,
|
||||
de::{Error, SeqAccess, Visitor},
|
||||
ser::SerializeTuple,
|
||||
};
|
||||
|
||||
impl<B> Record<B> for ()
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ();
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
|
||||
}
|
||||
|
||||
impl<T, B> Record<B> for Vec<T>
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.into_iter().map(Record::into_item).collect()
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.into_iter()
|
||||
.map(|i| Record::from_item(i, device))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> Record<B> for Option<T>
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Option<T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.map(Record::into_item)
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.map(|i| Record::from_item(i, device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T, B> Record<B> for [T; N]
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
/// The record item is an array of the record item of the elements.
|
||||
/// The reason why we wrap the array in a struct is because serde does not support
|
||||
/// deserializing arrays of variable size,
|
||||
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
|
||||
/// for backward compatibility reasons. Serde APIs were created before const generics.
|
||||
type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
Array(self.map(Record::into_item))
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.0.map(|i| Record::from_item(i, device))
|
||||
}
|
||||
}
|
||||
|
||||
/// A macro for generating implementations for tuple records of different sizes.
|
||||
/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
|
||||
/// Would generate an implementation for a tuple of size 2.
|
||||
/// For this macro to work properly, please adhere to the convention:
|
||||
/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
|
||||
macro_rules! impl_record_tuple {
|
||||
// `$r` represents the generic records.
|
||||
// `$i` represents the indices of the records in the tuple.
|
||||
([$($r:ident),*][$($i:tt),*]) => {
|
||||
impl<B, $($r,)*> Record<B> for ($($r,)*)
|
||||
where
|
||||
B: Backend,
|
||||
$($r: Record<B>),*
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
($(self.$i.into_item(),)*)
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
($(Record::from_item(item.$i, device),)*)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_record_tuple!([R0, R1][0, 1]);
|
||||
impl_record_tuple!([R0, R1, R2][0, 1, 2]);
|
||||
impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
||||
impl<T, B> Record<B> for HashMap<ParamId, T>
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
let mut items = HashMap::with_capacity(self.len());
|
||||
self.into_iter().for_each(|(id, record)| {
|
||||
items.insert(id.serialize(), record.into_item());
|
||||
});
|
||||
items
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
let mut record = HashMap::with_capacity(item.len());
|
||||
item.into_iter().for_each(|(id, item)| {
|
||||
record.insert(ParamId::deserialize(&id), T::from_item(item, device));
|
||||
});
|
||||
record
|
||||
}
|
||||
}
|
||||
|
||||
/// (De)serialize parameters into a clean format.
|
||||
#[derive(new, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParamSerde<T> {
|
||||
id: String,
|
||||
param: T,
|
||||
}
|
||||
|
||||
impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
let (id, tensor, mapper) = self.consume();
|
||||
let tensor = mapper.on_save(tensor);
|
||||
ParamSerde::new(id.serialize(), tensor.into_item())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
B::memory_persistent_allocations(device, item, |item| {
|
||||
Param::initialized(
|
||||
ParamId::deserialize(&item.id),
|
||||
Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
|
||||
// Param from a tensor.
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
let (id, tensor, mapper) = self.consume();
|
||||
let tensor = mapper.on_save(tensor);
|
||||
ParamSerde::new(id.serialize(), tensor.into_item())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
B::memory_persistent_allocations(device, item, |item| {
|
||||
Param::initialized(
|
||||
ParamId::deserialize(&item.id),
|
||||
Tensor::from_item(item.param, device),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
let (id, tensor, mapper) = self.consume();
|
||||
let tensor = mapper.on_save(tensor);
|
||||
ParamSerde::new(id.serialize(), tensor.into_item::<S>())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
B::memory_persistent_allocations(device, item, |item| {
|
||||
Param::initialized(
|
||||
ParamId::deserialize(&item.id),
|
||||
Tensor::from_item::<S>(item.param, device),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Type that can be serialized as is without any conversion.
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl<B: Backend> Record<B> for $type {
|
||||
type Item<S: PrecisionSettings> = $type;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
item
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// General Types
|
||||
primitive!(alloc::string::String);
|
||||
primitive!(bool);
|
||||
|
||||
// Float Types
|
||||
primitive!(f64);
|
||||
primitive!(f32);
|
||||
|
||||
primitive!(half::bf16);
|
||||
primitive!(half::f16);
|
||||
|
||||
// Unsigned Integer Types
|
||||
primitive!(usize);
|
||||
primitive!(u64);
|
||||
primitive!(u32);
|
||||
primitive!(u16);
|
||||
primitive!(u8);
|
||||
|
||||
// Signed Integer Types
|
||||
primitive!(isize);
|
||||
primitive!(i64);
|
||||
primitive!(i32);
|
||||
primitive!(i16);
|
||||
primitive!(i8);
|
||||
|
||||
/// A wrapper around an array of size N, so that it can be serialized and deserialized
|
||||
/// using serde.
|
||||
///
|
||||
/// The reason why we wrap the array in a struct is because serde does not support
|
||||
/// deserializing arrays of variable size,
|
||||
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
|
||||
/// for backward compatibility reasons. Serde APIs were created before const generics.
|
||||
#[derive(Clone)]
|
||||
pub struct Array<const N: usize, T>([T; N]);
|
||||
|
||||
impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut seq = serializer.serialize_tuple(self.0.len())?;
|
||||
for element in &self.0 {
|
||||
seq.serialize_element(element)?;
|
||||
}
|
||||
seq.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct ArrayVisitor<T, const N: usize> {
|
||||
marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
type Value = Array<N, T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a fixed size array")
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: SeqAccess<'de>,
|
||||
{
|
||||
let mut items = vec![];
|
||||
|
||||
for i in 0..N {
|
||||
let item = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(i, &self))?;
|
||||
items.push(item);
|
||||
}
|
||||
|
||||
let array: [T; N] = items
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.map_err(|_| "An array of size {N}")
|
||||
.unwrap();
|
||||
|
||||
Ok(Array(array))
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_tuple(
|
||||
N,
|
||||
ArrayVisitor {
|
||||
marker: PhantomData,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
use core::any::type_name;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use alloc::format;
|
||||
use alloc::string::{String, ToString};
|
||||
use burn_tensor::backend::Backend;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
|
||||
use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{
|
||||
BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,
|
||||
PrettyJsonFileRecorder,
|
||||
};
|
||||
|
||||
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
|
||||
pub trait Recorder<B: Backend>:
|
||||
Send + Sync + core::default::Default + core::fmt::Debug + Clone
|
||||
{
|
||||
/// Type of the settings used by the recorder.
|
||||
type Settings: PrecisionSettings;
|
||||
|
||||
/// Arguments used to record objects.
|
||||
type RecordArgs: Clone;
|
||||
|
||||
/// Record output type.
|
||||
type RecordOutput;
|
||||
|
||||
/// Arguments used to load recorded objects.
|
||||
type LoadArgs;
|
||||
|
||||
/// Records an item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `record` - The item to record.
|
||||
/// * `args` - Arguments used to record the item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output of the recording.
|
||||
fn record<R>(
|
||||
&self,
|
||||
record: R,
|
||||
args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError>
|
||||
where
|
||||
R: Record<B>,
|
||||
{
|
||||
let item = record.into_item::<Self::Settings>();
|
||||
let item = BurnRecord::new::<Self>(item);
|
||||
|
||||
self.save_item(item, args)
|
||||
}
|
||||
|
||||
/// Load an item from the given arguments.
|
||||
fn load<R>(&self, mut args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
|
||||
where
|
||||
R: Record<B>,
|
||||
{
|
||||
let item: BurnRecord<R::Item<Self::Settings>, B> =
|
||||
self.load_item(&mut args).map_err(|err| {
|
||||
if let Ok(record) = self.load_item::<BurnRecordNoItem>(&mut args) {
|
||||
let mut message = "Unable to load record.".to_string();
|
||||
let metadata = recorder_metadata::<Self, B>();
|
||||
if metadata.float != record.metadata.float {
|
||||
message += format!(
|
||||
"\nMetadata has a different float type: Actual {:?}, Expected {:?}",
|
||||
record.metadata.float, metadata.float
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
if metadata.int != record.metadata.int {
|
||||
message += format!(
|
||||
"\nMetadata has a different int type: Actual {:?}, Expected {:?}",
|
||||
record.metadata.int, metadata.int
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
if metadata.format != record.metadata.format {
|
||||
message += format!(
|
||||
"\nMetadata has a different format: Actual {:?}, Expected {:?}",
|
||||
record.metadata.format, metadata.format
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
if metadata.version != record.metadata.version {
|
||||
message += format!(
|
||||
"\nMetadata has a different Burn version: Actual {:?}, Expected {:?}",
|
||||
record.metadata.version, metadata.version
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
|
||||
message += format!("\nError: {err:?}").as_str();
|
||||
|
||||
return RecorderError::Unknown(message);
|
||||
}
|
||||
|
||||
err
|
||||
})?;
|
||||
|
||||
Ok(R::from_item(item.item, device))
|
||||
}
|
||||
|
||||
/// Saves an item.
|
||||
///
|
||||
/// This method is used by [record](Recorder::record) to save the item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - Item to save.
|
||||
/// * `args` - Arguments to use to save the item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output of the save operation.
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError>;
|
||||
|
||||
/// Loads an item.
|
||||
///
|
||||
/// This method is used by [load](Recorder::load) to load the item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `args` - Arguments to use to load the item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The loaded item.
|
||||
fn load_item<I>(&self, args: &mut Self::LoadArgs) -> Result<I, RecorderError>
|
||||
where
|
||||
I: DeserializeOwned;
|
||||
}
|
||||
|
||||
fn recorder_metadata<R, B>() -> BurnMetadata
|
||||
where
|
||||
R: Recorder<B>,
|
||||
B: Backend,
|
||||
{
|
||||
BurnMetadata::new(
|
||||
type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
|
||||
type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
|
||||
type_name::<R>().to_string(),
|
||||
env!("CARGO_PKG_VERSION").to_string(),
|
||||
format!("{:?}", R::Settings::default()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Error that can occur when using a [Recorder](Recorder).
|
||||
#[derive(Debug)]
|
||||
pub enum RecorderError {
|
||||
/// File not found.
|
||||
FileNotFound(String),
|
||||
|
||||
/// Failed to read file.
|
||||
DeserializeError(String),
|
||||
|
||||
/// Other error.
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for RecorderError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("{self:?}").as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl core::error::Error for RecorderError {}
|
||||
|
||||
pub(crate) fn bin_config() -> bincode::config::Configuration {
|
||||
bincode::config::standard()
|
||||
}
|
||||
|
||||
/// Metadata of a record.
|
||||
#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct BurnMetadata {
|
||||
/// Float type used to record the item.
|
||||
pub float: String,
|
||||
|
||||
/// Int type used to record the item.
|
||||
pub int: String,
|
||||
|
||||
/// Format used to record the item.
|
||||
pub format: String,
|
||||
|
||||
/// Burn record version used to record the item.
|
||||
pub version: String,
|
||||
|
||||
/// Settings used to record the item.
|
||||
pub settings: String,
|
||||
}
|
||||
|
||||
/// Record that can be saved by a [Recorder](Recorder).
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct BurnRecord<I, B: Backend> {
|
||||
/// Metadata of the record.
|
||||
pub metadata: BurnMetadata,
|
||||
|
||||
/// Item to record.
|
||||
pub item: I,
|
||||
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<I, B: Backend> BurnRecord<I, B> {
|
||||
/// Creates a new record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - Item to record.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The new record.
|
||||
pub fn new<R: Recorder<B>>(item: I) -> Self {
|
||||
let metadata = recorder_metadata::<R, B>();
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
item,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record that can be saved by a [Recorder](Recorder) without the item.
|
||||
#[derive(new, Debug, Serialize, Deserialize)]
|
||||
pub struct BurnRecordNoItem {
|
||||
/// Metadata of the record.
|
||||
pub metadata: BurnMetadata,
|
||||
}
|
||||
|
||||
/// Default recorder.
|
||||
///
|
||||
/// It uses the [named msgpack](rmp_serde) format for serialization with full precision.
|
||||
#[cfg(feature = "std")]
|
||||
pub type DefaultRecorder = DefaultFileRecorder<FullPrecisionSettings>;
|
||||
|
||||
/// Recorder optimized for compactness.
|
||||
///
|
||||
/// It uses the [named msgpack](rmp_serde) format for serialization with half precision.
|
||||
/// If you are looking for the recorder that offers the smallest file size, have a look at
|
||||
/// [sensitive compact recorder](SensitiveCompactRecorder).
|
||||
#[cfg(feature = "std")]
|
||||
pub type CompactRecorder = DefaultFileRecorder<HalfPrecisionSettings>;
|
||||
|
||||
/// Recorder optimized for compactness making it a good choice for model deployment.
|
||||
///
|
||||
/// It uses the [bincode](bincode) format for serialization and half precision.
|
||||
/// This format is not resilient to type changes since no metadata is encoded.
|
||||
/// Favor [default recorder](DefaultRecorder) or [compact recorder](CompactRecorder)
|
||||
/// for long term data storage.
|
||||
#[cfg(feature = "std")]
|
||||
pub type SensitiveCompactRecorder = BinGzFileRecorder<HalfPrecisionSettings>;
|
||||
|
||||
/// Training recorder compatible with no-std inference.
|
||||
#[cfg(feature = "std")]
|
||||
pub type NoStdTrainingRecorder = BinFileRecorder<FullPrecisionSettings>;
|
||||
|
||||
/// Inference recorder compatible with no-std.
|
||||
pub type NoStdInferenceRecorder = BinBytesRecorder<FullPrecisionSettings, &'static [u8]>;
|
||||
|
||||
/// Debug recorder.
|
||||
///
|
||||
/// It uses the [pretty json](serde_json) format for serialization with full precision making it
|
||||
/// human readable.
|
||||
#[cfg(feature = "std")]
|
||||
pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
static FILE_PATH: &str = "/tmp/burn_test_record";
|
||||
|
||||
use crate::TestBackend;
|
||||
|
||||
use super::*;
|
||||
use burn_tensor::{Device, ElementConversion};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn err_when_invalid_item() {
|
||||
#[derive(new, Serialize, Deserialize, Clone)]
|
||||
struct Item<S: PrecisionSettings> {
|
||||
value: S::FloatElem,
|
||||
}
|
||||
|
||||
impl<D, B> Record<B> for Item<D>
|
||||
where
|
||||
D: PrecisionSettings,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Item<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
Item {
|
||||
value: self.value.elem(),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
Item {
|
||||
value: item.value.elem(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let item = Item::<FullPrecisionSettings>::new(16.elem());
|
||||
let device: Device<TestBackend> = Default::default();
|
||||
|
||||
// Serialize in f32.
|
||||
let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
|
||||
Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();
|
||||
|
||||
// Can't deserialize f32 into f16.
|
||||
let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
|
||||
Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(
|
||||
&recorder,
|
||||
FILE_PATH.into(),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
use super::data::NestedValue;
|
||||
|
||||
/// A trait that defines the adapter for a Burn module.
|
||||
///
|
||||
/// This is used to adapt an incoming module to a Burn module.
|
||||
pub trait BurnModuleAdapter: Sized {
|
||||
/// Adapts a module.
|
||||
fn adapt(name: &str, data: NestedValue) -> NestedValue {
|
||||
match name {
|
||||
"BatchNorm" => Self::adapt_batch_norm(data),
|
||||
"Conv1d" => Self::adapt_conv1d(data),
|
||||
"Conv2d" => Self::adapt_conv2d(data),
|
||||
"Conv3d" => Self::adapt_conv3d(data),
|
||||
"ConvTranspose1d" => Self::adapt_conv_transpose_1d(data),
|
||||
"ConvTranspose2d" => Self::adapt_conv_transpose_2d(data),
|
||||
"ConvTranspose3d" => Self::adapt_conv_transpose_3d(data),
|
||||
"Embedding" => Self::adapt_embedding(data),
|
||||
"GroupNorm" => Self::adapt_group_norm(data),
|
||||
"LayerNorm" => Self::adapt_layer_norm(data),
|
||||
"Linear" => Self::adapt_linear(data),
|
||||
_ => data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapts a linear module.
|
||||
fn adapt_linear(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts a Convolution 1D module.
|
||||
fn adapt_conv1d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts a Convolution 2D module.
|
||||
fn adapt_conv2d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts a Convolution 3D module.
|
||||
fn adapt_conv3d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts convolution transpose 1D module.
|
||||
fn adapt_conv_transpose_1d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts convolution transpose 2D module.
|
||||
fn adapt_conv_transpose_2d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts convolution transpose 2D module.
|
||||
fn adapt_conv_transpose_3d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts embedding module.
|
||||
fn adapt_embedding(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts group normalization module.
|
||||
fn adapt_group_norm(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts layer normalization module.
|
||||
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts batch normalization module.
|
||||
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
}
|
||||
|
||||
/// Default adapter that takes no action.
|
||||
pub struct DefaultAdapter;
|
||||
impl BurnModuleAdapter for DefaultAdapter {}
|
||||
@@ -0,0 +1,399 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::adapter::BurnModuleAdapter;
|
||||
use super::de::Deserializer;
|
||||
use super::error::Error;
|
||||
use super::ser::Serializer;
|
||||
use crate::record::{PrecisionSettings, Record};
|
||||
use crate::tensor::backend::Backend;
|
||||
|
||||
use alloc::fmt;
|
||||
use burn_tensor::Bytes;
|
||||
use num_traits::cast::ToPrimitive;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
|
||||
/// The main data structure used for deserialization.
|
||||
///
|
||||
/// It can hold tree-like structures of nested maps and vectors.
|
||||
#[derive(Clone)]
|
||||
pub enum NestedValue {
|
||||
/// The default value, which actually does not hold any value and it is used to indicate that
|
||||
/// the value should be populated with the default value. It contains an optional string with
|
||||
/// the originator field name.
|
||||
Default(Option<String>),
|
||||
|
||||
/// A boolean value.
|
||||
Bool(bool),
|
||||
|
||||
/// A string value.
|
||||
String(String),
|
||||
|
||||
/// Floating point 32-bit value.
|
||||
F32(f32),
|
||||
|
||||
/// Floating point 64-bit value.
|
||||
F64(f64),
|
||||
|
||||
/// Signed 16-bit integer value.
|
||||
I16(i16),
|
||||
|
||||
/// Signed 32-bit integer value.
|
||||
I32(i32),
|
||||
|
||||
/// Signed 64-bit integer value.
|
||||
I64(i64),
|
||||
|
||||
/// Unsigned 8-bit integer value.
|
||||
U8(u8),
|
||||
|
||||
/// Unsigned 16-bit integer value used for bf16 and f16 serialization
|
||||
U16(u16),
|
||||
|
||||
/// Unsigned 64-bit integer value.
|
||||
U64(u64),
|
||||
|
||||
/// A map of nested values (typically used for structs)
|
||||
Map(HashMap<String, NestedValue>),
|
||||
|
||||
/// A vector of nested values (typically used for vector of structs or numbers)
|
||||
Vec(Vec<NestedValue>),
|
||||
|
||||
/// A vector of 8-bit unsigned integer values.
|
||||
U8s(Vec<u8>),
|
||||
|
||||
/// A vector of 16-bit unsigned integer values.
|
||||
U16s(Vec<u16>),
|
||||
|
||||
/// A vector of 32-bit floating point values.
|
||||
F32s(Vec<f32>),
|
||||
|
||||
/// An opaque vector of bytes, with alignment.
|
||||
Bytes(Bytes),
|
||||
}
|
||||
|
||||
impl NestedValue {
|
||||
/// Get the nested value as a map.
|
||||
pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {
|
||||
match self {
|
||||
NestedValue::Map(map) => Some(map),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a boolean.
|
||||
pub fn as_bool(self) -> Option<bool> {
|
||||
match self {
|
||||
NestedValue::Bool(bool) => Some(bool),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a string.
|
||||
pub fn as_string(self) -> Option<String> {
|
||||
match self {
|
||||
NestedValue::String(string) => Some(string),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a f32.
|
||||
pub fn as_f32(self) -> Option<f32> {
|
||||
match self {
|
||||
NestedValue::F32(f32) => Some(f32),
|
||||
NestedValue::F64(f) => f.to_f32(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a f64.
|
||||
pub fn as_f64(self) -> Option<f64> {
|
||||
match self {
|
||||
NestedValue::F64(f64) => Some(f64),
|
||||
NestedValue::F32(f) => f.to_f64(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as an i16.
|
||||
pub fn as_i16(self) -> Option<i16> {
|
||||
match self {
|
||||
NestedValue::I16(i16) => Some(i16),
|
||||
NestedValue::I32(i) => i.to_i16(),
|
||||
NestedValue::I64(i) => i.to_i16(),
|
||||
NestedValue::U16(u) => u.to_i16(),
|
||||
NestedValue::U64(u) => u.to_i16(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as an i32.
|
||||
pub fn as_i32(self) -> Option<i32> {
|
||||
match self {
|
||||
NestedValue::I32(i32) => Some(i32),
|
||||
NestedValue::I16(i) => i.to_i32(),
|
||||
NestedValue::I64(i) => i.to_i32(),
|
||||
NestedValue::U16(u) => u.to_i32(),
|
||||
NestedValue::U64(u) => u.to_i32(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as an i64.
|
||||
pub fn as_i64(self) -> Option<i64> {
|
||||
match self {
|
||||
NestedValue::I64(i64) => Some(i64),
|
||||
NestedValue::I16(i) => i.to_i64(),
|
||||
NestedValue::I32(i) => i.to_i64(),
|
||||
NestedValue::U16(u) => u.to_i64(),
|
||||
NestedValue::U64(u) => u.to_i64(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a u8.
|
||||
pub fn as_u8(self) -> Option<u8> {
|
||||
match self {
|
||||
NestedValue::U8(u8) => Some(u8),
|
||||
NestedValue::I16(i) => i.to_u8(),
|
||||
NestedValue::I32(i) => i.to_u8(),
|
||||
NestedValue::I64(i) => i.to_u8(),
|
||||
NestedValue::U16(u) => u.to_u8(),
|
||||
NestedValue::U64(u) => u.to_u8(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a u16.
|
||||
pub fn as_u16(self) -> Option<u16> {
|
||||
match self {
|
||||
NestedValue::U16(u16) => Some(u16),
|
||||
NestedValue::I16(i) => i.to_u16(),
|
||||
NestedValue::I32(i) => i.to_u16(),
|
||||
NestedValue::I64(i) => i.to_u16(),
|
||||
NestedValue::U64(u) => u.to_u16(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a u64.
|
||||
pub fn as_u64(self) -> Option<u64> {
|
||||
match self {
|
||||
NestedValue::U64(u64) => Some(u64),
|
||||
NestedValue::I16(i) => i.to_u64(),
|
||||
NestedValue::I32(i) => i.to_u64(),
|
||||
NestedValue::I64(i) => i.to_u64(),
|
||||
NestedValue::U16(u) => u.to_u64(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a vector of bytes.
|
||||
pub fn as_bytes(self) -> Option<Bytes> {
|
||||
match self {
|
||||
NestedValue::Bytes(u) => Some(u),
|
||||
NestedValue::U8s(u) => Some(Bytes::from_elems(u)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize a nested value into a record type.
|
||||
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
|
||||
where
|
||||
B: Backend,
|
||||
T: Record<B>,
|
||||
PS: PrecisionSettings,
|
||||
A: BurnModuleAdapter,
|
||||
{
|
||||
let deserializer = Deserializer::<A>::new(self, false);
|
||||
|
||||
let item = T::Item::deserialize(deserializer)?;
|
||||
|
||||
// Convert the deserialized item into a Record instance
|
||||
Ok(T::from_item::<PS>(item, device))
|
||||
}
|
||||
}
|
||||
|
||||
/// Remap the tensor locations according to the key remapping.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - A map of tensors.
|
||||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
||||
/// for more information.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A map of tensors with the remapped keys and
|
||||
/// a vector of tuples containing the remapped and original.
|
||||
pub fn remap<T>(
|
||||
mut tensors: HashMap<String, T>,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
) -> (HashMap<String, T>, Vec<(String, String)>) {
|
||||
if key_remap.is_empty() {
|
||||
let remapped_names = tensors
|
||||
.keys()
|
||||
.cloned()
|
||||
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
|
||||
.collect();
|
||||
return (tensors, remapped_names);
|
||||
}
|
||||
|
||||
let mut remapped = HashMap::new();
|
||||
let mut remapped_names = Vec::new();
|
||||
|
||||
for (name, tensor) in tensors.drain() {
|
||||
let mut new_name = name.clone();
|
||||
for (pattern, replacement) in &key_remap {
|
||||
if pattern.is_match(&new_name) {
|
||||
new_name = pattern
|
||||
.replace_all(&new_name, replacement.as_str())
|
||||
.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
remapped_names.push((new_name.clone(), name));
|
||||
remapped.insert(new_name, tensor);
|
||||
}
|
||||
|
||||
(remapped, remapped_names)
|
||||
}
|
||||
|
||||
/// Helper function to insert a value into a nested map/vector of tensors.
|
||||
fn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) {
|
||||
if keys.is_empty() {
|
||||
*current = value;
|
||||
return;
|
||||
}
|
||||
|
||||
match current {
|
||||
NestedValue::Map(map) => {
|
||||
if !map.contains_key(keys[0]) {
|
||||
let next = if keys[1..]
|
||||
.first()
|
||||
.and_then(|k| k.parse::<usize>().ok())
|
||||
.is_some()
|
||||
{
|
||||
NestedValue::Vec(Vec::new())
|
||||
} else {
|
||||
NestedValue::Map(HashMap::new())
|
||||
};
|
||||
map.insert(keys[0].to_string(), next);
|
||||
}
|
||||
insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value);
|
||||
}
|
||||
NestedValue::Vec(vec) => {
|
||||
let index = keys[0].parse::<usize>().unwrap();
|
||||
if index >= vec.len() {
|
||||
vec.resize_with(index + 1, || NestedValue::Map(HashMap::new()));
|
||||
}
|
||||
insert_nested_value(&mut vec[index], &keys[1..], value);
|
||||
}
|
||||
_ => panic!("Invalid structure encountered"),
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for encapsulating the serialization logic.
|
||||
pub trait Serializable {
|
||||
/// Serializes the object into a `NestedValue` using the provided `Serializer`.
|
||||
/// This method is generic over the precision settings `PS`.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `serializer`: The `Serializer` to use for serializing the object.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Result<NestedValue, Error>`: The result of serialization.
|
||||
/// Returns a `NestedValue` on success,
|
||||
/// or an `Error` on failure.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `PS`: The precision settings to use during serialization.
|
||||
/// This is a generic parameter and can be any type
|
||||
/// that implements the `PrecisionSettings` trait.
|
||||
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, Error>
|
||||
where
|
||||
PS: PrecisionSettings;
|
||||
}
|
||||
|
||||
/// Convert a vector of tensors to a nested value.
|
||||
pub fn unflatten<PS, T>(input: HashMap<String, T>) -> Result<NestedValue, Error>
|
||||
where
|
||||
PS: PrecisionSettings,
|
||||
T: Serializable,
|
||||
{
|
||||
let mut result = NestedValue::Map(HashMap::new());
|
||||
|
||||
for (key, value) in input {
|
||||
let parts: Vec<&str> = key.split('.').collect();
|
||||
let st = value.serialize::<PS>(Serializer::new())?;
|
||||
|
||||
insert_nested_value(&mut result, &parts, st);
|
||||
}
|
||||
|
||||
cleanup_empty_maps(&mut result);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Removes empty maps from the nested value.
|
||||
///
|
||||
/// We need to clean up empty maps from the nested value
|
||||
/// in some cases when there is non-contiguous indices in keys.
|
||||
fn cleanup_empty_maps(current: &mut NestedValue) {
|
||||
match current {
|
||||
NestedValue::Map(map) => {
|
||||
map.values_mut().for_each(cleanup_empty_maps);
|
||||
}
|
||||
NestedValue::Vec(vec) => {
|
||||
vec.iter_mut().for_each(cleanup_empty_maps);
|
||||
vec.retain(|v| !matches!(v, NestedValue::Map(m) if m.is_empty()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_vec_truncated<T: core::fmt::Debug>(
|
||||
vec: &[T],
|
||||
f: &mut core::fmt::Formatter,
|
||||
) -> fmt::Result {
|
||||
write!(f, "Vec([")?;
|
||||
for (i, v) in vec.iter().take(3).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{v:?}")?;
|
||||
}
|
||||
write!(f, ", ...] len={})", vec.len())
|
||||
}
|
||||
|
||||
impl fmt::Debug for NestedValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
// Truncate values for vector
|
||||
NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
|
||||
NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
|
||||
NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
|
||||
NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
|
||||
NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f),
|
||||
// Handle other variants as usual
|
||||
NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(),
|
||||
NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(),
|
||||
NestedValue::String(s) => f.debug_tuple("String").field(s).finish(),
|
||||
NestedValue::F32(val) => f.debug_tuple("F32").field(val).finish(),
|
||||
NestedValue::F64(val) => f.debug_tuple("F64").field(val).finish(),
|
||||
NestedValue::I16(val) => f.debug_tuple("I16").field(val).finish(),
|
||||
NestedValue::I32(val) => f.debug_tuple("I32").field(val).finish(),
|
||||
NestedValue::I64(val) => f.debug_tuple("I64").field(val).finish(),
|
||||
NestedValue::U8(val) => f.debug_tuple("U8").field(val).finish(),
|
||||
NestedValue::U16(val) => f.debug_tuple("U16").field(val).finish(),
|
||||
NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(),
|
||||
NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(),
|
||||
NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(),
|
||||
NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(),
|
||||
NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),
|
||||
NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),
|
||||
NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,40 @@
|
||||
use crate::record::RecorderError;
|
||||
|
||||
/// The error type for Record serde.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Failed to deserialize.
|
||||
#[error("failed to deserialize: {0}")]
|
||||
Deserialize(#[from] serde::de::value::Error),
|
||||
|
||||
/// Failed to serialize.
|
||||
#[error("failed to serialize")]
|
||||
Serialize(String),
|
||||
|
||||
/// Encountered an invalid state.
|
||||
#[error("invalid state")]
|
||||
InvalidState,
|
||||
|
||||
/// Other error.
|
||||
#[error("other error: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl serde::de::Error for Error {
|
||||
fn custom<T: std::fmt::Display>(msg: T) -> Self {
|
||||
Error::Deserialize(serde::de::value::Error::custom(msg.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::ser::Error for Error {
|
||||
fn custom<T: std::fmt::Display>(msg: T) -> Self {
|
||||
Error::Serialize(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Implement From trait for Error to RecorderError
|
||||
impl From<Error> for RecorderError {
|
||||
fn from(error: Error) -> Self {
|
||||
RecorderError::DeserializeError(error.to_string())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
//! Module contains the serde implementation for the record module
|
||||
//! useful for custom importing model weights, such as PyTorch's pt file format.
|
||||
|
||||
/// The adapter trait that is used to convert the nested value to the module type.
|
||||
pub mod adapter;
|
||||
|
||||
/// The main data structure used for deserialization.
|
||||
pub mod data;
|
||||
|
||||
/// The deserializer that is used to convert the nested value to the record.
|
||||
pub mod ser;
|
||||
|
||||
/// The deserializer that is used to convert the nested value to the record.
|
||||
pub mod de;
|
||||
|
||||
/// Error types.
|
||||
pub mod error;
|
||||
@@ -0,0 +1,387 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{
|
||||
data::NestedValue,
|
||||
error::{self, Error},
|
||||
};
|
||||
|
||||
use serde::{
|
||||
Serialize,
|
||||
ser::{self, SerializeSeq, SerializeStruct, Serializer as SerializerTrait},
|
||||
};
|
||||
|
||||
/// Simple struct serializer that converts a struct into NestedValues.
|
||||
///
|
||||
/// NOTE: This is used to serialize Param structs into NestedValues and not so much for
|
||||
/// the actual serialization of modules (although it could be used for that as well if all
|
||||
/// primitive types are implemented).
|
||||
#[derive(Clone)]
|
||||
pub struct Serializer {
|
||||
/// The state of the serialization process
|
||||
state: Option<NestedValue>,
|
||||
}
|
||||
|
||||
impl Serializer {
|
||||
/// Creates a new serializer.
|
||||
pub fn new() -> Self {
|
||||
Serializer { state: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Serializer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializerTrait for Serializer {
|
||||
type Ok = NestedValue;
|
||||
type Error = Error;
|
||||
type SerializeSeq = Self;
|
||||
type SerializeTuple = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeTupleStruct = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeTupleVariant = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeMap = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeStruct = Self;
|
||||
type SerializeStructVariant = ser::Impossible<NestedValue, Self::Error>;
|
||||
|
||||
fn serialize_struct(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeStruct, Self::Error> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn serialize_newtype_struct<T>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
value: &T,
|
||||
) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
value.serialize(self)
|
||||
}
|
||||
|
||||
fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::I32(v))
|
||||
}
|
||||
|
||||
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::String(v.to_string()))
|
||||
}
|
||||
|
||||
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::I16(v))
|
||||
}
|
||||
|
||||
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::I64(v))
|
||||
}
|
||||
|
||||
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U16(v))
|
||||
}
|
||||
|
||||
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U64(v))
|
||||
}
|
||||
|
||||
fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::F32(v))
|
||||
}
|
||||
|
||||
fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::F64(v))
|
||||
}
|
||||
|
||||
// The following methods are not implemented because they are not needed for the
|
||||
// serialization of Param structs.
|
||||
|
||||
fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U8s(v.to_vec()))
|
||||
}
|
||||
|
||||
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::Default(None))
|
||||
}
|
||||
fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U8(v))
|
||||
}
|
||||
|
||||
fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
value.serialize(self)
|
||||
}
|
||||
|
||||
fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_unit_variant(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::Map(HashMap::from([(
|
||||
_name.to_string(),
|
||||
NestedValue::String(_variant.to_string()),
|
||||
)])))
|
||||
}
|
||||
|
||||
fn serialize_newtype_variant<T>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
_value: &T,
|
||||
) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_tuple_struct(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeTupleStruct, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_tuple_variant(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeTupleVariant, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_struct_variant(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeStructVariant, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
// Implementing the SerializeStruct trait for Serializer
|
||||
impl SerializeStruct for Serializer {
|
||||
type Ok = NestedValue;
|
||||
type Error = Error;
|
||||
|
||||
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
let serialized_value = value.serialize(Serializer::new())?;
|
||||
|
||||
match self.state {
|
||||
Some(NestedValue::Map(ref mut map)) => {
|
||||
map.insert(key.to_string(), serialized_value); // Inserting into the state
|
||||
}
|
||||
Some(_) => {
|
||||
panic!("Invalid state encountered");
|
||||
}
|
||||
None => {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(key.to_string(), serialized_value); // Inserting into the state
|
||||
self.state = Some(NestedValue::Map(map));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
if self.state.is_none() {
|
||||
// If the state is empty, return an empty map
|
||||
Ok(NestedValue::Map(HashMap::new()))
|
||||
} else {
|
||||
self.state.ok_or(error::Error::InvalidState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeSeq for Serializer {
|
||||
type Ok = NestedValue;
|
||||
type Error = Error;
|
||||
|
||||
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
let serialized_value = value.serialize(Serializer::new())?;
|
||||
|
||||
match self.state {
|
||||
Some(NestedValue::Vec(ref mut vec)) => {
|
||||
vec.push(serialized_value); // Inserting into the state
|
||||
}
|
||||
Some(NestedValue::U8s(ref mut vec)) => {
|
||||
if let NestedValue::U8(val) = serialized_value {
|
||||
vec.push(val);
|
||||
} else {
|
||||
panic!("Invalid value type encountered");
|
||||
}
|
||||
}
|
||||
Some(NestedValue::U16s(ref mut vec)) => {
|
||||
if let NestedValue::U16(val) = serialized_value {
|
||||
vec.push(val);
|
||||
} else {
|
||||
panic!("Invalid value type encountered");
|
||||
}
|
||||
}
|
||||
Some(NestedValue::F32s(ref mut vec)) => {
|
||||
if let NestedValue::F32(val) = serialized_value {
|
||||
vec.push(val);
|
||||
} else {
|
||||
panic!("Invalid value type encountered");
|
||||
}
|
||||
}
|
||||
Some(_) => {
|
||||
panic!("Invalid state encountered");
|
||||
}
|
||||
None => {
|
||||
let val = match serialized_value {
|
||||
NestedValue::U8(val) => NestedValue::U8s(vec![val]),
|
||||
NestedValue::U16(val) => NestedValue::U16s(vec![val]),
|
||||
NestedValue::F32(val) => NestedValue::F32s(vec![val]),
|
||||
_ => NestedValue::Vec(vec![serialized_value]),
|
||||
};
|
||||
self.state = Some(val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
if self.state.is_none() {
|
||||
// If the state is empty, return an empty vector
|
||||
Ok(NestedValue::Vec(Vec::new()))
|
||||
} else {
|
||||
self.state.ok_or(error::Error::InvalidState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
TestBackend,
|
||||
module::{Param, ParamId},
|
||||
record::{FullPrecisionSettings, Record},
|
||||
tensor::Tensor,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct MyStruct1 {
|
||||
a: MyStruct3,
|
||||
b: MyStruct2,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct MyStruct2 {
|
||||
a: i32,
|
||||
b: Option<i32>,
|
||||
c: String,
|
||||
d: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct MyStruct3 {
|
||||
x: String,
|
||||
y: String,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize() {
|
||||
let my_struct = MyStruct1 {
|
||||
a: MyStruct3 {
|
||||
x: "Hello".to_owned(),
|
||||
y: "World".to_owned(),
|
||||
},
|
||||
b: MyStruct2 {
|
||||
a: 1,
|
||||
b: None,
|
||||
c: "Hello".to_owned(),
|
||||
d: Some("World".to_owned()),
|
||||
},
|
||||
};
|
||||
|
||||
let serialized = my_struct
|
||||
.serialize(Serializer::new())
|
||||
.expect("Should serialize item successfully");
|
||||
|
||||
let serialized_str = format!("{serialized:?}");
|
||||
|
||||
// Compare the lengths of expected and actual serialized strings because
|
||||
// the order of the fields is not guaranteed for HashMaps.
|
||||
assert_eq!(serialized_str.len(), 135);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_param_serde() {
|
||||
let device = Default::default();
|
||||
let tensor: Tensor<TestBackend, 2> = Tensor::ones([2, 2], &device);
|
||||
let param = Param::initialized(ParamId::new(), tensor);
|
||||
let param_item = param.into_item::<FullPrecisionSettings>();
|
||||
|
||||
let serialized = param_item
|
||||
.serialize(Serializer::new())
|
||||
.expect("Should serialize item successfully");
|
||||
|
||||
let bytes = serialized.as_map().expect("is a map")["param"]
|
||||
.clone()
|
||||
.as_map()
|
||||
.expect("param is a map")["bytes"]
|
||||
.clone()
|
||||
.as_bytes()
|
||||
.expect("has bytes vec");
|
||||
assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
use burn_tensor::Element;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
/// Settings allowing to control the precision when (de)serializing items.
|
||||
pub trait PrecisionSettings:
|
||||
Send + Sync + core::fmt::Debug + core::default::Default + Clone
|
||||
{
|
||||
/// Float element type.
|
||||
type FloatElem: Element + Serialize + DeserializeOwned;
|
||||
|
||||
/// Integer element type.
|
||||
type IntElem: Element + Serialize + DeserializeOwned;
|
||||
}
|
||||
|
||||
/// Default precision settings.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct FullPrecisionSettings;
|
||||
|
||||
/// Precision settings optimized for compactness.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct HalfPrecisionSettings;
|
||||
|
||||
/// Precision settings optimized for precision.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct DoublePrecisionSettings;
|
||||
|
||||
impl PrecisionSettings for FullPrecisionSettings {
|
||||
type FloatElem = f32;
|
||||
type IntElem = i32;
|
||||
}
|
||||
|
||||
impl PrecisionSettings for DoublePrecisionSettings {
|
||||
type FloatElem = f64;
|
||||
type IntElem = i64;
|
||||
}
|
||||
|
||||
impl PrecisionSettings for HalfPrecisionSettings {
|
||||
type FloatElem = half::f16;
|
||||
type IntElem = i16;
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::{PrecisionSettings, Record};
|
||||
use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use alloc::format;
|
||||
|
||||
/// Deserialize the value into [`TensorData`].
|
||||
fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
|
||||
where
|
||||
E: Element + Deserialize<'de>,
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = TensorData::deserialize(deserializer).map_err(|e| {
|
||||
serde::de::Error::custom(format!(
|
||||
"{e:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n"
|
||||
))
|
||||
})?;
|
||||
let data = if let DType::QFloat(_) = data.dtype {
|
||||
data // do not convert quantized tensors
|
||||
} else {
|
||||
data.convert::<E>()
|
||||
};
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize a float tensor
|
||||
/// using the given [record settings](RecordSettings).
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct FloatTensorSerde<S: PrecisionSettings> {
|
||||
data: TensorData,
|
||||
_e: PhantomData<S::FloatElem>,
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize an int tensor
|
||||
/// using the given [record settings](RecordSettings).
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct IntTensorSerde<S: PrecisionSettings> {
|
||||
data: TensorData,
|
||||
_e: PhantomData<S::IntElem>,
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize an bool tensor.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct BoolTensorSerde {
|
||||
data: TensorData,
|
||||
}
|
||||
|
||||
// --- SERDE IMPLEMENTATIONS --- //
|
||||
|
||||
impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
|
||||
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
|
||||
where
|
||||
Se: serde::Serializer,
|
||||
{
|
||||
self.data.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
|
||||
|
||||
Ok(Self::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
|
||||
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
|
||||
where
|
||||
Se: serde::Serializer,
|
||||
{
|
||||
self.data.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = deserialize_data::<S::IntElem, De>(deserializer)?;
|
||||
|
||||
Ok(Self::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for BoolTensorSerde {
|
||||
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
|
||||
where
|
||||
Se: serde::Serializer,
|
||||
{
|
||||
self.data.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for BoolTensorSerde {
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = deserialize_data::<bool, De>(deserializer)?;
|
||||
|
||||
Ok(Self::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
// --- RECORD IMPLEMENTATIONS --- //
|
||||
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
|
||||
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
let data = self.into_data();
|
||||
let data = if let DType::QFloat(_) = data.dtype {
|
||||
data // do not convert quantized tensors
|
||||
} else {
|
||||
data.convert::<S::FloatElem>()
|
||||
};
|
||||
FloatTensorSerde::new(data)
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
let data = if let DType::QFloat(_) = item.data.dtype {
|
||||
item.data // do not convert quantized tensors
|
||||
} else {
|
||||
item.data.convert::<B::FloatElem>()
|
||||
};
|
||||
Tensor::from_data(data, device)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
|
||||
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Tensor::from_data(item.data.convert::<B::IntElem>(), device)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
|
||||
type Item<S: PrecisionSettings> = BoolTensorSerde;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
BoolTensorSerde::new(self.into_data())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Tensor::from_data(item.data, device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
pub use burn_tensor::*;
|
||||
@@ -0,0 +1 @@
|
||||
pub use burn_vision::*;
|
||||
@@ -0,0 +1,113 @@
|
||||
use burn::config::{Config, config_to_json};
|
||||
use burn_core as burn;
|
||||
|
||||
#[derive(Config, Debug, PartialEq, Eq)]
|
||||
pub struct TestEmptyStructConfig {}
|
||||
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub struct TestStructConfig {
|
||||
int: i32,
|
||||
#[config(default = 2)]
|
||||
int_default: i32,
|
||||
float: f32,
|
||||
#[config(default = 2.0)]
|
||||
float_default: f32,
|
||||
string: String,
|
||||
other_config: TestEmptyStructConfig,
|
||||
}
|
||||
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum TestEnumConfig {
|
||||
None,
|
||||
Single(f32),
|
||||
Multiple(f32, String),
|
||||
Named { first: f32, second: String },
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[inline(always)]
|
||||
fn file_path(file_name: &str) -> std::path::PathBuf {
|
||||
std::env::temp_dir().join(file_name)
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn struct_config_should_impl_serde() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
let file_path = file_path("test_struct_config.json");
|
||||
|
||||
config.save(&file_path).unwrap();
|
||||
|
||||
let config_loaded = TestStructConfig::load(&file_path).unwrap();
|
||||
assert_eq!(config, config_loaded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_config_should_impl_clone() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
assert_eq!(config, config.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_config_should_impl_display() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn enum_config_no_value_should_impl_serde() {
|
||||
let config = TestEnumConfig::None;
|
||||
let file_path = file_path("test_enum_no_value_config.json");
|
||||
|
||||
config.save(&file_path).unwrap();
|
||||
|
||||
let config_loaded = TestEnumConfig::load(&file_path).unwrap();
|
||||
assert_eq!(config, config_loaded);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn enum_config_one_value_should_impl_serde() {
|
||||
let config = TestEnumConfig::Single(42.0);
|
||||
let file_path = file_path("test_enum_one_value_config.json");
|
||||
|
||||
config.save(&file_path).unwrap();
|
||||
|
||||
let config_loaded = TestEnumConfig::load(&file_path).unwrap();
|
||||
assert_eq!(config, config_loaded);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn enum_config_multiple_values_should_impl_serde() {
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
|
||||
let file_path = file_path("test_enum_multiple_values_config.json");
|
||||
|
||||
config.save(&file_path).unwrap();
|
||||
|
||||
let config_loaded = TestEnumConfig::load(&file_path).unwrap();
|
||||
assert_eq!(config, config_loaded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enum_config_should_impl_clone() {
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
|
||||
assert_eq!(config, config.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enum_config_should_impl_display() {
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
|
||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_config_can_load_binary() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
|
||||
let binary = config_to_json(&config).as_bytes().to_vec();
|
||||
|
||||
let config_loaded = TestStructConfig::load_binary(&binary).unwrap();
|
||||
assert_eq!(config, config_loaded);
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn::module::Initializer;
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Int, Tensor};
|
||||
use burn_core as burn;
|
||||
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
#[cfg(feature = "std")]
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModuleBasic<B: Backend> {
|
||||
weight_basic: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
struct ModuleTensorConstInt<B: Backend> {
|
||||
weight_basic: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleBasic<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
weight_basic: Initializer::Normal {
|
||||
std: 1.0,
|
||||
mean: 0.0,
|
||||
}
|
||||
.init([20, 20], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct ModuleWithConstGeneric<B: Backend, const N: usize> {
|
||||
modules: [ModuleBasic<B>; N],
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct ModuleWithGenericModule<B: Backend, M> {
|
||||
module: M,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
enum ModuleEnum<B: Backend> {
|
||||
Basic(ModuleBasic<B>),
|
||||
Composed(ModuleComposed<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
enum ModuleEnumNested<B: Backend> {
|
||||
AnotherEnum(ModuleEnum<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
enum ModuleEnumWithGenericModule<B: Backend, M: Module<B>> {
|
||||
Basic(ModuleBasic<B>),
|
||||
Generic(ModuleWithGenericModule<B, M>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModuleComposed<B: Backend> {
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
basic: ModuleBasic<B>,
|
||||
tuple: (ModuleBasic<B>, ModuleBasic<B>),
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleComposed<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
let weight = Initializer::Normal {
|
||||
std: 1.0,
|
||||
mean: 0.0,
|
||||
}
|
||||
.init([20, 20], device);
|
||||
|
||||
Self {
|
||||
weight,
|
||||
basic: ModuleBasic::new(device),
|
||||
tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
mod compiletime_clone_impl_check {
|
||||
use burn_core::{
|
||||
module::{Module, ModuleDisplay},
|
||||
prelude::Backend,
|
||||
record::{PrecisionSettings, Record},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
type RecordItem<M, B, S> = <<M as Module<B>>::Record as Record<B>>::Item<S>;
|
||||
|
||||
fn implements_clone<T: Clone>() {}
|
||||
|
||||
fn basic_implements_clone<B: Backend, S: PrecisionSettings>() {
|
||||
implements_clone::<RecordItem<ModuleBasic<B>, B, S>>();
|
||||
implements_clone::<RecordItem<ModuleComposed<B>, B, S>>();
|
||||
}
|
||||
|
||||
fn generic_implements_clone<B, S, M>()
|
||||
where
|
||||
B: Backend,
|
||||
S: PrecisionSettings,
|
||||
M: Module<B> + ModuleDisplay,
|
||||
RecordItem<M, B, S>: Clone,
|
||||
{
|
||||
implements_clone::<RecordItem<ModuleWithGenericModule<B, M>, B, S>>();
|
||||
implements_clone::<RecordItem<ModuleEnumWithGenericModule<B, M>, B, S>>();
|
||||
}
|
||||
}
|
||||
|
||||
mod state {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_load_from_record_basic() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleBasic::<TestBackend>::new(&device);
|
||||
let mut module_2 = ModuleBasic::<TestBackend>::new(&device);
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
assert_ne!(
|
||||
module_1.weight_basic.to_data(),
|
||||
module_2.weight_basic.to_data()
|
||||
);
|
||||
|
||||
module_2 = module_2.load_record(state_1);
|
||||
|
||||
assert_eq!(
|
||||
module_1.weight_basic.to_data(),
|
||||
module_2.weight_basic.to_data()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load_from_record_compose() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleComposed::<TestBackend>::new(&device);
|
||||
let mut module_2 = ModuleComposed::<TestBackend>::new(&device);
|
||||
assert_ne!(module_1.weight.to_data(), module_2.weight.to_data());
|
||||
assert_ne!(
|
||||
module_1.basic.weight_basic.to_data(),
|
||||
module_2.basic.weight_basic.to_data()
|
||||
);
|
||||
|
||||
let state_1 = module_1.clone().into_record();
|
||||
module_2 = module_2.load_record(state_1);
|
||||
|
||||
assert_eq!(module_1.weight.to_data(), module_2.weight.to_data());
|
||||
assert_eq!(
|
||||
module_1.basic.weight_basic.to_data(),
|
||||
module_2.basic.weight_basic.to_data()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load_from_record_enum() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
let mut module_2 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
let ModuleEnum::Basic(module_1_basic) = module_1 else {
|
||||
panic!("Invalid module type")
|
||||
};
|
||||
let ModuleEnum::Basic(module_2_basic) = module_2.clone() else {
|
||||
panic!("Invalid module type")
|
||||
};
|
||||
assert_ne!(
|
||||
module_1_basic.weight_basic.to_data(),
|
||||
module_2_basic.weight_basic.to_data()
|
||||
);
|
||||
|
||||
module_2 = module_2.load_record(state_1);
|
||||
|
||||
let ModuleEnum::Basic(module_2_basic) = module_2 else {
|
||||
panic!("Invalid module type")
|
||||
};
|
||||
assert_eq!(
|
||||
module_1_basic.weight_basic.to_data(),
|
||||
module_2_basic.weight_basic.to_data()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load_from_record_const_generic() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleWithConstGeneric {
|
||||
modules: [
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
],
|
||||
};
|
||||
let mut module_2 = ModuleWithConstGeneric {
|
||||
modules: [
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
],
|
||||
};
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
assert_ne!(
|
||||
module_1.modules[0].weight_basic.to_data(),
|
||||
module_2.modules[0].weight_basic.to_data(),
|
||||
);
|
||||
assert_ne!(
|
||||
module_1.modules[1].weight_basic.to_data(),
|
||||
module_2.modules[1].weight_basic.to_data(),
|
||||
);
|
||||
|
||||
module_2 = module_2.load_record(state_1);
|
||||
|
||||
assert_eq!(
|
||||
module_1.modules[0].weight_basic.to_data(),
|
||||
module_2.modules[0].weight_basic.to_data(),
|
||||
);
|
||||
assert_eq!(
|
||||
module_1.modules[1].weight_basic.to_data(),
|
||||
module_2.modules[1].weight_basic.to_data(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Can't parse record from a different variant")]
|
||||
fn should_panic_load_from_incorrect_enum_variant() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
let module_2 = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
module_2.load_record(state_1);
|
||||
}
|
||||
}
|
||||
|
||||
mod num_params {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_calculate_num_params_basic() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module = ModuleBasic::<TestBackend>::new(&device);
|
||||
assert_eq!(20 * 20, module.num_params());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_output_state_composed() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module = ModuleComposed::<TestBackend>::new(&device);
|
||||
assert_eq!(4 * 20 * 20, module.num_params());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_calculate_num_params_enum() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
assert_eq!(20 * 20, module.num_params());
|
||||
|
||||
let module = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
|
||||
assert_eq!(4 * 20 * 20, module.num_params());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod require_grad {
|
||||
use burn_tensor::backend::AutodiffBackend;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_have_grad_by_default() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module = ModuleBasic::<TestAutodiffBackend>::new(&device);
|
||||
let mut grads = calculate_grads(&module);
|
||||
|
||||
let grad_x = module.weight_basic.grad_remove(&mut grads);
|
||||
|
||||
assert!(grad_x.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_have_no_grad_after_no_grad() {
|
||||
let device = <TestAutodiffBackend as Backend>::Device::default();
|
||||
let module = ModuleBasic::<TestAutodiffBackend>::new(&device).no_grad();
|
||||
let mut grads = calculate_grads(&module);
|
||||
|
||||
let grad_x = module.weight_basic.grad_remove(&mut grads);
|
||||
|
||||
assert!(grad_x.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_have_grad_when_from_record() {
|
||||
let device = <TestAutodiffBackend as Backend>::Device::default();
|
||||
let module = ModuleBasic::<TestAutodiffBackend>::new(&device);
|
||||
let record = ModuleBasicRecord {
|
||||
weight_basic: module.weight_basic.clone(), // Even when param is no_grad,
|
||||
};
|
||||
let module = module.load_record(record);
|
||||
let mut grads = calculate_grads(&module);
|
||||
|
||||
let grad_x = module.weight_basic.grad_remove(&mut grads);
|
||||
|
||||
assert!(grad_x.is_some());
|
||||
}
|
||||
|
||||
fn calculate_grads(
|
||||
module: &ModuleBasic<TestAutodiffBackend>,
|
||||
) -> <TestAutodiffBackend as AutodiffBackend>::Gradients {
|
||||
let device = module.weight_basic.device();
|
||||
let x = Tensor::ones([20, 20], &device).require_grad();
|
||||
let y = module.weight_basic.val().matmul(x);
|
||||
|
||||
y.backward()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
use burn_core as burn;
|
||||
use burn_core::record::Record;
|
||||
|
||||
use burn_tensor::Tensor;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
// It compiles
|
||||
#[derive(Record)]
|
||||
pub struct TestWithBackendRecord<B: Backend> {
|
||||
tensor: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
// It compiles
|
||||
#[derive(Record)]
|
||||
pub struct TestWithoutBackendRecord {
|
||||
_tensor: usize,
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
#[cfg(feature = "std")]
|
||||
mod tests {
|
||||
use burn::{
|
||||
module::{Module, Param},
|
||||
record::{
|
||||
BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings,
|
||||
PrettyJsonFileRecorder, RecorderError,
|
||||
},
|
||||
};
|
||||
use burn_core as burn;
|
||||
use burn_ndarray::NdArrayDevice;
|
||||
use burn_tensor::{Tensor, backend::Backend};
|
||||
use std::path::PathBuf;
|
||||
|
||||
type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
/// Simple linear module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Linear<B: Backend> {
|
||||
pub weight: Param<Tensor<B, 2>>,
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Linear<B> {
|
||||
pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
|
||||
let weight = Tensor::random(
|
||||
[out_features, in_features],
|
||||
burn_tensor::Distribution::Default,
|
||||
device,
|
||||
);
|
||||
let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
|
||||
|
||||
Self {
|
||||
weight: Param::from_tensor(weight),
|
||||
bias: Some(Param::from_tensor(bias)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
single_const: f32,
|
||||
linear1: Linear<B>,
|
||||
array_const: [usize; 2],
|
||||
linear2: Linear<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModelNewOptionalField<B: Backend> {
|
||||
single_const: f32,
|
||||
linear1: Linear<B>,
|
||||
array_const: [usize; 2],
|
||||
linear2: Linear<B>,
|
||||
new_field: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModelNewConstantField<B: Backend> {
|
||||
single_const: f32,
|
||||
linear1: Linear<B>,
|
||||
array_const: [usize; 2],
|
||||
linear2: Linear<B>,
|
||||
new_field: usize,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
pub struct ModelNewFieldOrders<B: Backend> {
|
||||
array_const: [usize; 2],
|
||||
linear2: Linear<B>,
|
||||
single_const: f32,
|
||||
linear1: Linear<B>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_new_optional_field_works_with_default_file_recorder() {
|
||||
deserialize_with_new_optional_field(
|
||||
"default",
|
||||
DefaultFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_removed_optional_field_works_with_default_file_recorder() {
|
||||
deserialize_with_removed_optional_field(
|
||||
"default",
|
||||
DefaultFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_new_constant_field_works_with_default_file_recorder() {
|
||||
deserialize_with_new_constant_field(
|
||||
"default",
|
||||
DefaultFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_removed_constant_field_works_with_default_file_recorder() {
|
||||
deserialize_with_removed_constant_field(
|
||||
"default",
|
||||
DefaultFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_new_field_order_works_with_default_file_recorder() {
|
||||
deserialize_with_new_field_order(
|
||||
"default",
|
||||
DefaultFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
#[test]
|
||||
fn deserialize_with_new_optional_field_works_with_pretty_json() {
|
||||
deserialize_with_new_optional_field(
|
||||
"pretty-json",
|
||||
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_removed_optional_field_works_with_pretty_json() {
|
||||
deserialize_with_removed_optional_field(
|
||||
"pretty-json",
|
||||
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_new_constant_field_works_with_pretty_json() {
|
||||
deserialize_with_new_constant_field(
|
||||
"pretty-json",
|
||||
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_removed_constant_field_works_with_pretty_json() {
|
||||
deserialize_with_removed_constant_field(
|
||||
"pretty-json",
|
||||
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_new_field_order_works_with_pretty_json() {
|
||||
deserialize_with_new_field_order(
|
||||
"pretty-json",
|
||||
PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() {
|
||||
deserialize_with_new_optional_field("bin", BinFileRecorder::<FullPrecisionSettings>::new())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() {
|
||||
deserialize_with_removed_optional_field(
|
||||
"bin",
|
||||
BinFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_new_constant_field_works_with_bin_file_recorder() {
|
||||
deserialize_with_new_constant_field("bin", BinFileRecorder::<FullPrecisionSettings>::new())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() {
|
||||
deserialize_with_removed_constant_field(
|
||||
"bin",
|
||||
BinFileRecorder::<FullPrecisionSettings>::new(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn deserialize_with_new_field_order_works_with_bin_file_recorder() {
|
||||
deserialize_with_new_field_order("bin", BinFileRecorder::<FullPrecisionSettings>::new())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn file_path(filename: String) -> PathBuf {
|
||||
std::env::temp_dir().join(filename)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_serde() {
|
||||
let tensor: burn_tensor::Tensor<TestBackend, 1> =
|
||||
burn_tensor::Tensor::ones([1], &NdArrayDevice::default());
|
||||
let encoded = serde_json::to_string(&tensor).unwrap();
|
||||
let decoded: burn_tensor::Tensor<TestBackend, 1> = serde_json::from_str(&encoded).unwrap();
|
||||
assert_eq!(tensor.into_data(), decoded.into_data());
|
||||
}
|
||||
|
||||
fn deserialize_with_new_optional_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf = file_path(format!("deserialize_with_new_optional_field-{name}"));
|
||||
let model = Model {
|
||||
single_const: 32.0,
|
||||
linear1: Linear::<TestBackend>::new(20, 20, &device),
|
||||
array_const: [2, 2],
|
||||
linear2: Linear::<TestBackend>::new(20, 20, &device),
|
||||
};
|
||||
|
||||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result =
|
||||
recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn deserialize_with_removed_optional_field<R>(
|
||||
name: &str,
|
||||
recorder: R,
|
||||
) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf =
|
||||
file_path(format!("deserialize_with_removed_optional_field-{name}"));
|
||||
let model = ModelNewOptionalField {
|
||||
single_const: 32.0,
|
||||
linear1: Linear::<TestBackend>::new(20, 20, &device),
|
||||
array_const: [2, 2],
|
||||
linear2: Linear::<TestBackend>::new(20, 20, &device),
|
||||
new_field: None,
|
||||
};
|
||||
|
||||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn deserialize_with_new_constant_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf = file_path(format!("deserialize_with_new_constant_field-{name}"));
|
||||
let model = Model {
|
||||
single_const: 32.0,
|
||||
array_const: [2, 2],
|
||||
linear1: Linear::<TestBackend>::new(20, 20, &device),
|
||||
linear2: Linear::<TestBackend>::new(20, 20, &device),
|
||||
};
|
||||
|
||||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result =
|
||||
recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn deserialize_with_removed_constant_field<R>(
|
||||
name: &str,
|
||||
recorder: R,
|
||||
) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf =
|
||||
file_path(format!("deserialize_with_removed_constant_field-{name}"));
|
||||
let model = ModelNewConstantField {
|
||||
single_const: 32.0,
|
||||
array_const: [2, 2],
|
||||
linear1: Linear::<TestBackend>::new(20, 20, &device),
|
||||
linear2: Linear::<TestBackend>::new(20, 20, &device),
|
||||
new_field: 0,
|
||||
};
|
||||
|
||||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn deserialize_with_new_field_order<R>(name: &str, recorder: R) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf = file_path(format!("deserialize_with_new_field_order-{name}"));
|
||||
let model = Model {
|
||||
array_const: [2, 2],
|
||||
single_const: 32.0,
|
||||
linear1: Linear::<TestBackend>::new(20, 20, &device),
|
||||
linear2: Linear::<TestBackend>::new(20, 20, &device),
|
||||
};
|
||||
|
||||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
|
||||
let result =
|
||||
recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user