feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,80 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Training crate for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-train"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-train"
documentation = "https://docs.rs/burn-train"
version.workspace = true
[lints]
workspace = true
[features]
default = ["sys-metrics", "tui", "rl"]
doc = ["default"]
vision = ["burn-nn", "burn-store/pytorch", "burn-std/network", "dirs"]
tracing = [
"burn-core/tracing",
"burn-optim/tracing",
"burn-collective?/tracing",
]
sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"]
tui = ["ratatui"]
rl = ["burn-rl"]
# Distributed Data Parallel
ddp = ["burn-collective", "burn-optim/collective"]
[dependencies]
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [
"dataset",
"std",
], default-features = false }
burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [
"std",
], default-features = false }
burn-rl = { path = "../burn-rl", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true }
burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] }
burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] }
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] }
dirs = { workspace = true, optional = true }
log = { workspace = true }
tracing-subscriber = { workspace = true }
tracing-appender = { workspace = true }
tracing-core = { workspace = true }
# System Metrics
nvml-wrapper = { workspace = true, optional = true }
sysinfo = { workspace = true, optional = true }
systemstat = { workspace = true, optional = true }
# Text UI
ratatui = { workspace = true, optional = true, features = [
"all-widgets",
"crossterm",
] }
# Utilities
derive-new = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
async-channel = { workspace = true }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
rstest.workspace = true
thiserror.workspace = true
rand.workspace = true
[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,6 @@
# Burn Train
This crate should be used with [burn](https://github.com/tracel-ai/burn).
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-train.svg)](https://crates.io/crates/burn-train)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-train/blob/master/README.md)

View File

@@ -0,0 +1,171 @@
use super::{Checkpointer, CheckpointerError};
use crate::Interrupter;
use burn_core::{record::Record, tensor::backend::Backend};
use std::sync::mpsc;
enum Message<R, B: Backend> {
Restore(
usize,
B::Device,
mpsc::SyncSender<Result<R, CheckpointerError>>,
Option<Interrupter>,
),
Save(usize, R, Option<Interrupter>),
Delete(usize, Option<Interrupter>),
End,
}
#[derive(new)]
struct CheckpointerThread<C, R, B: Backend> {
checkpointer: C,
receiver: mpsc::Receiver<Message<R, B>>,
}
impl<C, R, B> CheckpointerThread<C, R, B>
where
C: Checkpointer<R, B>,
R: Record<B>,
B: Backend,
{
fn run(self) {
for item in self.receiver.iter() {
match item {
Message::Restore(epoch, device, callback, interrupter) => {
let record = self.checkpointer.restore(epoch, &device);
callback.send(record).unwrap_or_else(|err| {
interrupter.map_or_else(
|| {
panic!(
"Error when sending response through callback channel: {err}"
)
},
|int| int.stop(Some(&err.to_string())),
)
});
}
Message::Save(epoch, state, interrupter) => {
self.checkpointer.save(epoch, state).unwrap_or_else(|err| {
interrupter.map_or_else(
|| panic!("Error when saving the state: {err}"),
|int| int.stop(Some(&err.to_string())),
)
});
}
Message::Delete(epoch, interrupter) => {
self.checkpointer.delete(epoch).unwrap_or_else(|err| {
interrupter.map_or_else(
|| panic!("Error when deleting the state: {err}"),
|int| int.stop(Some(&err.to_string())),
)
});
}
Message::End => {
return;
}
};
}
}
}
/// Async checkpointer.
pub struct AsyncCheckpointer<Record, B: Backend> {
sender: mpsc::SyncSender<Message<Record, B>>,
handler: Option<std::thread::JoinHandle<()>>,
interrupter: Option<Interrupter>,
}
impl<R, B> AsyncCheckpointer<R, B>
where
R: Record<B> + 'static,
B: Backend,
{
/// Create a new async checkpointer.
///
/// # Arguments
///
/// * `checkpointer` - The checkpointer.
///
/// # Returns
///
/// The async checkpointer.
pub fn new<C>(checkpointer: C) -> Self
where
C: Checkpointer<R, B> + Send + 'static,
{
// Only on checkpoint can be done in advance.
let (sender, receiver) = mpsc::sync_channel(0);
let thread = CheckpointerThread::new(checkpointer, receiver);
let handler = Some(std::thread::spawn(move || thread.run()));
Self {
sender,
handler,
interrupter: None,
}
}
/// Assign a handle used to interrupt training in case of checkpointing error.
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
self.interrupter = Some(interrupter);
self
}
}
impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
where
R: Record<B> + 'static,
B: Backend,
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
self.sender
.send(Message::Save(epoch, record, self.interrupter.clone()))
.expect("Can send message to checkpointer thread.");
Ok(())
}
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::Restore(
epoch,
device.clone(),
sender,
self.interrupter.clone(),
))
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
if let Ok(record) = receiver.recv() {
return record;
};
Err(CheckpointerError::Unknown("Channel error.".to_string()))
}
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
self.sender
.send(Message::Delete(epoch, self.interrupter.clone()))
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
Ok(())
}
}
impl<E, B> Drop for AsyncCheckpointer<E, B>
where
B: Backend,
{
fn drop(&mut self) {
self.sender
.send(Message::End)
.expect("Can send the end message to the checkpointer thread.");
let handler = self.handler.take();
if let Some(handler) = handler {
handler
.join()
.expect("The checkpointer thread should stop.");
}
}
}

View File

@@ -0,0 +1,51 @@
use burn_core::{
record::{Record, RecorderError},
tensor::backend::Backend,
};
use thiserror::Error;
/// The error type for checkpointer.
#[derive(Error, Debug)]
pub enum CheckpointerError {
/// IO error.
#[error("I/O Error: `{0}`")]
IOError(std::io::Error),
/// Recorder error.
#[error("Recorder error: `{0}`")]
RecorderError(RecorderError),
/// Other errors.
#[error("Unknown error: `{0}`")]
Unknown(String),
}
/// The trait for checkpointer.
pub trait Checkpointer<R, B>: Send + Sync
where
R: Record<B>,
B: Backend,
{
/// Save the record.
///
/// # Arguments
///
/// * `epoch` - The epoch.
/// * `record` - The record.
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
/// Delete the record at the given epoch if present.
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
/// Restore the record.
///
/// # Arguments
///
/// * `epoch` - The epoch.
/// * `device` - The device used to restore the record.
///
/// # Returns
///
/// The record.
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
}

View File

@@ -0,0 +1,86 @@
use std::path::{Path, PathBuf};
use super::{Checkpointer, CheckpointerError};
use burn_core::{
record::{FileRecorder, Record},
tensor::backend::Backend,
};
/// The file checkpointer.
pub struct FileCheckpointer<FR> {
directory: PathBuf,
name: String,
recorder: FR,
}
impl<FR> FileCheckpointer<FR> {
/// Creates a new file checkpointer.
///
/// # Arguments
///
/// * `recorder` - The file recorder.
/// * `directory` - The directory to save the checkpoints.
/// * `name` - The name of the checkpoint.
pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
let directory = directory.as_ref();
std::fs::create_dir_all(directory).ok();
Self {
directory: directory.to_path_buf(),
name: name.to_string(),
recorder,
}
}
fn path_for_epoch(&self, epoch: usize) -> PathBuf {
self.directory.join(format!("{}-{}", self.name, epoch))
}
}
impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
where
R: Record<B>,
FR: FileRecorder<B>,
B: Backend,
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::trace!("Saving checkpoint {} to {}", epoch, file_path.display());
self.recorder
.record(record, file_path)
.map_err(CheckpointerError::RecorderError)?;
Ok(())
}
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!(
"Restoring checkpoint {} from {}",
epoch,
file_path.display()
);
let record = self
.recorder
.load(file_path, device)
.map_err(CheckpointerError::RecorderError)?;
Ok(record)
}
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
let file_to_remove = format!(
"{}.{}",
self.path_for_epoch(epoch).display(),
FR::file_extension(),
);
if std::path::Path::new(&file_to_remove).exists() {
log::trace!("Removing checkpoint {file_to_remove}");
std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;
}
Ok(())
}
}

View File

@@ -0,0 +1,9 @@
mod async_checkpoint;
mod base;
mod file;
mod strategy;
pub use async_checkpoint::*;
pub use base::*;
pub use file::*;
pub use strategy::*;

View File

@@ -0,0 +1,34 @@
use std::ops::DerefMut;
use crate::metric::store::EventStoreClient;
/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).
#[derive(Clone, PartialEq, Debug)]
pub enum CheckpointingAction {
/// Delete the given epoch.
Delete(usize),
/// Save the current record.
Save,
}
/// Define when checkpoint should be saved and deleted.
pub trait CheckpointingStrategy: Send {
/// Based on the epoch, determine if the checkpoint should be saved.
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction>;
}
// We make dyn box implement the checkpointing strategy so that it can be used with generic, but
// still be dynamic.
impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction> {
self.deref_mut().checkpointing(epoch, collector)
}
}

View File

@@ -0,0 +1,146 @@
use crate::metric::store::EventStoreClient;
use super::{CheckpointingAction, CheckpointingStrategy};
use std::collections::HashSet;
/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an
/// epoch to be deleted.
pub struct ComposedCheckpointingStrategy {
strategies: Vec<Box<dyn CheckpointingStrategy>>,
deleted: Vec<HashSet<usize>>,
}
/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.
#[derive(Default)]
pub struct ComposedCheckpointingStrategyBuilder {
strategies: Vec<Box<dyn CheckpointingStrategy>>,
}
impl ComposedCheckpointingStrategyBuilder {
/// Add a new [checkpointing strategy](CheckpointingStrategy).
#[allow(clippy::should_implement_trait)]
pub fn add<S>(mut self, strategy: S) -> Self
where
S: CheckpointingStrategy + 'static,
{
self.strategies.push(Box::new(strategy));
self
}
/// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).
pub fn build(self) -> ComposedCheckpointingStrategy {
ComposedCheckpointingStrategy::new(self.strategies)
}
}
impl ComposedCheckpointingStrategy {
fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
Self {
deleted: strategies.iter().map(|_| HashSet::new()).collect(),
strategies,
}
}
/// Create a new builder which help compose multiple
/// [checkpointing strategies](CheckpointingStrategy).
pub fn builder() -> ComposedCheckpointingStrategyBuilder {
ComposedCheckpointingStrategyBuilder::default()
}
}
impl CheckpointingStrategy for ComposedCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let mut saved = false;
let mut actions = Vec::new();
let mut epochs_to_check = Vec::new();
for (i, strategy) in self.strategies.iter_mut().enumerate() {
let actions = strategy.checkpointing(epoch, collector);
// We assume that the strategy would not want the current epoch to be saved.
// So we flag it as deleted.
if actions.is_empty() {
self.deleted
.get_mut(i)
.expect("As many 'deleted' as 'strategies'.")
.insert(epoch);
}
for action in actions {
match action {
CheckpointingAction::Delete(epoch) => {
self.deleted
.get_mut(i)
.expect("As many 'deleted' as 'strategies'.")
.insert(epoch);
epochs_to_check.push(epoch);
}
CheckpointingAction::Save => saved = true,
}
}
}
if saved {
actions.push(CheckpointingAction::Save);
}
for epoch in epochs_to_check.into_iter() {
let mut num_true = 0;
for i in 0..self.strategies.len() {
if self
.deleted
.get(i)
.expect("Ad many 'deleted' as 'strategies'.")
.contains(&epoch)
{
num_true += 1;
}
}
if num_true == self.strategies.len() {
actions.push(CheckpointingAction::Delete(epoch));
for i in 0..self.strategies.len() {
self.deleted
.get_mut(i)
.expect("As many 'deleted' as 'strategies'.")
.remove(&epoch);
}
}
}
actions
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};
#[test]
fn should_delete_when_both_deletes() {
let store = EventStoreClient::new(LogEventStore::default());
let mut strategy = ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(1))
.add(KeepLastNCheckpoints::new(2))
.build();
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(1, &store)
);
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(2, &store)
);
assert_eq!(
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
strategy.checkpointing(3, &store)
);
}
}

View File

@@ -0,0 +1,56 @@
use super::CheckpointingStrategy;
use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};
/// Keep the last N checkpoints.
///
/// Very useful when training, minimizing disk space while ensuring that the training can be
/// resumed even if something goes wrong.
#[derive(new)]
pub struct KeepLastNCheckpoints {
num_keep: usize,
}
impl CheckpointingStrategy for KeepLastNCheckpoints {
fn checkpointing(
&mut self,
epoch: usize,
_store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let mut actions = vec![CheckpointingAction::Save];
if let Some(epoch) = usize::checked_sub(epoch, self.num_keep)
&& epoch > 0
{
actions.push(CheckpointingAction::Delete(epoch));
}
actions
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metric::store::LogEventStore;
#[test]
fn should_always_delete_lastn_epoch_if_higher_than_one() {
let mut strategy = KeepLastNCheckpoints::new(2);
let store = EventStoreClient::new(LogEventStore::default());
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(1, &store)
);
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(2, &store)
);
assert_eq!(
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
strategy.checkpointing(3, &store)
);
}
}

View File

@@ -0,0 +1,136 @@
use super::CheckpointingStrategy;
use crate::{
checkpoint::CheckpointingAction,
metric::{
Metric, MetricName,
store::{Aggregate, Direction, EventStoreClient, Split},
},
};
/// Keep the best checkpoint based on a metric.
pub struct MetricCheckpointingStrategy {
current: Option<usize>,
aggregate: Aggregate,
direction: Direction,
split: Split,
name: MetricName,
}
impl MetricCheckpointingStrategy {
/// Create a new metric checkpointing strategy.
pub fn new<M>(metric: &M, aggregate: Aggregate, direction: Direction, split: Split) -> Self
where
M: Metric,
{
Self {
current: None,
name: metric.name(),
aggregate,
direction,
split,
}
}
}
impl CheckpointingStrategy for MetricCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let best_epoch =
match store.find_epoch(&self.name, self.aggregate, self.direction, &self.split) {
Some(epoch_best) => epoch_best,
None => epoch,
};
let mut actions = Vec::new();
if let Some(current) = self.current
&& current != best_epoch
{
actions.push(CheckpointingAction::Delete(current));
}
if best_epoch == epoch {
actions.push(CheckpointingAction::Save);
}
self.current = Some(best_epoch);
actions
}
}
#[cfg(test)]
mod tests {
use crate::{
EventProcessorTraining, TestBackend,
logger::InMemoryMetricLogger,
metric::{
LossMetric,
processor::{
MetricsTraining, MinimalEventProcessor,
test_utils::{end_epoch, process_train},
},
store::LogEventStore,
},
};
use super::*;
use std::sync::Arc;
#[test]
fn always_keep_the_best_epoch() {
let loss = LossMetric::<TestBackend>::new();
let mut store = LogEventStore::default();
let mut strategy = MetricCheckpointingStrategy::new(
&loss,
Aggregate::Mean,
Direction::Lowest,
Split::Train,
);
let mut metrics = MetricsTraining::<f64, f64>::default();
// Register an in memory logger.
store.register_logger(InMemoryMetricLogger::default());
// Register the loss metric.
metrics.register_train_metric_numeric(loss);
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
processor.process_train(crate::LearnerEvent::Start);
// Two points for the first epoch. Mean 0.75
let mut epoch = 1;
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 0.5, epoch);
end_epoch(&mut processor, epoch);
// Should save the current record.
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(epoch, &store)
);
// Two points for the second epoch. Mean 0.4
epoch += 1;
process_train(&mut processor, 0.5, epoch);
process_train(&mut processor, 0.3, epoch);
end_epoch(&mut processor, epoch);
// Should save the current record and delete the previous one.
assert_eq!(
vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
strategy.checkpointing(epoch, &store)
);
// Two points for the last epoch. Mean 2.0
epoch += 1;
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 3.0, epoch);
end_epoch(&mut processor, epoch);
// Should not delete the previous record, since it's the best one, and should not save a
// new one.
assert!(strategy.checkpointing(epoch, &store).is_empty());
}
}

View File

@@ -0,0 +1,9 @@
mod base;
mod composed;
mod lastn;
mod metric;
pub use base::*;
pub use composed::*;
pub use lastn::*;
pub use metric::*;

View File

@@ -0,0 +1,66 @@
use crate::{InferenceStep, TrainStep};
use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
use std::marker::PhantomData;
/// Components used for a model to learn, grouped in one trait.
pub trait LearningComponentsTypes {
/// The backend used for training.
type Backend: AutodiffBackend;
/// The learning rate scheduler used for training.
type LrScheduler: LrScheduler + 'static;
/// The model to train.
type TrainingModel: TrainStep
+ AutodiffModule<Self::Backend, InnerModule = Self::InferenceModel>
+ core::fmt::Display
+ 'static;
/// The non-autodiff type of the model.
type InferenceModel: InferenceStep;
/// The optimizer used for training.
type Optimizer: Optimizer<Self::TrainingModel, Self::Backend> + 'static;
}
/// Concrete type that implements the [LearningComponentsTypes](LearningComponentsTypes) trait.
pub struct LearningComponentsMarker<B, LR, M, O> {
_backend: PhantomData<B>,
_lr_scheduler: PhantomData<LR>,
_model: PhantomData<M>,
_optimizer: PhantomData<O>,
}
impl<B, LR, M, O> LearningComponentsTypes for LearningComponentsMarker<B, LR, M, O>
where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
M::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
{
type Backend = B;
type LrScheduler = LR;
type TrainingModel = M;
type InferenceModel = M::InnerModule;
type Optimizer = O;
}
/// The training backend.
pub type TrainingBackend<LC> = <LC as LearningComponentsTypes>::Backend;
/// The inference backend.
pub(crate) type InferenceBackend<LC> =
<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend;
/// The model used for training.
pub type TrainingModel<LC> = <LC as LearningComponentsTypes>::TrainingModel;
/// The non-autodiff model.
pub(crate) type InferenceModel<LC> = <LC as LearningComponentsTypes>::InferenceModel;
/// Type for training input.
pub(crate) type TrainingModelInput<LC> =
<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input;
/// Type for inference input.
pub(crate) type InferenceModelInput<LC> =
<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input;
/// Type for training output.
pub(crate) type TrainingModelOutput<LC> =
<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output;
/// Type for inference output.
pub(crate) type InferenceModelOutput<LC> =
<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output;

View File

@@ -0,0 +1,72 @@
use crate::{
AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep,
Interrupter, LearnerSummaryConfig,
evaluator::components::EvaluatorComponentTypes,
metric::processor::{EvaluatorEvent, EventProcessorEvaluation},
renderer::{EvaluationName, MetricsRenderer},
};
use burn_core::{data::dataloader::DataLoader, module::Module};
use std::sync::Arc;
pub(crate) type TestBackend<EC> = <EC as EvaluatorComponentTypes>::Backend;
pub(crate) type TestInput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Input;
pub(crate) type TestOutput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Output;
pub(crate) type TestLoader<EC> = Arc<dyn DataLoader<TestBackend<EC>, TestInput<EC>>>;
/// Evaluates a model on a specific dataset.
pub struct Evaluator<EC: EvaluatorComponentTypes> {
pub(crate) model: EC::Model,
pub(crate) interrupter: Interrupter,
pub(crate) event_processor:
AsyncProcessorEvaluation<FullEventProcessorEvaluation<TestOutput<EC>>>,
/// Config for creating a summary of the evaluation
pub summary: Option<LearnerSummaryConfig>,
}
impl<EC: EvaluatorComponentTypes> Evaluator<EC> {
/// Run the evaluation on the given dataset.
///
/// The data will be stored and displayed under the provided name.
pub fn eval<S: core::fmt::Display>(
mut self,
name: S,
dataloader: TestLoader<EC>,
) -> Box<dyn MetricsRenderer> {
// Move dataloader to the model device
let dataloader = dataloader.to_device(self.model.devices().first().unwrap());
let name = EvaluationName::new(name);
let mut iterator = dataloader.iter();
let mut iteration = 0;
self.event_processor.process_test(EvaluatorEvent::Start);
while let Some(item) = iterator.next() {
let progress = iterator.progress();
iteration += 1;
let item = self.model.step(item);
let item = EvaluationItem::new(item, progress, Some(iteration));
self.event_processor
.process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));
if self.interrupter.should_stop() {
log::info!("Testing interrupted.");
break;
}
}
let summary = self.summary.and_then(|summary| {
summary
.init()
.map(|summary| summary.with_model(self.model.to_string()))
.ok()
});
self.event_processor
.process_test(EvaluatorEvent::End(summary));
self.event_processor.renderer()
}
}

View File

@@ -0,0 +1,212 @@
use crate::{
ApplicationLoggerInstaller, Evaluator, FileApplicationLoggerInstaller, InferenceStep,
Interrupter, LearnerSummaryConfig, TestOutput,
evaluator::components::{EvaluatorComponentTypes, EvaluatorComponentTypesMarker},
logger::FileMetricLogger,
metric::{
Adaptor, ItemLazy, Metric, Numeric,
processor::{AsyncProcessorEvaluation, FullEventProcessorEvaluation, MetricsEvaluation},
store::{EventStoreClient, LogEventStore},
},
renderer::{MetricsRenderer, default_renderer},
};
use burn_core::{module::Module, prelude::Backend};
use std::{
collections::BTreeSet,
path::{Path, PathBuf},
sync::Arc,
};
/// Struct to configure and create an [evaluator](Evaluator).
///
/// The generics components of the builder should probably not be set manually, as they are
/// optimized for Rust type inference.
pub struct EvaluatorBuilder<EC: EvaluatorComponentTypes> {
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
event_store: LogEventStore,
summary_metrics: BTreeSet<String>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
interrupter: Interrupter,
metrics: MetricsEvaluation<TestOutput<EC>>,
directory: PathBuf,
summary: bool,
}
impl<B, M> EvaluatorBuilder<EvaluatorComponentTypesMarker<B, M>>
where
B: Backend,
M: Module<B> + InferenceStep + core::fmt::Display + 'static,
{
/// Creates a new evaluator builder.
///
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
pub fn new(directory: impl AsRef<Path>) -> Self {
let directory = directory.as_ref().to_path_buf();
let log_file = directory.join("evaluation.log");
Self {
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(log_file))),
event_store: LogEventStore::default(),
summary_metrics: Default::default(),
renderer: None,
interrupter: Interrupter::new(),
summary: false,
metrics: MetricsEvaluation::default(),
directory,
}
}
}
impl<EC: EvaluatorComponentTypes> EvaluatorBuilder<EC> {
/// Registers [numeric](crate::metric::Numeric) test [metrics](Metric).
pub fn metrics<Me: EvalMetricRegistration<EC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Registers text [metrics](Metric).
pub fn metrics_text<Me: EvalTextMetricRegistration<EC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// By default, Rust logs are captured and written into
/// `evaluation.log`. If disabled, standard Rust log handling
/// will apply.
pub fn with_application_logger(
mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self
}
/// Register a [numeric](crate::metric::Numeric) test [metric](Metric).
pub fn metric_numeric<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_test_metric_numeric(metric);
self
}
/// Register a text test [metric](Metric).
pub fn metric<Me>(mut self, metric: Me) -> Self
where
Me: Metric + 'static,
<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_test_metric(metric);
self
}
/// Replace the default CLI renderer with a custom one.
///
/// # Arguments
///
/// * `renderer` - The custom renderer.
pub fn renderer(mut self, renderer: Box<dyn MetricsRenderer + 'static>) -> Self {
self.renderer = Some(renderer);
self
}
/// Enable the evaluation summary report.
///
/// The summary will be displayed at the end of `.eval()`.
pub fn summary(mut self) -> Self {
self.summary = true;
self
}
/// Builds the evaluator.
#[allow(clippy::type_complexity)]
pub fn build(mut self, model: EC::Model) -> Evaluator<EC> {
let renderer = self
.renderer
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), None));
self.event_store
.register_logger(FileMetricLogger::new_eval(self.directory.clone()));
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = AsyncProcessorEvaluation::new(FullEventProcessorEvaluation::new(
self.metrics,
renderer,
event_store,
));
let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};
Evaluator {
model,
interrupter: self.interrupter,
event_processor,
summary,
}
}
}
/// Trait to fake variadic generics.
pub trait EvalMetricRegistration<EC: EvaluatorComponentTypes>: Sized {
/// Register the metrics.
fn register(self, builder: EvaluatorBuilder<EC>) -> EvaluatorBuilder<EC>;
}
/// Trait to fake variadic generics.
pub trait EvalTextMetricRegistration<EC: EvaluatorComponentTypes>: Sized {
/// Register the metrics.
fn register(self, builder: EvaluatorBuilder<EC>) -> EvaluatorBuilder<EC>;
}
macro_rules! gen_tuple {
($($M:ident),*) => {
impl<$($M,)* EC: EvaluatorComponentTypes> EvalTextMetricRegistration<EC> for ($($M,)*)
where
$(<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: EvaluatorBuilder<EC>,
) -> EvaluatorBuilder<EC> {
let ($($M,)*) = self;
$(let builder = builder.metric($M);)*
builder
}
}
impl<$($M,)* EC: EvaluatorComponentTypes> EvalMetricRegistration<EC> for ($($M,)*)
where
$(<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + $crate::metric::Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: EvaluatorBuilder<EC>,
) -> EvaluatorBuilder<EC> {
let ($($M,)*) = self;
$(let builder = builder.metric_numeric($M);)*
builder
}
}
};
}
gen_tuple!(M1);
gen_tuple!(M1, M2);
gen_tuple!(M1, M2, M3);
gen_tuple!(M1, M2, M3, M4);
gen_tuple!(M1, M2, M3, M4, M5);
gen_tuple!(M1, M2, M3, M4, M5, M6);

View File

@@ -0,0 +1,25 @@
use crate::InferenceStep;
use burn_core::{module::Module, prelude::Backend};
use std::marker::PhantomData;
/// All components necessary to evaluate a model grouped in one trait.
pub trait EvaluatorComponentTypes {
/// The backend in used for the evaluation.
type Backend: Backend;
/// The model to evaluate.
type Model: Module<Self::Backend> + InferenceStep + core::fmt::Display + 'static;
}
/// A marker type used to provide [evaluation components](EvaluatorComponentTypes).
pub struct EvaluatorComponentTypesMarker<B, M> {
_p: PhantomData<(B, M)>,
}
impl<B, M> EvaluatorComponentTypes for EvaluatorComponentTypesMarker<B, M>
where
B: Backend,
M: Module<B> + InferenceStep + core::fmt::Display + 'static,
{
type Backend = B;
type Model = M;
}

View File

@@ -0,0 +1,7 @@
mod base;
mod builder;
pub(crate) mod components;
pub use base::*;
pub use builder::*;

View File

@@ -0,0 +1,69 @@
use std::path::{Path, PathBuf};
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{Layer, registry};
/// This trait is used to install an application logger.
pub trait ApplicationLoggerInstaller {
/// Install the application logger.
fn install(&self) -> Result<(), String>;
}
/// This struct is used to install a local file application logger to output logs to a given file path.
pub struct FileApplicationLoggerInstaller {
path: PathBuf,
}
impl FileApplicationLoggerInstaller {
/// Create a new file application logger.
pub fn new(path: impl AsRef<Path>) -> Self {
Self {
path: path.as_ref().to_path_buf(),
}
}
}
impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
fn install(&self) -> Result<(), String> {
let path = Path::new(&self.path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name().unwrap_or_else(|| {
panic!("The path '{}' to point to a file.", self.path.display())
}),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(writer)
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
if registry().with(layer).try_init().is_err() {
return Err("Failed to install the file logger.".to_string());
}
let hook = std::panic::take_hook();
let file_path = self.path.to_owned();
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {info}");
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{}'\n=============",
file_path.display()
);
hook(info);
}));
Ok(())
}
}

View File

@@ -0,0 +1,255 @@
use crate::LearningComponentsMarker;
use crate::checkpoint::{
AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
};
use crate::components::{LearningComponentsTypes, TrainingBackend};
use crate::metric::store::EventStoreClient;
use crate::{
CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
TrainingModelOutput,
};
use burn_core::module::{AutodiffModule, Module};
use burn_core::prelude::Backend;
use burn_core::tensor::Device;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::lr_scheduler::LrScheduler;
use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
/// The record of the learner's model.
pub type LearnerModelRecord<LC> =
<<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
/// The record of the optimizer.
pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
<LC as LearningComponentsTypes>::TrainingModel,
TrainingBackend<LC>,
>>::Record;
/// The record of the LR scheduler.
pub type LearnerSchedulerRecord<LC> =
<<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
/// Learner struct encapsulating all components necessary to train a Neural Network model.
pub struct Learner<LC: LearningComponentsTypes> {
pub(crate) model: LC::TrainingModel,
optim: LC::Optimizer,
lr_scheduler: LC::LrScheduler,
lr: f64,
}
impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
fn clone(&self) -> Self {
Self {
model: self.model.clone(),
optim: self.optim.clone(),
lr_scheduler: self.lr_scheduler.clone(),
lr: self.lr,
}
}
}
impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
M::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
{
/// Create a learner.
pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
Self {
model,
optim,
lr_scheduler,
lr: 0.0,
}
}
}
impl<LC: LearningComponentsTypes> Learner<LC> {
/// Fork the learner's model to the given device.
pub fn fork(&mut self, device: &<TrainingBackend<LC> as Backend>::Device) {
self.model = self.model().fork(device);
}
/// Returns the current model.
pub fn model(&self) -> LC::TrainingModel {
self.model.clone()
}
/// Returns the current learning rate.
pub fn lr_current(&self) -> f64 {
self.lr
}
/// Executes a step of the learning rate scheduler.
pub fn lr_step(&mut self) {
self.lr = self.lr_scheduler.step();
}
/// Runs a step of the model for training, which executes the forward and backward passes.
///
/// # Arguments
///
/// * `item` - The input for the model.
///
/// # Returns
///
/// The output containing the model output and the gradients.
pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
self.model.step(item)
}
/// Optimize the current module with the provided gradients and learning rate.
///
/// # Arguments
///
/// * `optim`: Optimizer used for learning.
/// * `lr`: The learning rate used for this step.
/// * `grads`: The gradients of each parameter in the current model.
pub fn optimizer_step(&mut self, grads: GradientsParams) {
self.model = self.model().optimize(&mut self.optim, self.lr, grads);
}
/// Optimize the current module with the provided gradients and learning rate.
///
/// # Arguments
///
/// * `optim`: Optimizer used for learning.
/// * `lr`: The learning rate used for this step.
/// * `grads`: Multiple gradients associated to each parameter in the current model.
pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
}
/// Load the module state from a [record](LearnerModelRecord<LC>).
pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
self.model = self.model.clone().load_record(record);
}
/// Load the state of the learner's optimizer as a [record](LearnerOptimizerRecord<LC>).
pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
self.optim = self.optim.clone().load_record(record);
}
/// Load the state of the learner's scheduler as a [record](LearnerSchedulerRecord<LC>).
pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
}
}
#[derive(new)]
/// Used to create, delete, or load checkpoints of the training process.
pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
strategy: Box<dyn CheckpointingStrategy>,
}
impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
/// Create checkpoint for the training process.
pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
let actions = self.strategy.checkpointing(epoch, store);
for action in actions {
match action {
CheckpointingAction::Delete(epoch) => {
self.model
.delete(epoch)
.expect("Can delete model checkpoint.");
self.optim
.delete(epoch)
.expect("Can delete optimizer checkpoint.");
self.lr_scheduler
.delete(epoch)
.expect("Can delete learning rate scheduler checkpoint.");
}
CheckpointingAction::Save => {
self.model
.save(epoch, learner.model.clone().into_record())
.expect("Can save model checkpoint.");
self.optim
.save(epoch, learner.optim.to_record())
.expect("Can save optimizer checkpoint.");
self.lr_scheduler
.save(epoch, learner.lr_scheduler.to_record())
.expect("Can save learning rate scheduler checkpoint.");
}
}
}
}
/// Load a training checkpoint.
pub fn load_checkpoint(
&self,
mut learner: Learner<LC>,
device: &Device<LC::Backend>,
epoch: usize,
) -> Learner<LC> {
let record = self
.model
.restore(epoch, device)
.expect("Can load model checkpoint.");
learner.load_model(record);
let record = self
.optim
.restore(epoch, device)
.expect("Can load optimizer checkpoint.");
learner.load_optim(record);
let record = self
.lr_scheduler
.restore(epoch, device)
.expect("Can load learning rate scheduler checkpoint.");
learner.load_scheduler(record);
learner
}
}
/// Cloneable reference to an early stopping strategy
pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
#[derive(Clone, Default)]
/// A handle that allows aborting the training/evaluation process early.
pub struct Interrupter {
state: Arc<AtomicBool>,
message: Arc<Mutex<Option<String>>>,
}
impl Interrupter {
/// Create a new instance.
pub fn new() -> Self {
Self::default()
}
/// Notify the learner that it should stop.
/// # Arguments
/// * `reason` - A string describing the reason the training was stopped.
pub fn stop(&self, reason: Option<&str>) {
self.state.store(true, Ordering::Relaxed);
reason.inspect(|r| {
let mut message = self.message.lock().unwrap();
*message = Some(String::from(*r));
});
}
/// Reset the interrupter.
pub fn reset(&self) {
self.state.store(false, Ordering::Relaxed);
}
/// True if .stop() has been called.
pub fn should_stop(&self) -> bool {
self.state.load(Ordering::Relaxed)
}
/// Get the message associated with the interrupt.
pub fn get_message(&self) -> Option<String> {
let message = self.message.lock().unwrap();
message.clone()
}
}

View File

@@ -0,0 +1,159 @@
use crate::metric::{
AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,
PerplexityInput, TopKAccuracyInput, processor::ItemLazy,
};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor, Transaction};
use burn_ndarray::NdArray;
/// Simple classification output adapted for multiple metrics.
///
/// Supported metrics:
/// - Accuracy
/// - AUROC
/// - TopKAccuracy
/// - Perplexity
/// - Precision (via ConfusionStatsInput)
/// - Recall (via ConfusionStatsInput)
/// - FBetaScore (via ConfusionStatsInput)
/// - Loss.
#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The class logits or probabilities. Shape: \[batch_size, num_classes\].
pub output: Tensor<B, 2>,
/// The ground truth class index for each sample. Shape: \[batch_size\].
pub targets: Tensor<B, 1, Int>,
}
impl<B: Backend> ItemLazy for ClassificationOutput<B> {
type ItemSync = ClassificationOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [output, loss, targets] = Transaction::default()
.register(self.output)
.register(self.loss)
.register(self.targets)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
ClassificationOutput {
output: Tensor::from_data(output, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
}
}
}
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> AccuracyInput<B> {
AccuracyInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<AurocInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> AurocInput<B> {
AurocInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> TopKAccuracyInput<B> {
TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> PerplexityInput<B> {
PerplexityInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
let [_, num_classes] = self.output.dims();
if num_classes > 1 {
ConfusionStatsInput::new(
self.output.clone(),
self.targets.clone().one_hot(num_classes).bool(),
)
} else {
ConfusionStatsInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
}
}
}
/// Multi-label classification output adapted for multiple metrics.
///
/// Supported metrics:
/// - HammingScore
/// - Precision (via ConfusionStatsInput)
/// - Recall (via ConfusionStatsInput)
/// - FBetaScore (via ConfusionStatsInput)
/// - Loss
#[derive(new)]
pub struct MultiLabelClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The label logits or probabilities. Shape: \[batch_size, num_classes\].
pub output: Tensor<B, 2>,
/// The ground truth labels. Shape: \[batch_size, num_classes\].
pub targets: Tensor<B, 2, Int>,
}
impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
type ItemSync = MultiLabelClassificationOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [output, loss, targets] = Transaction::default()
.register(self.output)
.register(self.loss)
.register(self.targets)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
MultiLabelClassificationOutput {
output: Tensor::from_data(output, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
}
}
}
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
}
}

View File

@@ -0,0 +1,298 @@
use crate::metric::{
Metric, MetricName,
store::{Aggregate, Direction, EventStoreClient, Split},
};
/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
#[derive(Clone)]
pub enum StoppingCondition {
/// When no improvement has happened since the given number of epochs.
NoImprovementSince {
/// The number of epochs allowed to worsen before it gets better.
n_epochs: usize,
},
}
/// A strategy that checks if the training should be stopped.
pub trait EarlyStoppingStrategy: Send {
/// Update its current state and returns if the training should be stopped.
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
}
/// A helper trait to provide type-erased cloning.
pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
/// Clone into a boxed trait object.
fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
}
/// Blanket-implement `CloneEarlyStoppingStrategy` for any `T` that
/// already implements your strategy + `Clone` + `Send` + `'static`.
impl<T> CloneEarlyStoppingStrategy for T
where
T: EarlyStoppingStrategy + Clone + Send + 'static,
{
fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
Box::new(self.clone())
}
}
/// Now you can `impl Clone` for the boxed trait object.
impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
self.clone_box()
}
}
/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
/// during training or validation.
#[derive(Clone)]
pub struct MetricEarlyStoppingStrategy {
condition: StoppingCondition,
metric_name: MetricName,
aggregate: Aggregate,
direction: Direction,
split: Split,
best_epoch: usize,
best_value: f64,
warmup_epochs: Option<usize>,
}
impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
let current_value =
match store.find_metric(&self.metric_name, epoch, self.aggregate, &self.split) {
Some(value) => value,
None => {
log::warn!("Can't find metric for early stopping.");
return false;
}
};
let is_best = match self.direction {
Direction::Lowest => current_value < self.best_value,
Direction::Highest => current_value > self.best_value,
};
if is_best {
log::info!(
"New best epoch found {} {}: {}",
epoch,
self.metric_name,
current_value
);
self.best_value = current_value;
self.best_epoch = epoch;
return false;
}
if let Some(warmup_epochs) = self.warmup_epochs
&& epoch <= warmup_epochs
{
return false;
}
match self.condition {
StoppingCondition::NoImprovementSince { n_epochs } => {
let should_stop = epoch - self.best_epoch >= n_epochs;
if should_stop {
log::info!(
"Stopping training loop, no improvement since epoch {}, {}: {}, current \
epoch {}, {}: {}",
self.best_epoch,
self.metric_name,
self.best_value,
epoch,
self.metric_name,
current_value
);
}
should_stop
}
}
}
}
impl MetricEarlyStoppingStrategy {
/// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
/// during training or validation.
///
/// # Notes
///
/// The metric should be registered for early stopping to work, otherwise no data is collected.
pub fn new<Me: Metric>(
metric: &Me,
aggregate: Aggregate,
direction: Direction,
split: Split,
condition: StoppingCondition,
) -> Self {
let init_value = match direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
Self {
metric_name: metric.name(),
condition,
aggregate,
direction,
split,
best_epoch: 1,
best_value: init_value,
warmup_epochs: None,
}
}
/// Get the warmup period.
///
/// Early stopping will not trigger during the warmup epochs.
pub fn warmup_epochs(&self) -> Option<usize> {
self.warmup_epochs
}
/// Set the warmup epochs.
///
/// Early stopping will not trigger during the warmup epochs.
///
/// # Arguments
/// - `warmup`: the number of warmup epochs, or None.
pub fn with_warmup_epochs(self, warmup: Option<usize>) -> Self {
Self {
warmup_epochs: warmup,
..self
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
EventProcessorTraining, TestBackend,
logger::InMemoryMetricLogger,
metric::{
LossMetric,
processor::{
MetricsTraining, MinimalEventProcessor,
test_utils::{end_epoch, process_train},
},
store::LogEventStore,
},
};
use super::*;
#[test]
fn never_early_stop_while_it_is_improving() {
test_early_stopping(
None,
1,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(&[0.4, 0.3], false, "Should not stop when improving"),
(&[0.3, 0.3], false, "Should not stop when improving"),
(&[0.2, 0.3], false, "Should not stop when improving"),
],
);
}
#[test]
fn early_stop_when_no_improvement_since_two_epochs() {
test_early_stopping(
None,
2,
&[
(&[1.0, 0.5], false, "Should not stop first epoch"),
(&[0.5, 0.3], false, "Should not stop when improving"),
(
&[1.0, 3.0],
false,
"Should not stop first time it gets worse",
),
(
&[1.0, 2.0],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
#[test]
fn early_stopping_with_warmup() {
test_early_stopping(
Some(3),
2,
&[
(&[1.0, 0.5], false, "Should not stop during warmup"),
(&[1.0, 0.5], false, "Should not stop during warmup"),
(&[1.0, 0.5], false, "Should not stop during warmup"),
(
&[1.0, 0.5],
true,
"Should stop when not improving after warmup",
),
],
)
}
#[test]
fn early_stop_when_stays_equal() {
test_early_stopping(
None,
2,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(
&[0.5, 0.3],
false,
"Should not stop first time it stars the same",
),
(
&[0.5, 0.3],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
fn test_early_stopping(warmup: Option<usize>, n_epochs: usize, data: &[(&[f64], bool, &str)]) {
let loss = LossMetric::<TestBackend>::new();
let mut early_stopping = MetricEarlyStoppingStrategy::new(
&loss,
Aggregate::Mean,
Direction::Lowest,
Split::Train,
StoppingCondition::NoImprovementSince { n_epochs },
)
.with_warmup_epochs(warmup);
let mut store = LogEventStore::default();
let mut metrics = MetricsTraining::<f64, f64>::default();
store.register_logger(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(loss);
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
let mut epoch = 1;
processor.process_train(crate::LearnerEvent::Start);
for (points, should_start, comment) in data {
for point in points.iter() {
process_train(&mut processor, *point, epoch);
}
end_epoch(&mut processor, epoch);
assert_eq!(
*should_start,
early_stopping.should_stop(epoch, &store),
"{comment}"
);
epoch += 1;
}
}
}

View File

@@ -0,0 +1,24 @@
#[cfg(feature = "rl")]
mod rl;
#[cfg(feature = "rl")]
pub use rl::*;
mod application_logger;
mod base;
mod classification;
mod early_stopping;
mod regression;
mod sequence;
mod summary;
mod supervised;
mod train_val;
pub use application_logger::*;
pub use base::*;
pub use classification::*;
pub use early_stopping::*;
pub use regression::*;
pub use sequence::*;
pub use summary::*;
pub use supervised::*;
pub use train_val::*;

View File

@@ -0,0 +1,46 @@
use crate::metric::processor::ItemLazy;
use crate::metric::{Adaptor, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Tensor, Transaction};
use burn_ndarray::NdArray;
/// Regression output adapted for the loss metric.
#[derive(new)]
pub struct RegressionOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The predicted values. Shape: \[batch_size, num_targets\].
pub output: Tensor<B, 2>,
/// The ground truth values. Shape: \[batch_size, num_targets\].
pub targets: Tensor<B, 2>,
}
impl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
impl<B: Backend> ItemLazy for RegressionOutput<B> {
type ItemSync = RegressionOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [output, loss, targets] = Transaction::default()
.register(self.output)
.register(self.loss)
.register(self.targets)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
RegressionOutput {
output: Tensor::from_data(output, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
}
}
}

View File

@@ -0,0 +1,75 @@
use burn_core::tensor::Device;
use burn_rl::{Policy, PolicyLearner, PolicyState};
use crate::RLAgentRecord;
use crate::{
RLComponentsTypes, RLPolicyRecord,
checkpoint::Checkpointer,
checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},
metric::store::EventStoreClient,
};
#[derive(new)]
/// Used to create, delete, or load checkpoints of the training process.
pub struct RLCheckpointer<RLC: RLComponentsTypes> {
policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
strategy: Box<dyn CheckpointingStrategy>,
}
impl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {
/// Create checkpoint for the training process.
pub fn checkpoint(
&mut self,
policy: &RLC::PolicyState,
learning_agent: &RLC::LearningAgent,
epoch: usize,
store: &EventStoreClient,
) {
let actions = self.strategy.checkpointing(epoch, store);
for action in actions {
match action {
CheckpointingAction::Delete(epoch) => {
self.policy
.delete(epoch)
.expect("Can delete policy checkpoint.");
self.learning_agent
.delete(epoch)
.expect("Can delete learning agent checkpoint.")
}
CheckpointingAction::Save => {
self.policy
.save(epoch, policy.clone().into_record())
.expect("Can save policy checkpoint.");
self.learning_agent
.save(epoch, learning_agent.record())
.expect("Can save learning agent checkpoint.");
}
}
}
}
/// Load a training checkpoint.
pub fn load_checkpoint(
&self,
learning_agent: RLC::LearningAgent,
device: &Device<RLC::Backend>,
epoch: usize,
) -> RLC::LearningAgent {
let record = self
.policy
.restore(epoch, device)
.expect("Can load model checkpoint.");
let policy = learning_agent.policy().load_record(record);
let record = self
.learning_agent
.restore(epoch, device)
.expect("Can load learning agent checkpoint.");
let mut learning_agent = learning_agent.load_record(record);
learning_agent.update_policy(policy);
learning_agent
}
}

View File

@@ -0,0 +1,115 @@
use std::marker::PhantomData;
use burn_core::tensor::backend::AutodiffBackend;
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};
use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};
/// All components used by the reinforcement learning paradigm, grouped in one trait.
pub trait RLComponentsTypes {
/// The backend used for training.
type Backend: AutodiffBackend;
/// The learning environment.
type Env: Environment<State = Self::State, Action = Self::Action> + 'static;
/// Specifies how to initialize the environment.
type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;
/// The type of the environment state.
type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;
/// The type of the environment action.
type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>
+ Into<<Self::Policy as Policy<Self::Backend>>::Action>
+ Clone
+ Send
+ 'static;
/// The policy used to take actions in the environment.
type Policy: Policy<
Self::Backend,
Observation = Self::PolicyObs,
ActionDistribution = Self::PolicyAD,
Action = Self::PolicyAction,
ActionContext = Self::ActionContext,
PolicyState = Self::PolicyState,
> + Send
+ 'static;
/// The policy's observation type.
type PolicyObs: Clone + Send + Batchable + 'static;
/// The policy's action distribution type.
type PolicyAD: Clone + Send + Batchable;
/// The policy's action type.
type PolicyAction: Clone + Send + Batchable;
/// Additional data as context for an agent's action.
type ActionContext: ItemLazy + Clone + Send + 'static;
/// The state of the parameterized policy.
type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;
/// The learning agent.
type LearningAgent: PolicyLearner<
Self::Backend,
TrainContext = Self::TrainingOutput,
InnerPolicy = Self::Policy,
> + Send
+ 'static;
/// The output data of a training step.
type TrainingOutput: ItemLazy + Clone + Send;
}
/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait.
pub struct RLComponentsMarker<B, E, EI, A> {
_backend: PhantomData<B>,
_env: PhantomData<E>,
_env_init: PhantomData<EI>,
_agent: PhantomData<A>,
}
impl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>
where
B: AutodiffBackend,
E: Environment + 'static,
EI: EnvironmentInit<E> + Send + 'static,
A: PolicyLearner<B> + Send + 'static,
A::TrainContext: ItemLazy + Clone + Send,
A::InnerPolicy: Policy<B> + Send,
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
+ Into<<A::InnerPolicy as Policy<B>>::Action>
+ Clone
+ Send
+ 'static,
{
type Backend = B;
type Env = E;
type EnvInit = EI;
type LearningAgent = A;
type Policy = A::InnerPolicy;
type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;
type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;
type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;
type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;
type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;
type TrainingOutput = A::TrainContext;
type State = E::State;
type Action = E::Action;
}
pub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
<RLC as RLComponentsTypes>::Backend,
>>::InnerPolicy;
/// The event processor type for reinforcement learning.
pub type RLEventProcessorType<RLC> = AsyncProcessorTraining<
RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,
AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,
>;
/// The record of the policy.
pub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<
<RLC as RLComponentsTypes>::Backend,
>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;
/// The record of the learning agent.
pub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
<RLC as RLComponentsTypes>::Backend,
>>::Record;

View File

@@ -0,0 +1,703 @@
use rand::prelude::SliceRandom;
use std::{
sync::mpsc::{Receiver, Sender},
thread::spawn,
};
use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};
use burn_rl::EnvironmentInit;
use burn_rl::Policy;
use burn_rl::Transition;
use burn_rl::{AsyncPolicy, Environment};
use crate::{
AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,
RlPolicy, TimeStep, Trajectory,
};
enum RequestMessage {
Step(),
Episode(),
}
/// Configuration for an async agent/environment loop.
pub struct AsyncAgentEnvLoopConfig {
/// If the loop is used for evaluation (as opposed to training).
pub eval: bool,
/// If the agent should take action deterministically.
pub deterministic: bool,
/// An arbitrary ID for the loop.
pub id: usize,
}
/// An asynchronous agent/environement interface.
pub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {
eval: bool,
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
request_sender: Sender<RequestMessage>,
}
impl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {
/// Create a new asynchronous runner.
///
/// # Arguments
///
/// * `env_init` - A function returning an environment instance.
/// * `agent` - An [AsyncPolicy](AsyncPolicy) taking actions in the loop.
/// * `config` - An [AsyncAgentEnvLoopConfig](AsyncAgentEnvLoopConfig).
/// * `transition_sender` - Optional sender for transitions if you want to drive the requests from outside of the loop instance.
/// * `trajectory_sender` - Optional sender for trajectories if you want to drive the requests from outside of the loop instance.
///
/// # Returns
///
/// An async Agent/Environement loop.
pub fn new(
env_init: RLC::EnvInit,
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
config: AsyncAgentEnvLoopConfig,
transition_device: &Device<BT>,
transition_sender: Option<Sender<RLTimeStep<BT, RLC>>>,
trajectory_sender: Option<Sender<RLTrajectory<BT, RLC>>>,
) -> Self {
let (loop_transition_sender, transition_receiver) = std::sync::mpsc::channel();
let (loop_trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
let (request_sender, request_receiver) = std::sync::mpsc::channel();
let loop_transition_sender = transition_sender.unwrap_or(loop_transition_sender);
let loop_trajectory_sender = trajectory_sender.unwrap_or(loop_trajectory_sender);
let device = transition_device.clone();
let mut loop_agent = agent.clone();
let eval = config.eval;
let mut current_steps = vec![];
let mut current_reward = 0.0;
let mut step_num = 0;
spawn(move || {
let mut env = env_init.init();
env.reset();
let mut request_episode = false;
loop {
let state = env.state();
let (action, context) =
loop_agent.action(state.clone().into(), config.deterministic);
let env_action = RLC::Action::from(action);
let step_result = env.step(env_action.clone());
current_reward += step_result.reward;
step_num += 1;
let transition = Transition::new(
state.clone(),
step_result.next_state,
env_action,
Tensor::from_data([step_result.reward], &device),
Tensor::from_data(
[(step_result.done || step_result.truncated) as i32 as f64],
&device,
),
);
if !request_episode {
loop_agent.decrement_agents(1);
let request = match request_receiver.recv() {
Ok(req) => req,
Err(err) => {
log::error!("Error in env runner : {}", err);
break;
}
};
loop_agent.increment_agents(1);
match request {
RequestMessage::Step() => (),
RequestMessage::Episode() => request_episode = true,
}
}
let time_step = TimeStep {
env_id: config.id,
transition,
done: step_result.done,
ep_len: step_num,
cum_reward: current_reward,
action_context: context[0].clone(),
};
current_steps.push(time_step.clone());
if !request_episode && let Err(err) = loop_transition_sender.send(time_step) {
log::error!("Error in env runner : {}", err);
break;
}
if step_result.done || step_result.truncated {
if request_episode {
request_episode = false;
loop_trajectory_sender
.send(Trajectory {
timesteps: current_steps.clone(),
})
.expect("Can send trajectory to main thread.");
}
current_steps.clear();
env.reset();
current_reward = 0.;
step_num = 0;
}
}
});
Self {
eval,
agent,
transition_receiver,
trajectory_receiver,
request_sender,
}
}
}
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>
where
BT: Backend,
RLC: RLComponentsTypes,
{
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>> {
let mut items = vec![];
for _ in 0..num_steps {
self.request_sender
.send(RequestMessage::Step())
.expect("Can request transitions.");
let transition = self
.transition_receiver
.recv()
.expect("Can receive transitions.");
items.push(transition.clone());
if !self.eval {
progress.items_processed += 1;
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
transition.action_context,
progress.clone(),
None,
)));
if transition.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: transition.ep_len,
cum_reward: transition.cum_reward,
},
progress.clone(),
None,
)));
}
}
if interrupter.should_stop() {
break;
}
}
items
}
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
_progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>> {
let mut items = vec![];
self.agent.increment_agents(1);
for episode_num in 0..num_episodes {
self.request_sender
.send(RequestMessage::Episode())
.expect("Can request episodes.");
let trajectory = self
.trajectory_receiver
.recv()
.expect("Main thread can receive trajectory.");
for (i, step) in trajectory.timesteps.iter().enumerate() {
// TODO : clean this.
if self.eval {
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
),
));
}
} else {
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
)));
}
}
}
items.push(trajectory);
if interrupter.should_stop() {
break;
}
}
self.agent.decrement_agents(1);
items
}
fn update_policy(&mut self, update: RLC::PolicyState) {
self.agent.update(update);
}
fn policy(&self) -> RLC::PolicyState {
self.agent.state()
}
}
/// An asynchronous runner for multiple agent/environement interfaces.
pub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
num_envs: usize,
eval: bool,
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
request_senders: Vec<Sender<RequestMessage>>,
}
impl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {
/// Create a new asynchronous runner for multiple agent/environement interfaces.
pub fn new(
num_envs: usize,
env_init: RLC::EnvInit,
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
eval: bool,
deterministic: bool,
device: &Device<BT>,
) -> Self {
let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
let mut request_senders = vec![];
// Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps.
agent.increment_agents(num_envs);
for i in 0..num_envs {
let config = AsyncAgentEnvLoopConfig {
eval,
deterministic,
id: i,
};
let runner = AgentEnvAsyncLoop::<BT, RLC>::new(
env_init.clone(),
agent.clone(),
config,
&device.clone(),
Some(transition_sender.clone()),
Some(trajectory_sender.clone()),
);
request_senders.push(runner.request_sender.clone());
}
// Double batching : The environments are always one step ahead.
request_senders.iter().for_each(|s| {
s.send(RequestMessage::Step())
.expect("Main thread can send step requests.")
});
Self {
num_envs,
eval,
agent: agent.clone(),
transition_receiver,
trajectory_receiver,
request_senders,
}
}
}
impl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>
where
BT: Backend,
RLC: RLComponentsTypes,
{
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>> {
let mut items = vec![];
for _ in 0..num_steps {
let transition = self
.transition_receiver
.recv()
.expect("Can receive transitions.");
items.push(transition.clone());
self.request_senders[transition.env_id]
.send(RequestMessage::Step())
.expect("Main thread can request steps.");
if !self.eval {
progress.items_processed += 1;
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
transition.action_context,
progress.clone(),
None,
)));
if transition.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: transition.ep_len,
cum_reward: transition.cum_reward,
},
progress.clone(),
None,
)));
}
}
if interrupter.should_stop() {
break;
}
}
items
}
fn update_policy(&mut self, update: RLC::PolicyState) {
self.agent.update(update);
}
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
_progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>> {
// Send `num_episodes` initial requests.
let mut idx = vec![];
if num_episodes < self.num_envs {
let mut rng = rand::rng();
let mut vec: Vec<usize> = (0..self.num_envs).collect();
vec.shuffle(&mut rng);
idx = vec.into_iter().take(num_episodes).collect();
} else {
idx = (0..self.num_envs).collect();
}
let num_requests = self.num_envs.min(num_episodes);
idx.into_iter().for_each(|i| {
self.request_senders[i]
.send(RequestMessage::Episode())
.expect("Main thread can request steps.");
});
let mut items = vec![];
for episode_num in 0..num_episodes {
let trajectory = self
.trajectory_receiver
.recv()
.expect("Can receive trajectory.");
items.push(trajectory.clone());
if items.len() + num_requests <= num_episodes {
self.request_senders[trajectory.timesteps[0].env_id]
.send(RequestMessage::Episode())
.expect("Main thread can request steps.");
}
for (i, step) in trajectory.timesteps.iter().enumerate() {
if self.eval {
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
),
));
}
} else {
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
)));
}
}
}
if interrupter.should_stop() {
break;
}
}
items
}
fn policy(&self) -> RLC::PolicyState {
self.agent.state()
}
}
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests {
use burn_core::data::dataloader::Progress;
use burn_rl::AsyncPolicy;
use crate::learner::rl::env_runner::async_runner::AsyncAgentEnvLoopConfig;
use crate::learner::rl::env_runner::base::AgentEnvLoop;
use crate::learner::tests::{MockPolicyState, MockProcessor};
use crate::{
AgentEnvAsyncLoop, TestBackend,
learner::tests::{MockEnvInit, MockPolicy, MockRLComponents},
};
use crate::{AsyncProcessorTraining, Interrupter, MultiAgentEnvLoop};
fn setup_async_loop(
state: usize,
eval: bool,
deterministic: bool,
) -> AgentEnvAsyncLoop<TestBackend, MockRLComponents> {
let env_init = MockEnvInit;
let agent = MockPolicy(state);
let config = AsyncAgentEnvLoopConfig {
eval,
deterministic,
id: 0,
};
AgentEnvAsyncLoop::<TestBackend, MockRLComponents>::new(
env_init,
AsyncPolicy::new(1, agent),
config,
&Default::default(),
None,
None,
)
}
fn setup_multi_loop(
num_envs: usize,
autobatch_size: usize,
state: usize,
eval: bool,
deterministic: bool,
) -> MultiAgentEnvLoop<TestBackend, MockRLComponents> {
let env_init = MockEnvInit;
let agent = MockPolicy(state);
MultiAgentEnvLoop::<TestBackend, MockRLComponents>::new(
num_envs,
env_init,
AsyncPolicy::new(autobatch_size, agent),
eval,
deterministic,
&Default::default(),
)
}
#[test]
fn test_policy_async_loop() {
let runner = setup_async_loop(1000, false, false);
let policy_state = runner.policy();
assert_eq!(policy_state.0, 1000);
}
#[test]
fn test_update_policy_async_loop() {
let mut runner = setup_async_loop(0, false, false);
runner.update_policy(MockPolicyState(1));
assert_eq!(runner.policy().0, 1);
}
#[test]
fn run_steps_returns_requested_number_async_loop() {
let mut runner = setup_async_loop(0, false, false);
let mut processor = AsyncProcessorTraining::new(MockProcessor);
let interrupter = Interrupter::new();
let mut progress = Progress {
items_processed: 0,
items_total: 1,
};
let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
assert_eq!(steps.len(), 1);
let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
assert_eq!(steps.len(), 8);
}
#[test]
fn run_episodes_returns_requested_number_async_loop() {
let mut runner = setup_async_loop(0, false, false);
let mut processor = AsyncProcessorTraining::new(MockProcessor);
let interrupter = Interrupter::new();
let mut progress = Progress {
items_processed: 0,
items_total: 1,
};
let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
assert_eq!(trajectories.len(), 1);
assert_ne!(trajectories[0].timesteps.len(), 0);
let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
assert_eq!(trajectories.len(), 8);
for i in 0..8 {
assert_ne!(trajectories[i].timesteps.len(), 0);
}
}
#[test]
fn test_policy_multi_loop() {
let runner = setup_multi_loop(4, 4, 1000, false, false);
let policy_state = runner.policy();
assert_eq!(policy_state.0, 1000);
}
#[test]
fn test_update_policy_multi_loop() {
let mut runner = setup_multi_loop(4, 4, 0, false, false);
runner.update_policy(MockPolicyState(1));
assert_eq!(runner.policy().0, 1);
}
#[test]
fn run_steps_returns_requested_number_multi_loop() {
fn run_test(num_envs: usize, autobatch_size: usize) {
let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
let mut processor = AsyncProcessorTraining::new(MockProcessor);
let interrupter = Interrupter::new();
let mut progress = Progress {
items_processed: 0,
items_total: 1,
};
// Kickstart tests by running some steps to make sure it's not a double batching edge case success.
let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
assert_eq!(steps.len(), 8);
for i in 0..16 {
let steps = runner.run_steps(i, &mut processor, &interrupter, &mut progress);
assert_eq!(steps.len(), i);
}
}
// num_envs == autobatch_size
run_test(1, 1);
run_test(4, 4);
// num_envs < autobatch_size
run_test(1, 2);
run_test(1, 3);
run_test(2, 3);
run_test(2, 4);
run_test(5, 19);
// num_envs > autobatch_size
run_test(2, 1);
run_test(8, 1);
run_test(3, 2);
run_test(8, 2);
run_test(8, 3);
run_test(8, 7);
}
#[test]
fn run_episodes_returns_requested_number_multi_loop() {
fn run_test(num_envs: usize, autobatch_size: usize) {
let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
let mut processor = AsyncProcessorTraining::new(MockProcessor);
let interrupter = Interrupter::new();
let mut progress = Progress {
items_processed: 0,
items_total: 1,
};
// Kickstart tests by running some episodes to make sure it's not a double batching edge case success.
let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
assert_eq!(trajectories.len(), 8);
for j in 0..8 {
assert_ne!(trajectories[j].timesteps.len(), 0);
}
for i in 0..16 {
let trajectories =
runner.run_episodes(i, &mut processor, &interrupter, &mut progress);
assert_eq!(trajectories.len(), i);
for j in 0..i {
assert_ne!(trajectories[j].timesteps.len(), 0);
}
}
}
// num_envs == autobatch_size
run_test(1, 1);
run_test(4, 4);
// num_envs < autobatch_size
run_test(1, 2);
run_test(1, 3);
run_test(2, 3);
run_test(2, 4);
run_test(5, 19);
// num_envs > autobatch_size
run_test(2, 1);
run_test(8, 1);
run_test(3, 2);
run_test(8, 2);
run_test(8, 3);
run_test(8, 7);
}
}

View File

@@ -0,0 +1,343 @@
use std::marker::PhantomData;
use burn_core::data::dataloader::Progress;
use burn_core::{Tensor, prelude::Backend};
use burn_rl::Policy;
use burn_rl::Transition;
use burn_rl::{Environment, EnvironmentInit};
use crate::RLEvent;
use crate::{
AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
RLEventProcessorType,
};
use crate::{Interrupter, RLComponentsTypes};
/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep).
#[derive(Clone, new)]
pub struct Trajectory<B: Backend, S, A, C> {
/// A list of ordered [TimeStep](TimeStep)s.
pub timesteps: Vec<TimeStep<B, S, A, C>>,
}
/// A timestep debscribing an iteration of the state/decision process.
#[derive(Clone)]
pub struct TimeStep<B: Backend, S, A, C> {
/// The environment id.
pub env_id: usize,
/// The [burn_rl::Transition](burn_rl::Transition).
pub transition: Transition<B, S, A>,
/// True if the environment reaches a terminal state.
pub done: bool,
/// The running length of the current episode.
pub ep_len: usize,
/// The running cumulative reward.
pub cum_reward: f64,
/// The action's context for this timestep.
pub action_context: C,
}
pub(crate) type RLTimeStep<B, RLC> = TimeStep<
B,
<RLC as RLComponentsTypes>::State,
<RLC as RLComponentsTypes>::Action,
<RLC as RLComponentsTypes>::ActionContext,
>;
pub(crate) type RLTrajectory<B, RLC> = Trajectory<
B,
<RLC as RLComponentsTypes>::State,
<RLC as RLComponentsTypes>::Action,
<RLC as RLComponentsTypes>::ActionContext,
>;
/// Trait for a structure that implements an agent/environement interface.
pub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
/// Run a certain number of timesteps.
///
/// # Arguments
///
/// * `num_steps` - The number of time_steps to run.
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
/// * `num_steps` - The number of time_steps to run.
/// * `progress` - A mutable reference to the learning progress.
///
/// # Returns
///
/// A list of ordered timesteps.
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>>;
/// Run a certain number of episodes.
///
/// # Arguments
///
/// * `num_episodes` - The number of episodes to run.
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
/// * `progress` - A mutable reference to the learning progress.
///
/// # Returns
///
/// A list of ordered timesteps.
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>>;
/// Update the runner's agent.
fn update_policy(&mut self, update: RLC::PolicyState);
/// Get the state of the runner's agent.
fn policy(&self) -> RLC::PolicyState;
}
/// A simple, synchronized agent/environement interface.
pub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {
env: RLC::Env,
eval: bool,
agent: RLC::Policy,
deterministic: bool,
current_reward: f64,
run_num: usize,
step_num: usize,
_backend: PhantomData<B>,
}
impl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {
/// Create a new base runner.
pub fn new(
env_init: RLC::EnvInit,
agent: RLC::Policy,
eval: bool,
deterministic: bool,
) -> Self {
let mut env = env_init.init();
env.reset();
Self {
env,
eval,
agent: agent.clone(),
deterministic,
current_reward: 0.0,
run_num: 0,
step_num: 0,
_backend: PhantomData,
}
}
}
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>
where
BT: Backend,
RLC: RLComponentsTypes,
{
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>> {
let mut items = vec![];
let device = Default::default();
for _ in 0..num_steps {
let state = self.env.state();
let (action, context) = self.agent.action(state.clone().into(), self.deterministic);
let step_result = self.env.step(RLC::Action::from(action.clone()));
self.current_reward += step_result.reward;
self.step_num += 1;
let transition = Transition::new(
state.clone(),
step_result.next_state,
RLC::Action::from(action),
Tensor::from_data([step_result.reward], &device),
Tensor::from_data(
[(step_result.done || step_result.truncated) as i32 as f64],
&device,
),
);
items.push(TimeStep {
env_id: 0,
transition,
done: step_result.done,
ep_len: self.step_num,
cum_reward: self.current_reward,
action_context: context[0].clone(),
});
if !self.eval {
progress.items_processed += 1;
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
context[0].clone(),
progress.clone(),
None,
)));
if step_result.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: self.step_num,
cum_reward: self.current_reward,
},
progress.clone(),
None,
)));
}
}
if interrupter.should_stop() {
break;
}
if step_result.done || step_result.truncated {
self.env.reset();
self.current_reward = 0.;
self.step_num = 0;
self.run_num += 1;
}
}
items
}
fn update_policy(&mut self, update: RLC::PolicyState) {
self.agent.update(update);
}
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>> {
self.env.reset();
let mut items = vec![];
for ep in 0..num_episodes {
let mut steps = vec![];
loop {
let step = self.run_steps(1, processor, interrupter, progress)[0].clone();
steps.push(step.clone());
if self.eval {
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(steps.len() + 1, steps.len() + 1),
None,
)));
if step.done {
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(ep + 1, num_episodes),
None,
),
));
}
}
if interrupter.should_stop() || step.done {
break;
}
}
items.push(Trajectory::new(steps));
if interrupter.should_stop() {
break;
}
}
items
}
fn policy(&self) -> RLC::PolicyState {
self.agent.state()
}
}
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests {
use crate::{AsyncProcessorTraining, TestBackend};
use crate::learner::tests::{
MockEnvInit, MockPolicy, MockPolicyState, MockProcessor, MockRLComponents,
};
use super::*;
fn setup(
state: usize,
eval: bool,
deterministic: bool,
) -> AgentEnvBaseLoop<TestBackend, MockRLComponents> {
let env_init = MockEnvInit;
let agent = MockPolicy(state);
AgentEnvBaseLoop::<TestBackend, MockRLComponents>::new(env_init, agent, eval, deterministic)
}
#[test]
fn test_policy_returns_agent_state() {
let runner = setup(1000, false, false);
let policy_state = runner.policy();
assert_eq!(policy_state.0, 1000);
}
#[test]
fn test_update_policy() {
let mut runner = setup(0, false, false);
runner.update_policy(MockPolicyState(1));
assert_eq!(runner.policy().0, 1);
}
#[test]
fn run_steps_returns_requested_number() {
let mut runner = setup(0, false, false);
let mut processor = AsyncProcessorTraining::new(MockProcessor);
let interrupter = Interrupter::new();
let mut progress = Progress {
items_processed: 0,
items_total: 1,
};
let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
assert_eq!(steps.len(), 1);
let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
assert_eq!(steps.len(), 8);
}
#[test]
fn run_episodes_returns_requested_number() {
let mut runner = setup(0, false, false);
let mut processor = AsyncProcessorTraining::new(MockProcessor);
let interrupter = Interrupter::new();
let mut progress = Progress {
items_processed: 0,
items_total: 1,
};
let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
assert_eq!(trajectories.len(), 1);
assert_ne!(trajectories[0].timesteps.len(), 0);
let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
assert_eq!(trajectories.len(), 8);
for i in 0..8 {
assert_ne!(trajectories[i].timesteps.len(), 0);
}
}
}

View File

@@ -0,0 +1,302 @@
mod async_runner;
mod base;
pub use async_runner::*;
pub use base::*;
#[cfg(test)]
pub(crate) mod tests {
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyState};
use crate::tests::TestAutodiffBackend;
use crate::{
AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLComponentsTypes, RLEvent,
};
use burn_rl::{LearnerTransitionBatch, PolicyLearner, RLTrainOutput, StepResult};
/// Mock policy for testing
///
/// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)
/// containing a list of 0s of the same length as the observation.
///
/// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockPolicyAction](MockPolicyAction) with a list of actions of the same length as the observation.
/// The actions are all 1 if the call is requested as deterministic, or else 0.
#[derive(Clone)]
pub(crate) struct MockPolicy(pub usize);
impl Policy<TestAutodiffBackend> for MockPolicy {
type Observation = MockObservation;
type ActionDistribution = MockActionDistribution;
type Action = MockPolicyAction;
type ActionContext = MockActionContext;
type PolicyState = MockPolicyState;
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {
let mut dists = vec![];
for _ in obs.0 {
dists.push(MockActionDistribution(vec![0.]));
}
MockActionDistribution::batch(dists)
}
fn action(
&mut self,
obs: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>) {
let mut actions = vec![];
let mut contexts = vec![];
for _ in obs.0 {
if deterministic {
actions.push(MockPolicyAction(vec![1]));
} else {
actions.push(MockPolicyAction(vec![0]));
}
contexts.push(MockActionContext);
}
(MockPolicyAction::batch(actions), contexts)
}
fn update(&mut self, update: Self::PolicyState) {
self.0 = update.0;
}
fn state(&self) -> Self::PolicyState {
MockPolicyState(self.0)
}
fn load_record(
self,
_record: <Self::PolicyState as PolicyState<TestAutodiffBackend>>::Record,
) -> Self {
self
}
}
/// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.
#[derive(Clone)]
pub(crate) struct MockObservation(pub Vec<f32>);
/// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
#[derive(Clone)]
pub(crate) struct MockPolicyAction(pub Vec<i32>);
/// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
#[derive(Clone)]
pub(crate) struct MockActionDistribution(Vec<f32>);
#[derive(Clone)]
pub(crate) struct MockActionContext;
/// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.
#[derive(Clone)]
pub(crate) struct MockPolicyState(pub usize);
impl PolicyState<TestAutodiffBackend> for MockPolicyState {
type Record = ();
fn into_record(self) -> Self::Record {}
fn load_record(&self, _record: Self::Record) -> Self {
self.clone()
}
}
impl Batchable for MockObservation {
fn batch(items: Vec<Self>) -> Self {
MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())
}
fn unbatch(self) -> Vec<Self> {
vec![MockObservation(self.0)]
}
}
impl Batchable for MockPolicyAction {
fn batch(items: Vec<Self>) -> Self {
MockPolicyAction(items.iter().flat_map(|m| m.0.clone()).collect())
}
fn unbatch(self) -> Vec<Self> {
let mut actions = vec![];
for a in self.0 {
actions.push(MockPolicyAction(vec![a]));
}
actions
}
}
impl Batchable for MockActionDistribution {
fn batch(items: Vec<Self>) -> Self {
MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())
}
fn unbatch(self) -> Vec<Self> {
let mut dists = vec![];
for _ in self.0 {
dists.push(MockActionDistribution(vec![0.]));
}
dists
}
}
/// Mock environment for testing
#[derive(Clone)]
pub(crate) struct MockEnv {
counter: usize,
}
#[derive(Clone, Debug)]
pub(crate) struct MockState;
#[derive(Clone, Debug)]
pub(crate) struct MockAction(pub i32);
impl From<MockState> for MockObservation {
fn from(_value: MockState) -> Self {
MockObservation(vec![0.])
}
}
impl From<MockPolicyAction> for MockAction {
fn from(value: MockPolicyAction) -> Self {
MockAction(value.0[0])
}
}
impl From<MockAction> for MockPolicyAction {
fn from(value: MockAction) -> Self {
MockPolicyAction(vec![value.0])
}
}
impl ItemLazy for MockActionContext {
type ItemSync = MockActionContext;
fn sync(self) -> Self::ItemSync {
self
}
}
impl MockEnv {
fn new() -> Self {
Self { counter: 0 }
}
}
impl Environment for MockEnv {
type State = MockState;
type Action = MockAction;
const MAX_STEPS: usize = 5;
fn reset(&mut self) {
self.counter = 0;
}
fn step(&mut self, _action: Self::Action) -> StepResult<Self::State> {
self.counter += 1;
let done = self.counter >= Self::MAX_STEPS;
burn_rl::StepResult {
next_state: MockState,
reward: 1.0,
done,
truncated: false,
}
}
fn state(&self) -> Self::State {
MockState
}
}
/// Mock environment init for testing
#[derive(Clone)]
pub(crate) struct MockEnvInit;
impl EnvironmentInit<MockEnv> for MockEnvInit {
fn init(&self) -> MockEnv {
MockEnv::new()
}
}
// Mock RLComponentsTypes for testing
pub(crate) struct MockRLComponents;
impl RLComponentsTypes for MockRLComponents {
type Backend = TestAutodiffBackend;
type Env = MockEnv;
type EnvInit = MockEnvInit;
type State = MockState;
type Action = MockAction;
type Policy = MockPolicy;
type PolicyObs = MockObservation;
type PolicyAD = MockActionDistribution;
type PolicyAction = MockPolicyAction;
type ActionContext = MockActionContext;
type PolicyState = MockPolicyState;
type LearningAgent = MockLearningAgent;
type TrainingOutput = ();
}
// Mock learning agent for testing
#[derive(Clone)]
pub(crate) struct MockLearningAgent;
impl PolicyLearner<TestAutodiffBackend> for MockLearningAgent {
type InnerPolicy = MockPolicy;
type TrainContext = ();
type Record = ();
fn train(
&mut self,
_input: LearnerTransitionBatch<TestAutodiffBackend, Self::InnerPolicy>,
) -> RLTrainOutput<
Self::TrainContext,
<Self::InnerPolicy as Policy<TestAutodiffBackend>>::PolicyState,
> {
unimplemented!()
}
fn policy(&self) -> Self::InnerPolicy {
unimplemented!()
}
fn update_policy(&mut self, _update: Self::InnerPolicy) {
unimplemented!()
}
fn record(&self) -> Self::Record {
unimplemented!()
}
fn load_record(self, _record: Self::Record) -> Self {
unimplemented!()
}
}
// Mock event processor for testing
pub(crate) struct MockProcessor;
impl
EventProcessorTraining<
RLEvent<(), MockActionContext>,
AgentEvaluationEvent<MockActionContext>,
> for MockProcessor
{
fn process_train(&mut self, _event: RLEvent<(), MockActionContext>) {
// Mock process train
}
fn process_valid(&mut self, _event: AgentEvaluationEvent<MockActionContext>) {
// Mock process valid
}
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
unimplemented!()
}
}
}

View File

@@ -0,0 +1,15 @@
mod checkpointer;
mod components;
mod env_runner;
mod off_policy;
mod output;
mod paradigm;
mod strategy;
pub use checkpointer::*;
pub use components::*;
pub use env_runner::*;
pub use off_policy::*;
pub use output::*;
pub use paradigm::*;
pub use strategy::*;

View File

@@ -0,0 +1,189 @@
use std::marker::PhantomData;
use crate::{
AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem,
EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent,
RLEventProcessorType, RLStrategy,
};
use burn_core::{self as burn};
use burn_core::{config::Config, data::dataloader::Progress};
use burn_ndarray::NdArray;
use burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer};
/// Parameters of an on policy training with multi environments and double-batching.
#[derive(Config, Debug)]
pub struct OffPolicyConfig {
/// The number of environments to run simultaneously for experience collection.
#[config(default = 1)]
pub num_envs: usize,
/// Number of environment state to accumulate before running one step of inference with the policy.
/// Must be equal or less than the number of simultaneous environments.
#[config(default = 1)]
pub autobatch_size: usize,
/// Max number of transitions stored in the replay buffer.
#[config(default = 1024)]
pub replay_buffer_size: usize,
/// The number of steps to collect between each step of training.
#[config(default = 1)]
pub train_interval: usize,
/// Number of optimization steps done each `train_interval`.
#[config(default = 1)]
pub train_steps: usize,
/// The number of steps to collect between each evaluation.
#[config(default = 10_000)]
pub eval_interval: usize,
/// The number of episodes to run for each evaluation.
#[config(default = 1)]
pub eval_episodes: usize,
/// The number of transition to train on.
#[config(default = 32)]
pub train_batch_size: usize,
/// Number of steps to collect before starting to train.
#[config(default = 0)]
pub warmup_steps: usize,
}
/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.
pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
config: OffPolicyConfig,
_components: PhantomData<RLC>,
}
impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
/// Create a new off-policy base strategy.
pub fn new(config: OffPolicyConfig) -> Self {
Self {
config,
_components: PhantomData,
}
}
}
impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
where
RLC: RLComponentsTypes,
RLC::PolicyObs: SliceAccess<RLC::Backend>,
RLC::PolicyAction: SliceAccess<RLC::Backend>,
{
fn train_loop(
&self,
training_components: RLComponents<RLC>,
learner_agent: &mut RLC::LearningAgent,
starting_epoch: usize,
env_init: RLC::EnvInit,
) -> (RLC::Policy, RLEventProcessorType<RLC>) {
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let num_steps_total = training_components.num_steps;
let mut env_runner = MultiAgentEnvLoop::<NdArray, RLC>::new(
self.config.num_envs,
env_init.clone(),
AsyncPolicy::new(
self.config.num_envs.min(self.config.autobatch_size),
learner_agent.policy(),
),
false,
false,
&Default::default(),
);
let runner_config = AsyncAgentEnvLoopConfig {
eval: true,
deterministic: true,
id: 0,
};
let mut env_runner_valid = AgentEnvAsyncLoop::<NdArray, RLC>::new(
env_init,
AsyncPolicy::new(1, learner_agent.policy()),
runner_config,
&Default::default(),
None,
None,
);
let device: <RLC::Backend as burn_core::prelude::Backend>::Device = Default::default();
let mut transition_buffer = TransitionBuffer::<
RLC::Backend,
RLC::PolicyObs,
RLC::PolicyAction,
>::new(self.config.replay_buffer_size, &device);
let mut valid_next = self.config.eval_interval + starting_epoch - 1;
let mut progress = Progress {
items_processed: starting_epoch,
items_total: num_steps_total,
};
let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
None;
while progress.items_processed < num_steps_total {
if training_components.interrupter.should_stop() {
let reason = training_components
.interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
break;
}
let previous_steps = progress.items_processed;
let items = env_runner.run_steps(
self.config.train_interval,
&mut event_processor,
&training_components.interrupter,
&mut progress,
);
for item in &items {
let t = &item.transition;
let state: RLC::PolicyObs = t.state.clone().into();
let next_state: RLC::PolicyObs = t.next_state.clone().into();
let action: RLC::PolicyAction = t.action.clone().into();
let reward = t.reward.to_data().to_vec::<f32>().unwrap()[0];
let done = t.done.to_data().to_vec::<f32>().unwrap()[0] > 0.5;
transition_buffer.push(state, next_state, action, reward, done);
}
if transition_buffer.len() >= self.config.train_batch_size
&& progress.items_processed >= self.config.warmup_steps
{
if let Some(ref u) = intermediary_update {
env_runner.update_policy(u.clone());
}
for _ in 0..self.config.train_steps {
let batch = transition_buffer.sample(self.config.train_batch_size);
let train_item = learner_agent.train(batch);
intermediary_update = Some(train_item.policy);
event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
train_item.item,
progress.clone(),
None,
)));
}
}
if valid_next > previous_steps && valid_next <= progress.items_processed {
env_runner_valid.update_policy(learner_agent.policy().state());
env_runner_valid.run_episodes(
self.config.eval_episodes,
&mut event_processor,
&training_components.interrupter,
&mut progress,
);
if let Some(checkpointer) = &mut checkpointer {
checkpointer.checkpoint(
&env_runner.policy(),
learner_agent,
valid_next,
&training_components.event_store,
);
}
valid_next += self.config.eval_interval;
}
}
(learner_agent.policy(), event_processor)
}
}

View File

@@ -0,0 +1,32 @@
use crate::{
ItemLazy,
metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput},
};
/// Summary of an episode.
pub struct EpisodeSummary {
/// The total length of the episode.
pub episode_length: usize,
/// The final cumulative reward.
pub cum_reward: f64,
}
impl ItemLazy for EpisodeSummary {
type ItemSync = EpisodeSummary;
fn sync(self) -> Self::ItemSync {
self
}
}
impl Adaptor<EpisodeLengthInput> for EpisodeSummary {
fn adapt(&self) -> EpisodeLengthInput {
EpisodeLengthInput::new(self.episode_length as f64)
}
}
impl Adaptor<CumulativeRewardInput> for EpisodeSummary {
fn adapt(&self) -> CumulativeRewardInput {
CumulativeRewardInput::new(self.cum_reward)
}
}

View File

@@ -0,0 +1,525 @@
use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
KeepLastNCheckpoints, MetricCheckpointingStrategy,
};
use crate::learner::base::Interrupter;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};
use crate::renderer::{MetricsRenderer, default_renderer};
use crate::{
ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,
LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,
RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,
RLPolicyRecord, RLStrategy,
};
use crate::{EpisodeSummary, RLStrategies};
use burn_core::record::FileRecorder;
use burn_core::tensor::backend::AutodiffBackend;
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, SliceAccess};
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
/// Structure to configure and launch reinforcement learning trainings.
pub struct RLTraining<RLC: RLComponentsTypes> {
// Not that complex. Extracting into yet another type would only make it more confusing.
#[allow(clippy::type_complexity)]
checkpointers: Option<(
AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
)>,
num_steps: usize,
checkpoint: Option<usize>,
directory: PathBuf,
grad_accumulation: Option<usize>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,
event_store: LogEventStore,
interrupter: Interrupter,
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
learning_strategy: RLStrategies<RLC>,
// Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
summary_metrics: BTreeSet<String>,
summary: bool,
env_initializer: RLC::EnvInit,
}
impl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>
where
B: AutodiffBackend,
E: Environment + 'static,
EI: EnvironmentInit<E> + Send + 'static,
A: PolicyLearner<B> + Send + 'static,
A::TrainContext: ItemLazy + Clone + Send,
A::InnerPolicy: Policy<B> + Send,
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
+ Into<<A::InnerPolicy as Policy<B>>::Action>
+ Clone
+ Send
+ 'static,
{
/// Creates a new runner for reinforcement learning.
///
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
/// * `env_init` - Specifies how to initialize the environment.
pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {
let directory = directory.as_ref().to_path_buf();
let experiment_log_file = directory.join("experiment.log");
Self {
num_steps: 1,
checkpoint: None,
checkpointers: None,
directory,
grad_accumulation: None,
metrics: RLMetrics::default(),
event_store: LogEventStore::default(),
renderer: None,
interrupter: Interrupter::new(),
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
experiment_log_file,
))),
checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(2))
.add(MetricCheckpointingStrategy::new(
&EpisodeLengthMetric::new(), // default to evaluations' cumulative reward.
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
))
.build(),
),
learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),
summary_metrics: BTreeSet::new(),
summary: false,
env_initializer,
}
}
}
impl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {
/// Replace the default learning strategy (Off Policy learning) with the provided one.
///
/// # Arguments
///
/// * `training_strategy` - The training strategy.
pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {
self.learning_strategy = learning_strategy;
self
}
/// Replace the default metric loggers with the provided ones.
///
/// # Arguments
///
/// * `logger` - The training logger.
pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
where
ML: MetricLogger + 'static,
{
self.event_store.register_logger(logger);
self
}
/// Update the checkpointing_strategy.
pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
mut self,
strategy: CS,
) -> Self {
self.checkpointer_strategy = Box::new(strategy);
self
}
/// Replace the default CLI renderer with a custom one.
///
/// # Arguments
///
/// * `renderer` - The custom renderer.
pub fn renderer<MR>(mut self, renderer: MR) -> Self
where
MR: MetricsRenderer + 'static,
{
self.renderer = Some(Box::new(renderer));
self
}
/// Register numerical metrics for a training step of the agent.
pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register textual metrics for a training step of the agent.
pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register numerical metrics for each action of the agent.
pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register textual metrics for each action of the agent.
pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register numerical metrics for a completed episode.
pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register textual metrics for a completed episode.
pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register a textual metric for a training step.
pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.metrics.register_text_metric_train(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step.
pub fn metric_train<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_metric_train(metric);
self
}
/// Register a textual metric for each action taken by the agent.
pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.metrics.register_text_metric_agent(metric.clone());
self.metrics.register_text_metric_agent_valid(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent.
pub fn metric_agent<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_agent_metric(metric.clone());
self.metrics.register_agent_metric_valid(metric);
self
}
/// Register a textual metric for a completed episode.
pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
self.metrics.register_text_metric_episode(metric.clone());
self.metrics.register_text_metric_episode_valid(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode.
pub fn metric_episode<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_episode_metric(metric.clone());
self.metrics.register_episode_metric_valid(metric);
self
}
/// The number of environment steps to train for.
pub fn num_steps(mut self, num_steps: usize) -> Self {
self.num_steps = num_steps;
self
}
/// The step from which the training must resume.
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
self.checkpoint = Some(checkpoint);
self
}
/// Provides a handle that can be used to interrupt training.
pub fn interrupter(&self) -> Interrupter {
self.interrupter.clone()
}
/// Override the handle for stopping training with an externally provided handle
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
self.interrupter = interrupter;
self
}
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
pub fn with_application_logger(
mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self
}
/// Register a checkpointer that will save the environment runner's [policy](Policy)
/// and the [PolicyLearner](PolicyLearner) state to different files.
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
where
FR: FileRecorder<RLC::Backend> + 'static,
FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,
{
let checkpoint_dir = self.directory.join("checkpoint");
let checkpointer_policy =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy");
let checkpointer_learning =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent");
self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_policy),
AsyncCheckpointer::new(checkpointer_learning),
));
self
}
/// Enable the training summary report.
///
/// The summary will be displayed after `.launch()`, when the renderer is dropped.
pub fn summary(mut self) -> Self {
self.summary = true;
self
}
/// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment.
pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy>
where
RLC::PolicyObs: SliceAccess<RLC::Backend>,
RLC::PolicyAction: SliceAccess<RLC::Backend>,
{
if self.tracing_logger.is_some()
&& let Err(e) = self.tracing_logger.as_ref().unwrap().install()
{
log::warn!("Failed to install the experiment logger: {e}");
}
let renderer = self
.renderer
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
if !self.event_store.has_loggers() {
self.event_store
.register_logger(FileMetricLogger::new(self.directory.clone()));
}
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(
self.metrics,
renderer,
event_store.clone(),
));
let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {
RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)
});
let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};
let components = RLComponents::<RLC> {
checkpoint: self.checkpoint,
checkpointer,
interrupter: self.interrupter,
event_processor,
event_store,
num_steps: self.num_steps,
grad_accumulation: self.grad_accumulation,
summary,
};
match self.learning_strategy {
RLStrategies::OffPolicyStrategy(config) => {
let strategy = OffPolicyStrategy::new(config);
strategy.train(learner_agent, components, self.env_initializer)
}
RLStrategies::Custom(strategy) => {
strategy.train(learner_agent, components, self.env_initializer)
}
}
}
}
/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer).
pub struct RLResult<P> {
/// The learned policy.
pub policy: P,
/// The renderer that can be used for follow up training and evaluation.
pub renderer: Box<dyn MetricsRenderer>,
}
/// Trait to fake variadic generics for train step metrics.
pub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for train step text metrics.
pub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for env step metrics.
pub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for env step text metrics.
pub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for episode metrics.
pub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for episode text metrics.
pub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
macro_rules! gen_tuple {
($($M:ident),*) => {
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.text_metric_train($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.metric_train($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.text_metric_agent($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.metric_agent($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)
where
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.text_metric_episode($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)
where
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.metric_episode($M.clone());)*
builder
}
}
};
}
gen_tuple!(M1);
gen_tuple!(M1, M2);
gen_tuple!(M1, M2, M3);
gen_tuple!(M1, M2, M3, M4);
gen_tuple!(M1, M2, M3, M4, M5);
gen_tuple!(M1, M2, M3, M4, M5, M6);

View File

@@ -0,0 +1,99 @@
use std::sync::Arc;
use crate::{
Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
RLEventProcessorType, RLResult,
metric::{processor::EventProcessorTraining, store::EventStoreClient},
};
/// Struct to minimise parameters passed to [RLStrategy::train].
pub struct RLComponents<RLC: RLComponentsTypes> {
/// The total number of environment steps.
pub num_steps: usize,
/// The step number from which to continue the training.
pub checkpoint: Option<usize>,
/// A checkpointer used to load and save learning checkpoints.
pub checkpointer: Option<RLCheckpointer<RLC>>,
/// Enables gradients accumulation.
pub grad_accumulation: Option<usize>,
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
pub interrupter: Interrupter,
/// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation.
pub event_processor: RLEventProcessorType<RLC>,
/// A reference to an [EventStoreClient](EventStoreClient).
pub event_store: Arc<EventStoreClient>,
/// Config for creating a summary of the learning
pub summary: Option<LearnerSummaryConfig>,
}
/// The strategy for reinforcement learning.
#[derive(Clone)]
pub enum RLStrategies<RLC: RLComponentsTypes> {
/// Training on one device
OffPolicyStrategy(OffPolicyConfig),
/// Training using a custom learning strategy
Custom(CustomRLStrategy<RLC>),
}
/// A reference to an implementation of [RLStrategy].
pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
/// Provides the `fit` function for any learning strategy
pub trait RLStrategy<RLC: RLComponentsTypes> {
/// Train the learner agent with this strategy.
fn train(
&self,
mut learner_agent: RLC::LearningAgent,
mut training_components: RLComponents<RLC>,
env_init: RLC::EnvInit,
) -> RLResult<RLC::Policy> {
let starting_epoch = match training_components.checkpoint {
Some(checkpoint) => {
if let Some(checkpointer) = &mut training_components.checkpointer {
learner_agent = checkpointer.load_checkpoint(
learner_agent,
&Default::default(),
checkpoint,
);
}
checkpoint + 1
}
None => 1,
};
let summary_config = training_components.summary.clone();
// Event processor start training
training_components
.event_processor
.process_train(RLEvent::Start);
// Training loop
let (policy, mut event_processor) = self.train_loop(
training_components,
&mut learner_agent,
starting_epoch,
env_init,
);
let summary = summary_config.and_then(|summary| summary.init().ok());
// Signal training end. For the TUI renderer, this handles the exit & return to main screen.
// TODO: summary makes sense for RL?
event_processor.process_train(RLEvent::End(summary));
// let model = model.valid();
let renderer = event_processor.renderer();
RLResult { policy, renderer }
}
/// Training loop for this strategy
fn train_loop(
&self,
training_components: RLComponents<RLC>,
learner_agent: &mut RLC::LearningAgent,
starting_epoch: usize,
env_init: RLC::EnvInit,
) -> (RLC::Policy, RLEventProcessorType<RLC>);
}

View File

@@ -0,0 +1,131 @@
use crate::metric::{AccuracyInput, PerplexityInput, TopKAccuracyInput};
use crate::metric::{Adaptor, CerInput, LossInput, WerInput, processor::ItemLazy};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor, Transaction};
use burn_ndarray::NdArray;
/// Sequence prediction output adapted for multiple metrics.
///
/// Supported metrics:
/// - Accuracy
/// - TopKAccuracy
/// - Perplexity
/// - Loss
/// - CER
/// - WER
#[derive(new)]
pub struct SequenceOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// Raw logits. Shape: `[batch_size, seq_len, vocab_size]`
pub logits: Tensor<B, 3>,
/// Optional predicted token indices. Shape: `[batch_size, seq_length]`.
/// If not provided, predictions default to argmax of `logits` along the last dimension.
pub predictions: Option<Tensor<B, 2, Int>>,
/// The target token indices. Shape: `[batch_size, seq_length]`
pub targets: Tensor<B, 2, Int>,
}
impl<B: Backend> SequenceOutput<B> {
fn predicted_tokens(&self) -> Tensor<B, 2, Int> {
match &self.predictions {
Some(preds) => preds.clone(),
None => self.logits.clone().argmax(2).squeeze_dim::<2>(2),
}
}
fn flat_logits(&self) -> Tensor<B, 2> {
let [batch_size, seq_len, vocab_size] = self.logits.dims();
self.logits
.clone()
.reshape([batch_size * seq_len, vocab_size])
}
fn flat_targets(&self) -> Tensor<B, 1, Int> {
let [batch_size, seq_len] = self.targets.dims();
self.targets.clone().reshape([batch_size * seq_len])
}
}
impl<B: Backend> ItemLazy for SequenceOutput<B> {
type ItemSync = SequenceOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let device = &Default::default();
match self.predictions {
Some(preds) => {
let [logits, loss, targets, predictions] = Transaction::default()
.register(self.logits)
.register(self.loss)
.register(self.targets)
.register(preds)
.execute()
.try_into()
.expect("Correct amount of tensor data");
SequenceOutput {
logits: Tensor::from_data(logits, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
predictions: Some(Tensor::from_data(predictions, device)),
}
}
None => {
let [logits, loss, targets] = Transaction::default()
.register(self.logits)
.register(self.loss)
.register(self.targets)
.execute()
.try_into()
.expect("Correct amount of tensor data");
SequenceOutput {
logits: Tensor::from_data(logits, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
predictions: None,
}
}
}
}
}
impl<B: Backend> Adaptor<LossInput<B>> for SequenceOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
impl<B: Backend> Adaptor<CerInput<B>> for SequenceOutput<B> {
fn adapt(&self) -> CerInput<B> {
CerInput::new(self.predicted_tokens(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<WerInput<B>> for SequenceOutput<B> {
fn adapt(&self) -> WerInput<B> {
WerInput::new(self.predicted_tokens(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<AccuracyInput<B>> for SequenceOutput<B> {
fn adapt(&self) -> AccuracyInput<B> {
AccuracyInput::new(self.flat_logits(), self.flat_targets())
}
}
impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for SequenceOutput<B> {
fn adapt(&self) -> TopKAccuracyInput<B> {
TopKAccuracyInput::new(self.flat_logits(), self.flat_targets())
}
}
impl<B: Backend> Adaptor<PerplexityInput<B>> for SequenceOutput<B> {
fn adapt(&self) -> PerplexityInput<B> {
PerplexityInput::new(self.flat_logits(), self.flat_targets())
}
}

View File

@@ -0,0 +1,475 @@
use core::cmp::Ordering;
use std::{
collections::{HashMap, hash_map::Entry},
fmt::Display,
path::{Path, PathBuf},
};
use crate::{
logger::FileMetricLogger,
metric::store::{Aggregate, EventStore, LogEventStore, Split},
};
/// Contains the metric value at a given time.
#[derive(Debug)]
pub struct MetricEntry {
/// The step at which the metric was recorded (i.e., epoch).
pub step: usize,
/// The metric value.
pub value: f64,
}
/// Contains the summary of recorded values for a given metric.
#[derive(Debug)]
pub struct MetricSummary {
/// The metric name.
pub name: String,
/// The metric entries.
pub entries: Vec<MetricEntry>,
}
impl MetricSummary {
fn collect<E: EventStore>(
event_store: &mut E,
metric: &str,
split: &Split,
num_epochs: usize,
) -> Option<Self> {
let entries = (1..=num_epochs)
.filter_map(|epoch| {
event_store
.find_metric(metric, epoch, Aggregate::Mean, split)
.map(|value| MetricEntry { step: epoch, value })
})
.collect::<Vec<_>>();
if entries.is_empty() {
None
} else {
Some(Self {
name: metric.to_string(),
entries,
})
}
}
}
/// Contains the summary of recorded metrics for the training and validation steps.
pub struct SummaryMetrics {
/// Training metrics summary.
pub train: Vec<MetricSummary>,
/// Validation metrics summary.
pub valid: Vec<MetricSummary>,
/// Test metrics summary per test split tag.
///
/// Each key corresponds to a `Split::Test(Some(tag))`.
/// The empty string represents `Split::Test(None)`.
pub test: HashMap<String, Vec<MetricSummary>>,
}
/// Detailed training summary.
pub struct LearnerSummary {
/// The number of epochs completed.
pub epochs: usize,
/// The summary of recorded metrics during training.
pub metrics: SummaryMetrics,
/// The model name (only recorded within the learner).
pub(crate) model: Option<String>,
}
impl LearnerSummary {
/// Creates a new learner summary for the specified metrics.
///
/// # Arguments
///
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
/// * `metrics` - The list of metrics to collect for the summary.
pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
let directory = directory.as_ref();
if !directory.exists() {
return Err(format!(
"Artifact directory does not exist at: {}",
directory.display()
));
}
let mut event_store = LogEventStore::default();
let train_split = Split::Train;
let valid_split = Split::Valid;
let logger = FileMetricLogger::new(directory);
let test_split_root = logger.split_dir(&Split::Test(None));
if !logger.split_exists(&train_split)
&& !logger.split_exists(&valid_split)
&& test_split_root.is_none()
{
return Err(format!(
"No training, validation or test artifacts found at: {}",
directory.display()
));
}
// Number of recorded epochs
let epochs = logger.epochs();
event_store.register_logger(logger);
let train_summary = metrics
.iter()
.filter_map(|metric| {
MetricSummary::collect(&mut event_store, metric.as_ref(), &train_split, epochs)
})
.collect::<Vec<_>>();
let valid_summary = metrics
.iter()
.filter_map(|metric| {
MetricSummary::collect(&mut event_store, metric.as_ref(), &valid_split, epochs)
})
.collect::<Vec<_>>();
let test_summary = match test_split_root {
Some(root) => collect_test_split_metrics(root, metrics, &mut event_store, epochs),
None => Default::default(),
};
Ok(Self {
epochs,
metrics: SummaryMetrics {
train: train_summary,
valid: valid_summary,
test: test_summary,
},
model: None,
})
}
pub(crate) fn with_model(mut self, name: String) -> Self {
self.model = Some(name);
self
}
/// Merges another summary into this one, combining all metric entries.
pub(crate) fn merge(mut self, other: LearnerSummary) -> Self {
fn merge_metrics(
base: Vec<MetricSummary>,
incoming: Vec<MetricSummary>,
) -> Vec<MetricSummary> {
let mut map: HashMap<String, MetricSummary> =
base.into_iter().map(|m| (m.name.clone(), m)).collect();
for metric in incoming {
match map.entry(metric.name.clone()) {
Entry::Occupied(mut entry) => {
entry.get_mut().entries.extend(metric.entries);
}
Entry::Vacant(entry) => {
entry.insert(metric);
}
}
}
map.into_values().collect()
}
self.metrics.train = merge_metrics(self.metrics.train, other.metrics.train);
self.metrics.valid = merge_metrics(self.metrics.valid, other.metrics.valid);
for (tag, metrics) in other.metrics.test {
match self.metrics.test.entry(tag) {
Entry::Occupied(mut entry) => {
let current = std::mem::take(entry.get_mut());
let merged = merge_metrics(current, metrics);
*entry.get_mut() = merged;
}
Entry::Vacant(entry) => {
entry.insert(metrics);
}
}
}
if self.model != other.model {
self.model = None;
}
self
}
}
fn collect_test_split_metrics<P: AsRef<Path>, S: AsRef<str>>(
root: P,
metrics: &[S],
event_store: &mut LogEventStore,
epochs: usize,
) -> HashMap<String, Vec<MetricSummary>> {
// Collect immediate child directories
let dirs = match std::fs::read_dir(root) {
Ok(entries) => entries
.filter_map(|entry| {
let entry = entry.ok()?;
let file_type = entry.file_type().ok()?;
if file_type.is_dir() {
Some(entry.file_name().to_string_lossy().to_string())
} else {
None
}
})
.collect::<Vec<_>>(),
Err(_) => Vec::new(),
};
let mut map = HashMap::new();
if dirs.is_empty() {
return map;
}
// Detect if all directories are epoch directories
let all_epochs = dirs.iter().all(FileMetricLogger::is_epoch_dir);
if all_epochs {
// Single untagged test split
let split = Split::Test(None);
let summaries = metrics
.iter()
.filter_map(|metric| {
MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
})
.collect::<Vec<_>>();
// Untagged marked with empty string
map.insert("".to_string(), summaries);
} else {
// Tagged splits
for tag in dirs {
let split = Split::Test(Some(tag.clone().into()));
let summaries = metrics
.iter()
.filter_map(|metric| {
MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
})
.collect::<Vec<_>>();
map.insert(tag, summaries);
}
}
map
}
impl Display for LearnerSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Compute the max length for each column
let mut max_split_len = 5; // "Train"
let mut max_metric_len = "Metric".len();
for metric in self.metrics.train.iter() {
max_metric_len = max_metric_len.max(metric.name.len());
}
for metric in self.metrics.valid.iter() {
max_metric_len = max_metric_len.max(metric.name.len());
}
for (tag, metrics) in self.metrics.test.iter() {
let split_name = if tag.is_empty() {
"Test".to_string()
} else {
format!("Test ({tag})")
};
max_split_len = max_split_len.max(split_name.len());
for metric in metrics {
max_metric_len = max_metric_len.max(metric.name.len());
}
}
// Summary header
writeln!(
f,
"{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
"",
"",
width_symbol = 24,
)?;
if let Some(model) = &self.model {
writeln!(f, "Model:\n{model}")?;
}
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
// Metrics table header
writeln!(
f,
"| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
"Split",
"Metric",
"",
"",
width_split = max_split_len,
width_metric = max_metric_len,
)?;
// Table entries
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
match (a.is_nan(), b.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
_ => a.partial_cmp(b).unwrap(),
}
}
fn fmt_val(val: f64) -> String {
if val < 1e-2 {
// Use scientific notation for small values which would otherwise be truncated
format!("{val:<9.3e}")
} else {
format!("{val:<9.3}")
}
}
let mut write_metrics_summary =
|metrics: &[MetricSummary], split: String| -> std::fmt::Result {
for metric in metrics.iter() {
if metric.entries.is_empty() {
continue; // skip metrics with no recorded values
}
// Compute the min & max for each metric
let metric_min = metric
.entries
.iter()
.min_by(|a, b| cmp_f64(&a.value, &b.value))
.unwrap();
let metric_max = metric
.entries
.iter()
.max_by(|a, b| cmp_f64(&a.value, &b.value))
.unwrap();
writeln!(
f,
"| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|",
split,
metric.name,
fmt_val(metric_min.value),
metric_min.step,
fmt_val(metric_max.value),
metric_max.step,
width_split = max_split_len,
width_metric = max_metric_len,
)?;
}
Ok(())
};
write_metrics_summary(&self.metrics.train, format!("{:?}", Split::Train))?;
write_metrics_summary(&self.metrics.valid, format!("{:?}", Split::Valid))?;
for (tag, metrics) in &self.metrics.test {
let split_name = if tag.is_empty() {
"Test".to_string()
} else {
format!("Test ({tag})")
};
write_metrics_summary(metrics, split_name)?;
}
Ok(())
}
}
// TODO: rename to `ExperimentSummary`? Used in learner + evaluator.
#[derive(Clone)]
/// Learning summary config.
pub struct LearnerSummaryConfig {
pub(crate) directory: PathBuf,
pub(crate) metrics: Vec<String>,
}
impl LearnerSummaryConfig {
/// Create the learning summary.
pub fn init(&self) -> Result<LearnerSummary, String> {
LearnerSummary::new(&self.directory, &self.metrics[..])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic = "Summary artifacts should exist"]
fn test_artifact_dir_should_exist() {
let dir = "/tmp/learner-summary-not-found";
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
}
#[test]
#[should_panic = "Summary artifacts should exist"]
fn test_train_valid_artifacts_should_exist() {
let dir = "/tmp/test-learner-summary-empty";
std::fs::create_dir_all(dir).ok();
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
}
#[test]
fn test_summary_should_be_empty() {
let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
std::fs::create_dir_all(dir).unwrap();
std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
.expect("Summary artifacts should exist");
assert_eq!(summary.epochs, 1);
assert_eq!(summary.metrics.train.len(), 0);
assert_eq!(summary.metrics.valid.len(), 0);
std::fs::remove_dir_all(dir).unwrap();
}
#[test]
fn test_summary_should_be_collected() {
let dir = Path::new("/tmp/test-learner-summary");
let train_dir = dir.join("train/epoch-1");
let valid_dir = dir.join("valid/epoch-1");
std::fs::create_dir_all(dir).unwrap();
std::fs::create_dir_all(&train_dir).unwrap();
std::fs::create_dir_all(&valid_dir).unwrap();
std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
.expect("Summary artifacts should exist");
assert_eq!(summary.epochs, 1);
// Only Loss metric
assert_eq!(summary.metrics.train.len(), 1);
assert_eq!(summary.metrics.valid.len(), 1);
// Aggregated train metric entries for 1 epoch
let train_metric = &summary.metrics.train[0];
assert_eq!(train_metric.name, "Loss");
assert_eq!(train_metric.entries.len(), 1);
let entry = &train_metric.entries[0];
assert_eq!(entry.step, 1); // epoch = 1
assert_eq!(entry.value, 1.5); // (1 + 2) / 2
// Aggregated valid metric entries for 1 epoch
let valid_metric = &summary.metrics.valid[0];
assert_eq!(valid_metric.name, "Loss");
assert_eq!(valid_metric.entries.len(), 1);
let entry = &valid_metric.entries[0];
assert_eq!(entry.step, 1); // epoch = 1
assert_eq!(entry.value, 1.0);
std::fs::remove_dir_all(dir).unwrap();
}
}

View File

@@ -0,0 +1,7 @@
mod paradigm;
mod step;
mod strategies;
pub use paradigm::*;
pub use step::*;
pub use strategies::*;

View File

@@ -0,0 +1,488 @@
use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
KeepLastNCheckpoints, MetricCheckpointingStrategy,
};
use crate::components::{InferenceModelOutput, TrainingModelOutput};
use crate::learner::EarlyStoppingStrategy;
use crate::learner::base::Interrupter;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::processor::{
AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric, Numeric};
use crate::multi::MultiDeviceLearningStrategy;
use crate::renderer::{MetricsRenderer, default_renderer};
use crate::single::SingleDeviceTrainingStrategy;
use crate::{
ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent,
LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig,
LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult,
TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy,
};
use crate::{Learner, SupervisedLearningStrategy};
use burn_core::data::dataloader::DataLoader;
use burn_core::module::{AutodiffModule, Module};
use burn_core::record::FileRecorder;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::Optimizer;
use burn_optim::lr_scheduler::LrScheduler;
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
/// A reference to the training split [DataLoader](DataLoader).
pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;
/// A reference to the validation split [DataLoader](DataLoader).
pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
/// The event processor type for supervised learning.
pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
LearnerEvent<TrainingModelOutput<LC>>,
LearnerEvent<InferenceModelOutput<LC>>,
>;
/// Structure to configure and launch supervised learning trainings.
pub struct SupervisedTraining<LC>
where
LC: LearningComponentsTypes,
{
// Not that complex. Extracting into another type would only make it more confusing.
#[allow(clippy::type_complexity)]
checkpointers: Option<(
AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,
AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,
AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,
)>,
num_epochs: usize,
checkpoint: Option<usize>,
directory: PathBuf,
grad_accumulation: Option<usize>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
event_store: LogEventStore,
interrupter: Interrupter,
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<EarlyStoppingStrategyRef>,
training_strategy: Option<TrainingStrategy<LC>>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
// Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
summary_metrics: BTreeSet<String>,
summary: bool,
}
impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
M::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
{
/// Creates a new runner for a supervised training.
///
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
/// * `dataloader_train` - The dataloader for the training split.
/// * `dataloader_valid` - The dataloader for the validation split.
pub fn new(
directory: impl AsRef<Path>,
dataloader_train: Arc<dyn DataLoader<B, M::Input>>,
dataloader_valid: Arc<
dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,
>,
) -> Self {
let directory = directory.as_ref().to_path_buf();
let experiment_log_file = directory.join("experiment.log");
Self {
num_epochs: 1,
checkpoint: None,
checkpointers: None,
directory,
grad_accumulation: None,
metrics: MetricsTraining::default(),
event_store: LogEventStore::default(),
renderer: None,
interrupter: Interrupter::new(),
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
experiment_log_file,
))),
checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(2))
.add(MetricCheckpointingStrategy::new(
&LossMetric::<B>::new(), // default to valid loss
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
))
.build(),
),
early_stopping: None,
training_strategy: None,
summary_metrics: BTreeSet::new(),
summary: false,
dataloader_train,
dataloader_valid,
}
}
}
impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
/// Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided one.
///
/// # Arguments
///
/// * `training_strategy` - The training strategy.
pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {
self.training_strategy = Some(training_strategy);
self
}
/// Replace the default metric loggers with the provided ones.
///
/// # Arguments
///
/// * `logger` - The training logger.
pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
where
ML: MetricLogger + 'static,
{
self.event_store.register_logger(logger);
self
}
/// Update the checkpointing_strategy.
pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
mut self,
strategy: CS,
) -> Self {
self.checkpointer_strategy = Box::new(strategy);
self
}
/// Replace the default CLI renderer with a custom one.
///
/// # Arguments
///
/// * `renderer` - The custom renderer.
pub fn renderer<MR>(mut self, renderer: MR) -> Self
where
MR: MetricsRenderer + 'static,
{
self.renderer = Some(Box::new(renderer));
self
}
/// Register all metrics as numeric for the training and validation set.
pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register all metrics as text for the training and validation set.
pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register a training metric.
pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.metrics.register_train_metric(metric);
self
}
/// Register a validation metric.
pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.metrics.register_valid_metric(metric);
self
}
/// Enable gradients accumulation.
///
/// # Notes
///
/// When you enable gradients accumulation, the gradients object used by the optimizer will be
/// the sum of all gradients generated by each backward pass. It might be a good idea to
/// reduce the learning to compensate.
///
/// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
/// amount.
pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
self.grad_accumulation = Some(accumulation);
self
}
/// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_train_metric_numeric(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self
where
<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_valid_metric_numeric(metric);
self
}
/// The number of epochs the training should last.
pub fn num_epochs(mut self, num_epochs: usize) -> Self {
self.num_epochs = num_epochs;
self
}
/// The epoch from which the training must resume.
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
self.checkpoint = Some(checkpoint);
self
}
/// Provides a handle that can be used to interrupt training.
pub fn interrupter(&self) -> Interrupter {
self.interrupter.clone()
}
/// Override the handle for stopping training with an externally provided handle
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
self.interrupter = interrupter;
self
}
/// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
/// conditions are meet.
pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
where
Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
{
self.early_stopping = Some(Box::new(strategy));
self
}
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
pub fn with_application_logger(
mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self
}
/// Register a checkpointer that will save the [optimizer](Optimizer), the
/// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
where
FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,
FR: FileRecorder<
<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,
> + 'static,
{
let checkpoint_dir = self.directory.join("checkpoint");
let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
let checkpointer_optimizer =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
let checkpointer_scheduler: FileCheckpointer<FR> =
FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_model),
AsyncCheckpointer::new(checkpointer_optimizer),
AsyncCheckpointer::new(checkpointer_scheduler),
));
self
}
/// Enable the training summary report.
///
/// The summary will be displayed after `.fit()`, when the renderer is dropped.
pub fn summary(mut self) -> Self {
self.summary = true;
self
}
}
impl<LC: LearningComponentsTypes + Send + 'static> SupervisedTraining<LC> {
/// Launch this training with the given [Learner](Learner).
pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {
if self.tracing_logger.is_some()
&& let Err(e) = self.tracing_logger.as_ref().unwrap().install()
{
log::warn!("Failed to install the experiment logger: {e}");
}
let renderer = self
.renderer
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
if !self.event_store.has_loggers() {
self.event_store
.register_logger(FileMetricLogger::new(self.directory.clone()));
}
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
self.metrics,
renderer,
event_store.clone(),
));
let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
LearningCheckpointer::new(
model.with_interrupter(self.interrupter.clone()),
optim.with_interrupter(self.interrupter.clone()),
scheduler.with_interrupter(self.interrupter.clone()),
self.checkpointer_strategy,
)
});
let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};
let components = TrainingComponents {
checkpoint: self.checkpoint,
checkpointer,
interrupter: self.interrupter,
early_stopping: self.early_stopping,
event_processor,
event_store,
num_epochs: self.num_epochs,
grad_accumulation: self.grad_accumulation,
summary,
};
// Default to single device based on model
let training_strategy = self
.training_strategy
.unwrap_or(TrainingStrategy::SingleDevice(
learner.model.devices()[0].clone(),
));
match training_strategy {
TrainingStrategy::SingleDevice(device) => {
let single_device: SingleDeviceTrainingStrategy<LC> =
SingleDeviceTrainingStrategy::new(device);
single_device.train(
learner,
self.dataloader_train,
self.dataloader_valid,
components,
)
}
TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(
learner,
self.dataloader_train,
self.dataloader_valid,
components,
),
TrainingStrategy::MultiDevice(devices, multi_device_optim) => {
let strategy: Box<dyn SupervisedLearningStrategy<LC>> = match devices.len() == 1 {
true => Box::new(SingleDeviceTrainingStrategy::new(devices[0].clone())),
false => Box::new(MultiDeviceLearningStrategy::new(
devices,
multi_device_optim,
)),
};
strategy.train(
learner,
self.dataloader_train,
self.dataloader_valid,
components,
)
}
#[cfg(feature = "ddp")]
TrainingStrategy::DistributedDataParallel { devices, config } => {
use crate::ddp::DdpTrainingStrategy;
let ddp = DdpTrainingStrategy::new(devices.clone(), config.clone());
ddp.train(
learner,
self.dataloader_train,
self.dataloader_valid,
components,
)
}
}
}
}
/// Trait to fake variadic generics.
pub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
}
/// Trait to fake variadic generics.
pub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
}
macro_rules! gen_tuple {
($($M:ident),*) => {
impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)
where
$(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: SupervisedTraining<LC>,
) -> SupervisedTraining<LC> {
let ($($M,)*) = self;
$(let builder = builder.metric_train($M.clone());)*
$(let builder = builder.metric_valid($M);)*
builder
}
}
impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)
where
$(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: SupervisedTraining<LC>,
) -> SupervisedTraining<LC> {
let ($($M,)*) = self;
$(let builder = builder.metric_train_numeric($M.clone());)*
$(let builder = builder.metric_valid_numeric($M);)*
builder
}
}
};
}
gen_tuple!(M1);
gen_tuple!(M1, M2);
gen_tuple!(M1, M2, M3);
gen_tuple!(M1, M2, M3, M4);
gen_tuple!(M1, M2, M3, M4, M5);
gen_tuple!(M1, M2, M3, M4, M5, M6);

View File

@@ -0,0 +1,2 @@
/// The trainer module.
pub mod train;

View File

@@ -0,0 +1,151 @@
use crate::{LearningComponentsTypes, TrainingModel};
use crate::{TrainOutput, TrainStep, TrainingBackend, TrainingModelInput, TrainingModelOutput};
use burn_core::data::dataloader::DataLoaderIterator;
use burn_core::data::dataloader::Progress;
use burn_core::module::Module;
use burn_core::prelude::DeviceOps;
use burn_core::tensor::Device;
use burn_core::tensor::backend::DeviceId;
use std::sync::mpsc::{Receiver, Sender};
use std::thread::spawn;
/// Multi devices train step.
pub struct MultiDevicesTrainStep<LC: LearningComponentsTypes> {
workers: Vec<Worker<LC>>,
receiver: Receiver<MultiTrainOutput<TrainingModelOutput<LC>>>,
}
struct Message<M, TI> {
item: TI,
model: M,
}
struct Worker<LC: LearningComponentsTypes> {
// Not that complex. Extracting into another type would only make it more confusing.
#[allow(clippy::type_complexity)]
sender_input: Sender<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
device: Device<TrainingBackend<LC>>,
}
impl<LC: LearningComponentsTypes> Worker<LC> {
fn register(&self, item: TrainingModelInput<LC>, model: &TrainingModel<LC>) {
let message = Message {
item,
model: model.clone(),
};
self.sender_input.send(message).unwrap();
}
// Not that complex. Extracting into another type would only make it more confusing.
#[allow(clippy::type_complexity)]
fn start(
&self,
sender_output: Sender<MultiTrainOutput<TrainingModelOutput<LC>>>,
receiver_input: Receiver<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
) {
let device = self.device.clone();
spawn(move || {
loop {
match receiver_input.recv() {
Ok(item) => {
let model = item.model.fork(&device);
let output = model.step(item.item);
let item = MultiTrainOutput {
output,
device: device.to_id(),
};
sender_output.send(item).unwrap();
}
Err(_err) => {
log::info!("Closing thread on device {device:?}");
break;
}
}
}
});
}
}
/// Multiple output items.
pub struct MultiTrainOutput<TO> {
/// The training output.
pub output: TrainOutput<TO>,
/// The device on which the computing happened.
pub device: DeviceId,
}
impl<LC: LearningComponentsTypes> MultiDevicesTrainStep<LC> {
/// Create a new multi devices train step.
///
/// # Arguments
///
/// * `devices` - Devices.
///
/// # Returns
///
/// MultiDevicesTrainStep instance.
pub fn new(devices: &[Device<TrainingBackend<LC>>]) -> Self {
let (sender_output, receiver_output) = std::sync::mpsc::channel();
let workers = devices
.iter()
.map(|device| {
let (sender_input, receiver_input) = std::sync::mpsc::channel();
let worker = Worker {
sender_input,
device: device.clone(),
};
worker.start(sender_output.clone(), receiver_input);
worker
})
.collect();
Self {
workers,
receiver: receiver_output,
}
}
/// Collect outputs from workers for one step.
///
/// # Arguments
///
/// * `model` - Model.
/// * `dataloaders` - The data loader for each worker.
///
/// # Returns
///
/// Outputs.
pub fn step<'a>(
&self,
dataloaders: &mut [Box<dyn DataLoaderIterator<TrainingModelInput<LC>> + 'a>],
model: &TrainingModel<LC>,
) -> (Vec<MultiTrainOutput<TrainingModelOutput<LC>>>, Progress) {
let mut num_send = 0;
let mut items_total = 0;
let mut items_processed = 0;
for (i, worker) in self.workers.iter().enumerate() {
let dataloader = &mut dataloaders[i];
if let Some(item) = dataloader.next() {
worker.register(item, model);
num_send += 1;
let progress = dataloader.progress();
items_total += progress.items_total;
items_processed += progress.items_processed;
}
}
let mut outputs = Vec::with_capacity(num_send);
for _ in 0..num_send {
let output = self.receiver.recv().unwrap();
outputs.push(output);
}
(outputs, Progress::new(items_processed, items_total))
}
}

View File

@@ -0,0 +1,154 @@
use std::sync::Arc;
#[cfg(feature = "ddp")]
use burn_collective::CollectiveConfig;
use burn_core::{module::AutodiffModule, prelude::Backend};
use crate::{
EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
TrainingModel, ValidLoader,
components::LearningComponentsTypes,
metric::{
processor::{EventProcessorTraining, LearnerEvent},
store::EventStoreClient,
},
};
type LearnerDevice<LC> = <<LC as LearningComponentsTypes>::Backend as Backend>::Device;
/// A reference to an implementation of SupervisedLearningStrategy.
pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
#[derive(Clone, Copy, Debug)]
/// Determine how the optimization is performed when training with multiple devices.
pub enum MultiDeviceOptim {
/// The optimization is done on an elected device.
OptimMainDevice,
/// The optimization is sharded across all devices.
OptimSharded,
}
/// How should the learner run the learning for the model
#[derive(Clone)]
pub enum TrainingStrategy<LC: LearningComponentsTypes> {
/// Training on one device
SingleDevice(LearnerDevice<LC>),
/// Performs data-parallel distributed training where the optimization is
/// done on an elected master device.
MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
/// Training using a custom learning strategy
Custom(CustomLearningStrategy<LC>),
/// Training with input distributed across devices, each device has its own copy of the model.
/// Collective ops are used to sync the gradients after each pass.
#[cfg(feature = "ddp")]
DistributedDataParallel {
/// Devices on this node for the DDP
devices: Vec<LearnerDevice<LC>>,
/// The configuration for collective operations
/// num_devices is ignored
config: CollectiveConfig,
},
}
/// Constructor for a distributed data parallel (DDP) learning strategy
#[cfg(feature = "ddp")]
pub fn ddp<LC: LearningComponentsTypes>(
devices: Vec<LearnerDevice<LC>>,
config: CollectiveConfig,
) -> TrainingStrategy<LC> {
TrainingStrategy::DistributedDataParallel { devices, config }
}
impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
fn default() -> Self {
Self::SingleDevice(Default::default())
}
}
/// Struct to minimise parameters passed to [SupervisedLearningStrategy::train].
/// These components are used during training.
pub struct TrainingComponents<LC: LearningComponentsTypes> {
/// The total number of epochs
pub num_epochs: usize,
/// The epoch number from which to continue the training.
pub checkpoint: Option<usize>,
/// A checkpointer used to load and save learner checkpoints.
pub checkpointer: Option<LearningCheckpointer<LC>>,
/// Enables gradients accumulation.
pub grad_accumulation: Option<usize>,
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
pub interrupter: Interrupter,
/// Cloneable reference to an early stopping strategy.
pub early_stopping: Option<EarlyStoppingStrategyRef>,
/// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and validation.
pub event_processor: SupervisedTrainingEventProcessor<LC>,
/// A reference to an [EventStoreClient](EventStoreClient).
pub event_store: Arc<EventStoreClient>,
/// Config for creating a summary of the learning
pub summary: Option<LearnerSummaryConfig>,
}
/// Provides the `fit` function for any learning strategy
pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
/// Train the learner's model with this strategy.
fn train(
&self,
mut learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
mut training_components: TrainingComponents<LC>,
) -> LearningResult<InferenceModel<LC>> {
let starting_epoch = match training_components.checkpoint {
Some(checkpoint) => {
if let Some(checkpointer) = &mut training_components.checkpointer {
learner =
checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
}
checkpoint + 1
}
None => 1,
};
let summary_config = training_components.summary.clone();
// Event processor start training
training_components
.event_processor
.process_train(LearnerEvent::Start);
// Training loop
let (model, mut event_processor) = self.fit(
training_components,
learner,
dataloader_train,
dataloader_valid,
starting_epoch,
);
let summary = summary_config.and_then(|summary| {
summary
.init()
.map(|summary| summary.with_model(model.to_string()))
.ok()
});
// Signal training end. For the TUI renderer, this handles the exit & return to main screen.
event_processor.process_train(LearnerEvent::End(summary));
let model = model.valid();
let renderer = event_processor.renderer();
LearningResult::<InferenceModel<LC>> { model, renderer }
}
/// Training loop for this strategy
fn fit(
&self,
training_components: TrainingComponents<LC>,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
}

View File

@@ -0,0 +1,17 @@
## DDP
Distributed Data Parallel
The DDP is a learning strategy that trains a replica of the model on each device.
The DDP launches threads for each local device. Each thread on each node will run the model.
After the forward and backward passes, the gradients are synced between all peers on all nodes
with an `all-reduce` operation.
While the DDP launches threads for each local device, it is the user's responsibility to launch the
DDP on each node, and assure the collective configuration matches.
## Main device vs secondary devices
The main device is responsible for validation, as well as event processing, which is used in the UI.
The first device is chosen as the main device.

View File

@@ -0,0 +1,234 @@
use burn_collective::{PeerId, ReduceOperation};
use burn_core::data::dataloader::Progress;
use burn_core::module::AutodiffModule;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::GradientsAccumulator;
use burn_optim::GradientsParams;
use std::marker::PhantomData;
use std::sync::mpsc::{Receiver, SyncSender};
use std::sync::{Arc, Mutex};
use crate::SupervisedTrainingEventProcessor;
use crate::learner::base::Interrupter;
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
use crate::{
InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader,
};
/// A validation epoch.
#[derive(new)]
pub struct DdpValidEpoch<LC: LearningComponentsTypes> {
dataloader: ValidLoader<LC>,
}
/// A training epoch.
#[derive(new)]
pub struct DdpTrainEpoch<LC: LearningComponentsTypes> {
dataloader: TrainLoader<LC>,
grad_accumulation: Option<usize>,
}
impl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {
/// Runs the validation epoch.
///
/// # Arguments
///
/// * `model` - The model to validate.
/// * `processor` - The event processor to use.
pub fn run(
&self,
model: &<LC as LearningComponentsTypes>::TrainingModel,
global_progress: &Progress,
processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
) {
let epoch = global_progress.items_processed;
log::info!("Executing validation step for epoch {}", epoch);
let model = model.valid();
let mut iterator = self.dataloader.iter();
let mut iteration = 0;
while let Some(item) = iterator.next() {
let progress = iterator.progress();
iteration += 1;
let item = model.step(item);
let item = TrainingItem::new(
item,
progress,
global_progress.clone(),
Some(iteration),
None,
);
processor.process_valid(LearnerEvent::ProcessedItem(item));
if interrupter.should_stop() {
log::info!("Training interrupted.");
break;
}
}
processor.process_valid(LearnerEvent::EndEpoch(epoch));
}
}
impl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {
/// Runs the training epoch.
///
/// # Arguments
///
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `scheduler` - The learning rate scheduler to use.
/// * `processor` - The event processor to use.
///
/// # Returns
///
/// The trained model and the optimizer.
#[allow(clippy::too_many_arguments)]
pub fn run(
&self,
learner: &mut Learner<LC>,
global_progress: &Progress,
processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
interrupter: &Interrupter,
peer_id: PeerId,
peer_count: usize,
is_main: bool,
) {
let epoch = global_progress.items_processed;
log::info!("Executing training step for epoch {}", epoch,);
let mut iterator = self.dataloader.iter();
let mut iteration = 0;
let mut accumulator = GradientsAccumulator::new();
let mut accumulation_current = 0;
let grads_syncer = GradsSyncer::<
TrainingBackend<LC>,
<LC as LearningComponentsTypes>::TrainingModel,
>::new(false, peer_id);
while let Some(item) = iterator.next() {
for _ in 0..peer_count {
iteration += 1;
learner.lr_step();
}
log::info!("Iteration {iteration}");
let mut progress = iterator.progress();
progress.items_processed *= peer_count;
progress.items_total *= peer_count;
let item = learner.train_step(item);
match self.grad_accumulation {
Some(accumulation) => {
accumulator.accumulate(&learner.model(), item.grads);
accumulation_current += 1;
if accumulation <= accumulation_current {
let grads = accumulator.grads();
// With double buffering, these are the previous iteration's gradients
let grads = grads_syncer.sync(grads);
if let Some(grads) = grads {
learner.optimizer_step(grads);
}
accumulation_current = 0;
}
}
None => {
// With double buffering, these are the previous iteration's gradients
let grads = grads_syncer.sync(item.grads);
if let Some(grads) = grads {
learner.optimizer_step(grads);
}
}
}
let item = TrainingItem::new(
item.item,
progress,
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);
{
let mut processor = processor.lock().unwrap();
processor.process_train(LearnerEvent::ProcessedItem(item));
}
if interrupter.should_stop() {
log::info!("Training interrupted.");
break;
}
}
if is_main {
let mut processor = processor.lock().unwrap();
processor.process_train(LearnerEvent::EndEpoch(epoch));
}
}
}
/// Worker that is responsible for syncing gradients for the DDP worker. With double buffering,
/// this allows for more optimization.
struct GradsSyncer<B: AutodiffBackend, M: AutodiffModule<B> + 'static> {
msg_send: SyncSender<GradientsParams>,
// Optional because with double buffering, the first iteration yields no gradients.
result_recv: Receiver<Option<GradientsParams>>,
_p: PhantomData<(B, M)>,
}
impl<B: AutodiffBackend, M: AutodiffModule<B> + 'static> GradsSyncer<B, M> {
fn new(double_buffering: bool, peer_id: PeerId) -> Self {
let (msg_send, msg_recv) = std::sync::mpsc::sync_channel::<GradientsParams>(1);
let (result_send, result_recv) =
std::sync::mpsc::sync_channel::<Option<GradientsParams>>(1);
std::thread::spawn(move || {
Self::run_worker(double_buffering, peer_id, result_send, msg_recv)
});
Self {
msg_send,
result_recv,
_p: PhantomData,
}
}
fn sync(&self, grads: GradientsParams) -> Option<GradientsParams> {
self.msg_send.send(grads).unwrap();
self.result_recv.recv().unwrap()
}
fn run_worker(
double_buffering: bool,
peer_id: PeerId,
send: SyncSender<Option<GradientsParams>>,
recv: Receiver<GradientsParams>,
) {
let mut grads_buffer = None;
while let Ok(new_grads) = recv.recv() {
// Sync grads with collective
let new_grads = new_grads
.all_reduce::<B::InnerBackend>(peer_id, ReduceOperation::Mean)
.expect("DDP worker could not sync gradients!");
if double_buffering {
let old_grads = grads_buffer.take();
grads_buffer = Some(new_grads);
send.send(old_grads).unwrap();
} else {
send.send(Some(new_grads)).unwrap();
}
}
// GradsSyncer dropped, channel closed, this thread can end
}
}

View File

@@ -0,0 +1,5 @@
mod epoch;
mod strategy;
mod worker;
pub use strategy::*;

View File

@@ -0,0 +1,140 @@
use core::panic;
use std::sync::{Arc, Mutex};
use burn_collective::CollectiveConfig;
use burn_core::tensor::Device;
use burn_core::tensor::backend::DeviceOps;
use crate::ddp::worker::DdpWorker;
use crate::metric::store::EventStoreClient;
use crate::{
EarlyStoppingStrategyRef, Interrupter, Learner, LearningComponentsTypes,
SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend,
TrainingComponents, TrainingModel, ValidLoader,
};
use burn_core::data::dataloader::split::split_dataloader;
#[derive(Clone)]
pub(crate) struct WorkerComponents {
/// The total number of epochs
pub num_epochs: usize,
/// Enables gradients accumulation.
pub grad_accumulation: Option<usize>,
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
pub interrupter: Interrupter,
/// Cloneable reference to an early stopping strategy.
pub early_stopping: Option<EarlyStoppingStrategyRef>,
/// A reference to an [EventStoreClient](EventStoreClient).
pub event_store: Arc<EventStoreClient>,
}
pub struct DdpTrainingStrategy<LC: LearningComponentsTypes> {
devices: Vec<Device<TrainingBackend<LC>>>,
config: CollectiveConfig,
}
impl<LC: LearningComponentsTypes> DdpTrainingStrategy<LC> {
pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, config: CollectiveConfig) -> Self {
let config = config.with_num_devices(devices.len());
Self { devices, config }
}
}
impl<LC: LearningComponentsTypes + Send + 'static> SupervisedLearningStrategy<LC>
for DdpTrainingStrategy<LC>
{
fn fit(
&self,
training_components: TrainingComponents<LC>,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
// The reference model is always on the first device provided.
let main_device = self.devices.first().unwrap();
// One worker per device, so we use a fixed device strategy
// for each (worker) data loader. This matches the expected device on the worker, so we
// don't have to move the data between devices.
let mut dataloaders_train = split_dataloader(dataloader_train, &self.devices);
let dataloader_valid = dataloader_valid.to_device(main_device.inner());
let main_device = self.devices[0].clone();
let peer_count = self.devices.len();
let event_processor = Arc::new(Mutex::new(training_components.event_processor));
let interrupter = training_components.interrupter;
let worker_components = WorkerComponents {
num_epochs: training_components.num_epochs,
grad_accumulation: training_components.grad_accumulation,
interrupter: interrupter.clone(),
early_stopping: training_components.early_stopping,
event_store: training_components.event_store,
};
// Start worker for main device
// First training dataloader corresponds to main device
let main_handle = DdpWorker::<LC>::start(
0.into(),
main_device,
learner.clone(),
event_processor.clone(),
worker_components.clone(),
training_components.checkpointer,
dataloaders_train.remove(0),
Some(dataloader_valid),
self.config.clone(),
starting_epoch,
peer_count,
true,
);
// Spawn other workers for the other devices, starting with peer id 1
let mut peer_id = 1;
let mut secondary_workers = vec![];
for device in &self.devices[1..] {
let handle = DdpWorker::<LC>::start(
peer_id.into(),
device.clone(),
learner.clone(),
event_processor.clone(),
worker_components.clone(),
None,
dataloaders_train.remove(0),
None,
self.config.clone(),
starting_epoch,
peer_count,
false,
);
peer_id += 1;
secondary_workers.push(handle);
}
// Wait for all devices to finish
for worker in secondary_workers {
worker
.join()
.expect("Distributed data parallel worker failed");
}
// Main worker had the event processor
let model = main_handle
.join()
.expect("Distributed data parallel main worker failed");
if interrupter.should_stop() {
let reason = interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
}
let Ok(event_processor) = Arc::try_unwrap(event_processor) else {
panic!("Event processor still held!");
};
let Ok(event_processor) = event_processor.into_inner() else {
panic!("Event processor lock poisoned");
};
(model, event_processor)
}
}

View File

@@ -0,0 +1,135 @@
use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch};
use crate::ddp::strategy::WorkerComponents;
use crate::single::TrainingLoop;
use crate::{
Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor,
TrainLoader, TrainingBackend, ValidLoader,
};
use burn_collective::{self, CollectiveConfig, PeerId};
use burn_core::tensor::Device;
use burn_core::tensor::backend::AutodiffBackend;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
/// A worker runs the model, syncing gradients using collective operations.
/// Event processing and validation is optional too.
pub(crate) struct DdpWorker<LC>
where
LC: LearningComponentsTypes + Send + 'static,
{
peer_id: PeerId,
device: Device<TrainingBackend<LC>>,
learner: Learner<LC>,
event_processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
components: WorkerComponents,
checkpointer: Option<LearningCheckpointer<LC>>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: Option<ValidLoader<LC>>,
collective_config: CollectiveConfig,
starting_epoch: usize,
peer_count: usize,
is_main: bool,
}
impl<LC> DdpWorker<LC>
where
LC: LearningComponentsTypes + Send + 'static,
{
/// Starts a worker that runs the model in a data distributed parallel
#[allow(clippy::too_many_arguments)]
pub fn start(
peer_id: PeerId,
device: Device<TrainingBackend<LC>>,
learner: Learner<LC>,
event_processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
components: WorkerComponents,
checkpointer: Option<LearningCheckpointer<LC>>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: Option<ValidLoader<LC>>,
collective_config: CollectiveConfig,
starting_epoch: usize,
peer_count: usize,
is_main: bool,
) -> JoinHandle<<LC as LearningComponentsTypes>::TrainingModel> {
let worker = Self {
peer_id,
device,
learner,
event_processor,
components,
checkpointer,
dataloader_train,
dataloader_valid,
collective_config,
starting_epoch,
peer_count,
is_main,
};
std::thread::spawn(|| worker.fit())
}
/// Fits the model,
pub fn fit(mut self) -> <LC as LearningComponentsTypes>::TrainingModel {
burn_collective::register::<<TrainingBackend<LC> as AutodiffBackend>::InnerBackend>(
self.peer_id,
self.device.clone(),
self.collective_config.clone(),
)
.expect("Couldn't register for collective operations!");
let num_epochs = self.components.num_epochs;
let interrupter = self.components.interrupter;
// Changed the train epoch to keep the dataloaders
let epoch_train = DdpTrainEpoch::<LC>::new(
self.dataloader_train.clone(),
self.components.grad_accumulation,
);
let epoch_valid = self
.dataloader_valid
.map(|dataloader| DdpValidEpoch::<LC>::new(dataloader));
self.learner.fork(&self.device);
for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) {
let epoch = training_progress.items_processed;
epoch_train.run(
&mut self.learner,
&training_progress,
self.event_processor.clone(),
&interrupter,
self.peer_id,
self.peer_count,
self.is_main,
);
if interrupter.should_stop() {
break;
}
// Validation
if let Some(runner) = &epoch_valid {
let mut event_processor = self.event_processor.lock().unwrap();
runner.run(
&self.learner.model(),
&training_progress,
&mut event_processor,
&interrupter,
);
}
if let Some(checkpointer) = &mut self.checkpointer {
checkpointer.checkpoint(&self.learner, epoch, &self.components.event_store);
}
if let Some(early_stopping) = &mut self.components.early_stopping
&& early_stopping.should_stop(epoch, &self.components.event_store)
{
break;
}
}
self.learner.model()
}
}

View File

@@ -0,0 +1,8 @@
mod base;
#[cfg(feature = "ddp")]
pub(crate) mod ddp;
pub(crate) mod multi;
pub(crate) mod single;
pub use base::*;

View File

@@ -0,0 +1,219 @@
use crate::learner::base::Interrupter;
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
use crate::train::MultiDevicesTrainStep;
use crate::{
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor,
TrainLoader, TrainingBackend,
};
use burn_core::data::dataloader::Progress;
use burn_core::prelude::DeviceOps;
use burn_core::tensor::Device;
use burn_core::tensor::backend::DeviceId;
use burn_optim::GradientsAccumulator;
use burn_optim::MultiGradientsParams;
use std::collections::HashMap;
/// A training epoch.
#[derive(new)]
pub struct MultiDeviceTrainEpoch<LC: LearningComponentsTypes> {
dataloaders: Vec<TrainLoader<LC>>,
grad_accumulation: Option<usize>,
}
impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
/// Runs the training epoch on multiple devices.
///
/// # Arguments
///
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `lr_scheduler` - The learning rate scheduler to use.
/// * `processor` - The event processor to use.
/// * `devices` - The devices to use.
///
/// # Returns
///
/// The trained model and the optimizer.
#[allow(clippy::too_many_arguments)]
pub fn run(
&self,
learner: &mut Learner<LC>,
global_progress: &Progress,
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
devices: Vec<Device<TrainingBackend<LC>>>,
strategy: MultiDeviceOptim,
) {
match strategy {
MultiDeviceOptim::OptimMainDevice => self.run_optim_main(
learner,
global_progress,
event_processor,
interrupter,
devices,
),
MultiDeviceOptim::OptimSharded => self.run_optim_distr(
learner,
global_progress,
event_processor,
interrupter,
devices,
),
}
}
fn run_optim_main(
&self,
learner: &mut Learner<LC>,
global_progress: &Progress,
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
devices: Vec<Device<TrainingBackend<LC>>>,
) {
let epoch = global_progress.items_processed;
log::info!(
"Executing training step for epoch {} on devices {:?}",
epoch,
devices
);
let mut iterators = self
.dataloaders
.iter()
.map(|d| d.iter())
.collect::<Vec<_>>();
let mut iteration = 0;
let mut accumulator = GradientsAccumulator::new();
let mut accumulation_current = 0;
let accumulation = self.grad_accumulation.unwrap_or(1);
let step = MultiDevicesTrainStep::<LC>::new(&devices);
// The main device is always the first in the list.
let device_main = devices.first().expect("A minimum of one device.").clone();
loop {
let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model());
if items.is_empty() {
break;
}
learner.lr_step();
let mut progress_items = Vec::with_capacity(items.len());
for item in items.into_iter() {
let grads = item.output.grads.to_device(&device_main, &learner.model());
accumulator.accumulate(&learner.model(), grads);
progress_items.push(item.output.item);
}
accumulation_current += 1;
if accumulation <= accumulation_current {
let grads = accumulator.grads();
learner.optimizer_step(grads);
accumulation_current = 0;
}
for item in progress_items {
iteration += 1;
let item = TrainingItem::new(
item,
progress.clone(),
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);
event_processor.process_train(LearnerEvent::ProcessedItem(item));
}
if interrupter.should_stop() {
break;
}
}
event_processor.process_train(LearnerEvent::EndEpoch(epoch));
}
fn run_optim_distr(
&self,
learner: &mut Learner<LC>,
global_progress: &Progress,
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
devices: Vec<Device<TrainingBackend<LC>>>,
) {
let epoch = global_progress.items_processed;
log::info!(
"Executing training step for epoch {} on devices {:?}",
epoch,
devices
);
let mut iterators = self
.dataloaders
.iter()
.map(|d| d.iter())
.collect::<Vec<_>>();
let mut iteration = 0;
let mut accumulators = HashMap::<
DeviceId,
GradientsAccumulator<<LC as LearningComponentsTypes>::TrainingModel>,
>::new();
for device in devices.iter() {
accumulators.insert(device.to_id(), GradientsAccumulator::new());
}
let mut accumulation_current = 0;
let accumulation = self.grad_accumulation.unwrap_or(1);
let step = MultiDevicesTrainStep::<LC>::new(&devices);
loop {
let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model());
if items.is_empty() {
break;
}
learner.lr_step();
let mut progress_items = Vec::with_capacity(items.len());
for item in items.into_iter() {
let accumulator = accumulators.get_mut(&item.device).unwrap();
accumulator.accumulate(&learner.model(), item.output.grads);
progress_items.push(item.output.item);
}
accumulation_current += 1;
if accumulation <= accumulation_current {
let mut grads = MultiGradientsParams::default();
for (device_id, accumulator) in accumulators.iter_mut() {
let grad = accumulator.grads();
grads.grads.push((grad, *device_id));
}
learner.optimizer_step_multi(grads);
accumulation_current = 0;
}
for item in progress_items {
iteration += 1;
let item = TrainingItem::new(
item,
progress.clone(),
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);
event_processor.process_train(LearnerEvent::ProcessedItem(item));
}
if interrupter.should_stop() {
break;
}
}
event_processor.process_train(LearnerEvent::EndEpoch(epoch));
}
}

View File

@@ -0,0 +1,4 @@
pub(crate) mod epoch;
mod strategy;
pub use strategy::*;

View File

@@ -0,0 +1,100 @@
use crate::{
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy,
SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents,
TrainingModel, ValidLoader,
multi::epoch::MultiDeviceTrainEpoch,
single::{TrainingLoop, epoch::SingleDeviceValidEpoch},
};
use burn_core::{
data::dataloader::split::split_dataloader,
tensor::{Device, backend::DeviceOps},
};
pub struct MultiDeviceLearningStrategy<LC: LearningComponentsTypes> {
devices: Vec<Device<TrainingBackend<LC>>>,
optim: MultiDeviceOptim,
}
impl<LC: LearningComponentsTypes> MultiDeviceLearningStrategy<LC> {
pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, optim: MultiDeviceOptim) -> Self {
Self { devices, optim }
}
}
impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
for MultiDeviceLearningStrategy<LC>
{
fn fit(
&self,
training_components: TrainingComponents<LC>,
mut learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
let main_device = self.devices.first().unwrap();
// `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy
// for each (worker) data loader. This matches the expected device on the worker, so we
// don't have to move the data between devices.
let dataloader_train = split_dataloader(dataloader_train, &self.devices);
let dataloader_valid = dataloader_valid.to_device(main_device.inner());
learner.fork(main_device);
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let mut early_stopping = training_components.early_stopping;
let epoch_train = MultiDeviceTrainEpoch::<LC>::new(
dataloader_train.clone(),
training_components.grad_accumulation,
);
let epoch_valid: SingleDeviceValidEpoch<LC> =
SingleDeviceValidEpoch::new(dataloader_valid.clone());
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
let epoch = training_progress.items_processed;
epoch_train.run(
&mut learner,
&training_progress,
&mut event_processor,
&training_components.interrupter,
self.devices.to_vec(),
self.optim,
);
if training_components.interrupter.should_stop() {
let reason = training_components
.interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
break;
}
// After OptimSharded training, model parameters are scattered across
// devices. Fork back to main_device before single-device validation.
if matches!(self.optim, MultiDeviceOptim::OptimSharded) {
learner.fork(main_device);
}
epoch_valid.run(
&learner,
&training_progress,
&mut event_processor,
&training_components.interrupter,
);
if let Some(checkpointer) = &mut checkpointer {
checkpointer.checkpoint(&learner, epoch, &training_components.event_store);
}
if let Some(early_stopping) = &mut early_stopping
&& early_stopping.should_stop(epoch, &training_components.event_store)
{
break;
}
}
(learner.model(), event_processor)
}
}

View File

@@ -0,0 +1,136 @@
use crate::learner::base::Interrupter;
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
use crate::{
InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader,
ValidLoader,
};
use burn_core::data::dataloader::Progress;
use burn_core::module::AutodiffModule;
use burn_optim::GradientsAccumulator;
/// A validation epoch.
#[derive(new)]
pub struct SingleDeviceValidEpoch<LC: LearningComponentsTypes> {
dataloader: ValidLoader<LC>,
}
/// A training epoch.
#[derive(new)]
pub struct SingleDeviceTrainEpoch<LC: LearningComponentsTypes> {
dataloader: TrainLoader<LC>,
grad_accumulation: Option<usize>,
}
impl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {
/// Runs the validation epoch.
///
/// # Arguments
///
/// * `model` - The model to validate.
/// * `processor` - The event processor to use.
pub fn run(
&self,
learner: &Learner<LC>,
global_progress: &Progress,
processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
) {
let epoch = global_progress.items_processed;
log::info!("Executing validation step for epoch {}", epoch);
let model = learner.model().valid();
let mut iterator = self.dataloader.iter();
let mut iteration = 0;
while let Some(item) = iterator.next() {
let progress = iterator.progress();
iteration += 1;
let item = model.step(item);
let item = TrainingItem::new(
item,
progress,
global_progress.clone(),
Some(iteration),
None,
);
processor.process_valid(LearnerEvent::ProcessedItem(item));
if interrupter.should_stop() {
break;
}
}
processor.process_valid(LearnerEvent::EndEpoch(epoch));
}
}
impl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {
/// Runs the training epoch.
///
/// # Arguments
///
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `scheduler` - The learning rate scheduler to use.
/// * `processor` - The event processor to use.
///
/// # Returns
///
/// The trained model and the optimizer.
pub fn run(
&self,
learner: &mut Learner<LC>,
global_progress: &Progress,
processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
) {
let epoch = global_progress.items_processed;
log::info!("Executing training step for epoch {}", epoch,);
// Single device / dataloader
let mut iterator = self.dataloader.iter();
let mut iteration = 0;
let mut accumulator = GradientsAccumulator::new();
let mut accumulation_current = 0;
while let Some(item) = iterator.next() {
iteration += 1;
learner.lr_step();
log::info!("Iteration {iteration}");
let progress = iterator.progress();
let item = learner.train_step(item);
match self.grad_accumulation {
Some(accumulation) => {
accumulator.accumulate(&learner.model(), item.grads);
accumulation_current += 1;
if accumulation <= accumulation_current {
let grads = accumulator.grads();
learner.optimizer_step(grads);
accumulation_current = 0;
}
}
None => learner.optimizer_step(item.grads),
}
let item = TrainingItem::new(
item.item,
progress,
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);
processor.process_train(LearnerEvent::ProcessedItem(item));
if interrupter.should_stop() {
break;
}
}
processor.process_train(LearnerEvent::EndEpoch(epoch));
}
}

View File

@@ -0,0 +1,4 @@
pub(crate) mod epoch;
mod strategy;
pub use strategy::*;

View File

@@ -0,0 +1,108 @@
use crate::{
Learner, LearningComponentsTypes, SupervisedLearningStrategy, SupervisedTrainingEventProcessor,
TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader,
single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch},
};
use burn_core::{
data::dataloader::Progress,
tensor::{Device, backend::DeviceOps},
};
/// Simplest learning strategy possible, with only a single devices doing both the training and
/// validation.
pub struct SingleDeviceTrainingStrategy<LC: LearningComponentsTypes> {
device: Device<TrainingBackend<LC>>,
}
impl<LC: LearningComponentsTypes> SingleDeviceTrainingStrategy<LC> {
pub fn new(device: Device<TrainingBackend<LC>>) -> Self {
Self { device }
}
}
#[derive(new)]
pub(crate) struct TrainingLoop {
next_iteration: usize,
total_iteration: usize,
}
/// An iterator that returns the progress of the training.
impl Iterator for TrainingLoop {
type Item = Progress;
fn next(&mut self) -> Option<Self::Item> {
if self.next_iteration > self.total_iteration {
return None;
}
let progress = Progress {
items_processed: self.next_iteration,
items_total: self.total_iteration,
};
self.next_iteration += 1;
Some(progress)
}
}
impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
for SingleDeviceTrainingStrategy<LC>
{
fn fit(
&self,
training_components: TrainingComponents<LC>,
mut learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
let dataloader_train = dataloader_train.to_device(&self.device);
let dataloader_valid = dataloader_valid.to_device(self.device.inner());
learner.fork(&self.device);
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let mut early_stopping = training_components.early_stopping;
let epoch_train: SingleDeviceTrainEpoch<LC> =
SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation);
let epoch_valid: SingleDeviceValidEpoch<LC> =
SingleDeviceValidEpoch::new(dataloader_valid.clone());
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
let epoch = training_progress.items_processed;
epoch_train.run(
&mut learner,
&training_progress,
&mut event_processor,
&training_components.interrupter,
);
if training_components.interrupter.should_stop() {
let reason = training_components
.interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
break;
}
epoch_valid.run(
&learner,
&training_progress,
&mut event_processor,
&training_components.interrupter,
);
if let Some(checkpointer) = &mut checkpointer {
checkpointer.checkpoint(&learner, epoch, &training_components.event_store);
}
if let Some(early_stopping) = &mut early_stopping
&& early_stopping.should_stop(epoch, &training_components.event_store)
{
break;
}
}
(learner.model(), event_processor)
}
}

View File

@@ -0,0 +1,129 @@
use crate::{ItemLazy, renderer::MetricsRenderer};
use burn_core::module::AutodiffModule;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
/// A training output.
pub struct TrainOutput<TO> {
/// The gradients.
pub grads: GradientsParams,
/// The item.
pub item: TO,
}
impl<TO> TrainOutput<TO> {
/// Creates a new training output.
///
/// # Arguments
///
/// * `module` - The module.
/// * `grads` - The gradients.
/// * `item` - The item.
///
/// # Returns
///
/// A new training output.
pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
module: &M,
grads: B::Gradients,
item: TO,
) -> Self {
let grads = GradientsParams::from_grads(grads, module);
Self { grads, item }
}
}
/// Trait to be implemented for models to be able to be trained.
///
/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
///
/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
/// optimizer is used to update the model. This can be useful if you want to call custom mutable
/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
///
/// # Notes
///
/// To be used with the [Learner](crate::Learner) struct, the struct which implements this trait must
/// also implement the [AutodiffModule] trait, which is done automatically with the
/// [Module](burn_core::module::Module) derive.
pub trait TrainStep {
/// Type of input for a step of the training stage.
type Input: Send + 'static;
/// Type of output for a step of the training stage.
type Output: ItemLazy + 'static;
/// Runs a step for training, which executes the forward and backward passes.
///
/// # Arguments
///
/// * `item` - The input for the model.
///
/// # Returns
///
/// The output containing the model output and the gradients.
fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;
/// Optimize the current module with the provided gradients and learning rate.
///
/// # Arguments
///
/// * `optim`: Optimizer used for learning.
/// * `lr`: The learning rate used for this step.
/// * `grads`: The gradients of each parameter in the current model.
///
/// # Returns
///
/// The updated model.
fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
where
B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B>,
{
optim.step(lr, self, grads)
}
/// Optimize the current module with the provided gradients and learning rate.
///
/// # Arguments
///
/// * `optim`: Optimizer used for learning.
/// * `lr`: The learning rate used for this step.
/// * `grads`: Multiple gradients associated to each parameter in the current model.
///
/// # Returns
///
/// The updated model.
fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
where
B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B>,
{
optim.step_multi(lr, self, grads)
}
}
/// Trait to be implemented for validating models.
pub trait InferenceStep {
/// Type of input for an inference step.
type Input: Send + 'static;
/// Type of output for an inference step.
type Output: ItemLazy + 'static;
/// Runs a validation step.
///
/// # Arguments
///
/// * `item` - The item to validate on.
///
/// # Returns
///
/// The validation output.
fn step(&self, item: Self::Input) -> Self::Output;
}
/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
pub struct LearningResult<M> {
/// The model with the learned weights.
pub model: M,
/// The renderer that can be used for follow up training and evaluation.
pub renderer: Box<dyn MetricsRenderer>,
}

View File

@@ -0,0 +1,118 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! A library for training neural networks using the burn crate.
#[macro_use]
extern crate derive_new;
/// The checkpoint module.
pub mod checkpoint;
pub(crate) mod components;
/// Renderer modules to display metrics and training information.
pub mod renderer;
/// The logger module.
pub mod logger;
/// The metric module.
pub mod metric;
pub use metric::processor::*;
mod learner;
pub use learner::*;
mod evaluator;
pub use evaluator::*;
pub use components::*;
#[cfg(test)]
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(test)]
pub(crate) mod tests {
use crate::TestBackend;
use burn_core::{prelude::Tensor, tensor::Bool};
use std::default::Default;
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
/// Probability of tp before adding errors
pub const THRESHOLD: f64 = 0.5;
#[derive(Debug, Default)]
pub enum ClassificationType {
#[default]
Binary,
Multiclass,
Multilabel,
}
/// Sample x Class shaped matrix for use in
/// classification metrics testing
pub fn dummy_classification_input(
classification_type: &ClassificationType,
) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
match classification_type {
ClassificationType::Binary => {
(
Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
// targets
Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
// predictions @ threshold=0.5
// [[0], [0], [1], [0], [1]]
)
}
ClassificationType::Multiclass => {
(
Tensor::from_data(
[
[0.2, 0.8, 0.0],
[0.3, 0.6, 0.1],
[0.7, 0.25, 0.05],
[0.1, 0.15, 0.8],
[0.9, 0.03, 0.07],
],
&Default::default(),
),
Tensor::from_data(
// targets
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
// predictions @ top_k=1
// [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
// predictions @ top_k=2
// [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
&Default::default(),
),
)
}
ClassificationType::Multilabel => {
(
Tensor::from_data(
[
[0.1, 0.7, 0.6],
[0.3, 0.9, 0.05],
[0.8, 0.9, 0.4],
[0.7, 0.5, 0.9],
[1.0, 0.3, 0.2],
],
&Default::default(),
),
// targets
Tensor::from_data(
[[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
// predictions @ threshold=0.5
// [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
&Default::default(),
),
)
}
}
}
}

View File

@@ -0,0 +1,91 @@
use super::Logger;
use std::sync::mpsc;
enum Message<T> {
Log(T),
End,
Sync(mpsc::Sender<()>),
}
/// Async logger.
pub struct AsyncLogger<T> {
sender: mpsc::Sender<Message<T>>,
handler: Option<std::thread::JoinHandle<()>>,
}
#[derive(new)]
struct LoggerThread<T, L: Logger<T>> {
logger: L,
receiver: mpsc::Receiver<Message<T>>,
}
impl<T, L> LoggerThread<T, L>
where
L: Logger<T>,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::Log(item) => {
self.logger.log(item);
}
Message::End => {
return;
}
Message::Sync(callback) => {
callback
.send(())
.expect("Can return result with the callback channel.");
}
}
}
}
}
impl<T: Send + Sync + 'static> AsyncLogger<T> {
/// Create a new async logger.
pub fn new<L>(logger: L) -> Self
where
L: Logger<T> + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = LoggerThread::new(logger, receiver);
let handler = Some(std::thread::spawn(move || thread.run()));
Self { sender, handler }
}
/// Sync the async logger.
pub(crate) fn sync(&self) {
let (sender, receiver) = mpsc::channel();
self.sender
.send(Message::Sync(sender))
.expect("Can send message to logger thread.");
receiver
.recv()
.expect("Should sync, otherwise the thread is dead.");
}
}
impl<T: Send> Logger<T> for AsyncLogger<T> {
fn log(&mut self, item: T) {
self.sender
.send(Message::Log(item))
.expect("Can log using the logger thread.");
}
}
impl<T> Drop for AsyncLogger<T> {
fn drop(&mut self) {
self.sender
.send(Message::End)
.expect("Can send the end message to the logger thread.");
let handler = self.handler.take();
if let Some(handler) = handler {
handler.join().expect("The logger thread should stop.");
}
}
}

View File

@@ -0,0 +1,9 @@
/// The logger trait.
pub trait Logger<T>: Send {
/// Logs an item.
///
/// # Arguments
///
/// * `item` - The item.
fn log(&mut self, item: T);
}

View File

@@ -0,0 +1,46 @@
use super::Logger;
use std::{fs::File, io::Write, path::Path};
/// File logger.
pub struct FileLogger {
file: File,
}
impl FileLogger {
/// Create a new file logger.
///
/// # Arguments
///
/// * `path` - The path.
///
/// # Returns
///
/// The file logger.
pub fn new(path: impl AsRef<Path>) -> Self {
let path = path.as_ref();
let mut options = std::fs::File::options();
let file = options
.write(true)
.truncate(true)
.create(true)
.open(path)
.unwrap_or_else(|err| {
panic!(
"Should be able to create the new file '{}': {}",
path.display(),
err
)
});
Self { file }
}
}
impl<T> Logger<T> for FileLogger
where
T: std::fmt::Display,
{
fn log(&mut self, item: T) {
writeln!(&mut self.file, "{item}").expect("Can log an item.");
}
}

View File

@@ -0,0 +1,16 @@
use super::Logger;
/// In memory logger.
#[derive(Default)]
pub struct InMemoryLogger {
pub(crate) values: Vec<String>,
}
impl<T> Logger<T> for InMemoryLogger
where
T: std::fmt::Display,
{
fn log(&mut self, item: T) {
self.values.push(item.to_string());
}
}

View File

@@ -0,0 +1,375 @@
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::{
MetricDefinition, MetricEntry, MetricId, NumericEntry,
store::{EpochSummary, MetricsUpdate, Split},
};
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
const EPOCH_PREFIX: &str = "epoch-";
/// Metric logger.
pub trait MetricLogger: Send {
/// Logs an item.
///
/// # Arguments
///
/// * `update` - Update information for all registered metrics.
/// * `epoch` - Current epoch.
/// * `split` - Current dataset split.
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split);
/// Read the logs for an epoch.
fn read_numeric(
&mut self,
name: &str,
epoch: usize,
split: &Split,
) -> Result<Vec<NumericEntry>, String>;
/// Logs the metric definition information (name, description, unit, etc.)
fn log_metric_definition(&mut self, definition: MetricDefinition);
/// Logs summary at the end of the epoch.
fn log_epoch_summary(&mut self, summary: EpochSummary);
}
/// The file metric logger.
pub struct FileMetricLogger {
loggers: HashMap<String, AsyncLogger<String>>,
directory: PathBuf,
metric_definitions: HashMap<MetricId, MetricDefinition>,
is_eval: bool,
last_epoch: Option<usize>,
}
impl FileMetricLogger {
/// Create a new file metric logger.
///
/// # Arguments
///
/// * `directory` - The directory.
///
/// # Returns
///
/// The file metric logger.
pub fn new(directory: impl AsRef<Path>) -> Self {
Self {
loggers: HashMap::new(),
directory: directory.as_ref().to_path_buf(),
metric_definitions: HashMap::default(),
is_eval: false,
last_epoch: None,
}
}
/// Create a new file metric logger.
///
/// # Arguments
///
/// * `directory` - The directory.
///
/// # Returns
///
/// The file metric logger.
pub fn new_eval(directory: impl AsRef<Path>) -> Self {
Self {
loggers: HashMap::new(),
directory: directory.as_ref().to_path_buf(),
metric_definitions: HashMap::default(),
is_eval: true,
last_epoch: None,
}
}
pub(crate) fn split_exists(&self, split: &Split) -> bool {
self.split_dir(split).is_some()
}
pub(crate) fn split_dir(&self, split: &Split) -> Option<PathBuf> {
let split_path = match split {
Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()),
other => self.directory.join(other.to_string()),
};
(split_path.exists() && split_path.is_dir()).then_some(split_path)
}
pub(crate) fn is_epoch_dir<P: AsRef<str>>(dirname: P) -> bool {
dirname.as_ref().starts_with(EPOCH_PREFIX)
}
/// Number of epochs recorded.
pub(crate) fn epochs(&self) -> usize {
if self.is_eval {
log::warn!("Number of epochs not available when testing.");
return 0;
}
let mut max_epoch = 0;
// with split
for path in fs::read_dir(&self.directory).unwrap() {
let path = path.unwrap();
if fs::metadata(path.path()).unwrap().is_dir() {
for split_path in fs::read_dir(path.path()).unwrap() {
let split_path = split_path.unwrap();
if fs::metadata(split_path.path()).unwrap().is_dir() {
let dir_name = split_path.file_name().into_string().unwrap();
if !dir_name.starts_with(EPOCH_PREFIX) {
continue;
}
let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
if let Some(epoch) = epoch
&& epoch > max_epoch
{
max_epoch = epoch;
}
}
}
}
}
max_epoch
}
fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf {
let name = format!("{EPOCH_PREFIX}{epoch}");
match split {
Split::Train | Split::Valid | Split::Test(None) => {
self.directory.join(split.to_string()).join(name)
}
Split::Test(Some(tag)) => {
let tag = format_tag(tag);
self.directory.join(split.to_string()).join(tag).join(name)
}
}
}
fn eval_directory(&self, split: &Split) -> PathBuf {
match split {
Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(),
Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)),
}
}
fn file_path(&self, name: &str, epoch: Option<usize>, split: &Split) -> PathBuf {
let directory = match epoch {
Some(epoch) => self.train_directory(epoch, split),
None => self.eval_directory(split),
};
let name = name.replace(' ', "_");
let name = format!("{name}.log");
directory.join(name)
}
fn create_directory(&self, epoch: Option<usize>, split: &Split) {
let directory = match epoch {
Some(epoch) => self.train_directory(epoch, split),
None => self.eval_directory(split),
};
std::fs::create_dir_all(directory).ok();
}
}
impl FileMetricLogger {
fn log_item(&mut self, item: &MetricEntry, epoch: Option<usize>, split: &Split) {
let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
let key = logger_key(name, split);
let value = &item.serialized_entry.serialized;
let logger = match self.loggers.get_mut(&key) {
Some(val) => val,
None => {
self.create_directory(epoch, split);
let file_path = self.file_path(name, epoch, split);
let logger = FileLogger::new(file_path);
let logger = AsyncLogger::new(logger);
self.loggers.insert(key.clone(), logger);
self.loggers
.get_mut(&key)
.expect("Can get the previously saved logger.")
}
};
logger.log(value.clone());
}
}
fn format_tag(tag: &str) -> String {
tag.trim().replace(' ', "-").to_lowercase()
}
impl MetricLogger for FileMetricLogger {
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
if !self.is_eval && self.last_epoch != Some(epoch) {
self.loggers.clear();
self.last_epoch = Some(epoch);
}
let entries: Vec<_> = update
.entries
.iter()
.chain(
update
.entries_numeric
.iter()
.map(|numeric_update| &numeric_update.entry),
)
.cloned()
.collect();
for item in entries.iter() {
self.log_item(item, Some(epoch), split);
}
}
fn read_numeric(
&mut self,
name: &str,
epoch: usize,
split: &Split,
) -> Result<Vec<NumericEntry>, String> {
if let Some(value) = self.loggers.get(name) {
value.sync()
}
let file_path = self.file_path(name, Some(epoch), split);
let mut errors = false;
let data = std::fs::read_to_string(file_path)
.unwrap_or_default()
.split('\n')
.filter_map(|value| {
if value.is_empty() {
None
} else {
match NumericEntry::deserialize(value) {
Ok(value) => Some(value),
Err(err) => {
log::error!("{err}");
errors = true;
None
}
}
}
})
.collect();
if errors {
Err("Parsing numeric entry errors".to_string())
} else {
Ok(data)
}
}
fn log_metric_definition(&mut self, definition: MetricDefinition) {
self.metric_definitions
.insert(definition.metric_id.clone(), definition);
}
fn log_epoch_summary(&mut self, _summary: EpochSummary) {
if !self.is_eval {
self.loggers.clear();
}
}
}
fn logger_key(name: &str, split: &Split) -> String {
format!("{name}_{split}")
}
/// In memory metric logger, useful when testing and debugging.
#[derive(Default)]
pub struct InMemoryMetricLogger {
values: HashMap<String, Vec<InMemoryLogger>>,
last_epoch: Option<usize>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
impl InMemoryMetricLogger {
/// Create a new in-memory metric logger.
pub fn new() -> Self {
Self::default()
}
}
impl MetricLogger for InMemoryMetricLogger {
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
if self.last_epoch != Some(epoch) {
self.values
.values_mut()
.for_each(|loggers| loggers.push(InMemoryLogger::default()));
self.last_epoch = Some(epoch);
}
let entries: Vec<_> = update
.entries
.iter()
.chain(
update
.entries_numeric
.iter()
.map(|numeric_update| &numeric_update.entry),
)
.cloned()
.collect();
for item in entries.iter() {
let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
let key = logger_key(name, split);
if !self.values.contains_key(&key) {
self.values
.insert(key.to_string(), vec![InMemoryLogger::default()]);
}
let values = self.values.get_mut(&key).unwrap();
values
.last_mut()
.unwrap()
.log(item.serialized_entry.serialized.clone());
}
}
fn read_numeric(
&mut self,
name: &str,
epoch: usize,
split: &Split,
) -> Result<Vec<NumericEntry>, String> {
let key = logger_key(name, split);
let values = match self.values.get(&key) {
Some(values) => values,
None => return Ok(Vec::new()),
};
match values.get(epoch - 1) {
Some(logger) => Ok(logger
.values
.iter()
.filter_map(|value| NumericEntry::deserialize(value).ok())
.collect()),
None => Ok(Vec::new()),
}
}
fn log_metric_definition(&mut self, definition: MetricDefinition) {
self.metric_definitions
.insert(definition.metric_id.clone(), definition);
}
fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
}

View File

@@ -0,0 +1,11 @@
mod async_logger;
mod base;
mod file;
mod in_memory;
mod metric;
pub use async_logger::*;
pub use base::*;
pub use file::*;
pub use in_memory::*;
pub use metric::*;

View File

@@ -0,0 +1,164 @@
use core::marker::PhantomData;
use super::MetricMetadata;
use super::state::{FormatOptions, NumericMetricState};
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
/// The accuracy metric.
#[derive(Clone)]
pub struct AccuracyMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}
/// The [accuracy metric](AccuracyMetric) input type.
#[derive(new)]
pub struct AccuracyInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Default for AccuracyMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> AccuracyMetric<B> {
/// Creates the metric.
pub fn new() -> Self {
Self {
name: MetricName::new("Accuracy".to_string()),
state: Default::default(),
pad_token: Default::default(),
_b: PhantomData,
}
}
/// Sets the pad token.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for AccuracyMetric<B> {
type Input = AccuracyInput<B>;
fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
let targets = input.targets.clone();
let outputs = input.outputs.clone();
let [batch_size, _n_classes] = outputs.dims();
let outputs = outputs.argmax(1).reshape([batch_size]);
let accuracy = match self.pad_token {
Some(pad_token) => {
let mask = targets.clone().equal_elem(pad_token as i64);
let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
let num_pad = mask.float().sum();
let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
acc.into_scalar().elem::<f64>()
}
None => {
outputs
.equal(targets)
.int()
.sum()
.into_scalar()
.elem::<f64>()
/ batch_size as f64
}
};
self.state.update(
100.0 * accuracy,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
super::NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for AccuracyMetric<B> {
fn value(&self) -> super::NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> super::NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_accuracy_without_padding() {
let device = Default::default();
let mut metric = AccuracyMetric::<TestBackend>::new();
let input = AccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8], // 2
[1.0, 2.0, 0.5], // 1
[0.4, 0.1, 0.2], // 0
[0.6, 0.7, 0.2], // 1
],
&device,
),
Tensor::from_data([2, 2, 1, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
#[test]
fn test_accuracy_with_padding() {
let device = Default::default();
let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
let input = AccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8, 0.0], // 2
[1.0, 2.0, 0.5, 0.0], // 1
[0.4, 0.1, 0.2, 0.0], // 0
[0.6, 0.7, 0.2, 0.0], // 1
[0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
[0.0, 0.1, 0.2, 0.0], // Error on padding should not count
[0.6, 0.0, 0.2, 0.0], // Error on padding should not count
],
&device,
),
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
}

View File

@@ -0,0 +1,231 @@
use core::f64;
use core::marker::PhantomData;
use super::MetricMetadata;
use super::state::{FormatOptions, NumericMetricState};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification.
#[derive(Clone)]
pub struct AurocMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
}
/// The [AUROC metric](AurocMetric) input type.
#[derive(new)]
pub struct AurocInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Default for AurocMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> AurocMetric<B> {
/// Creates the metric.
pub fn new() -> Self {
Self {
name: MetricName::new("AUROC".to_string()),
state: Default::default(),
_b: PhantomData,
}
}
fn binary_auroc(&self, probabilities: &Tensor<B, 1>, targets: &Tensor<B, 1, Int>) -> f64 {
let n = targets.dims()[0];
let n_pos = targets.clone().sum().into_scalar().elem::<u64>() as usize;
// Early return if we don't have both positive and negative samples
if n_pos == 0 || n_pos == n {
if n_pos == 0 {
log::warn!("Metric cannot be computed because all target values are negative.")
} else {
log::warn!("Metric cannot be computed because all target values are positive.")
}
return 0.0;
}
let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
let valid_pairs = pos_mask * neg_mask;
let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
let correct_order = prob_i.clone().greater(prob_j.clone()).int();
let ties = prob_i.equal(prob_j).int();
// Calculate AUC components
let num_pairs = valid_pairs.clone().sum().into_scalar().elem::<f64>();
let correct_pairs = (correct_order * valid_pairs.clone())
.sum()
.into_scalar()
.elem::<f64>();
let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::<f64>();
(correct_pairs + 0.5 * tied_pairs) / num_pairs
}
}
impl<B: Backend> Metric for AurocMetric<B> {
type Input = AurocInput<B>;
fn update(&mut self, input: &AurocInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
let [batch_size, num_classes] = input.outputs.dims();
assert_eq!(
num_classes, 2,
"Currently only binary classification is supported"
);
let probabilities = {
let exponents = input.outputs.clone().exp();
let sum = exponents.clone().sum_dim(1);
(exponents / sum)
.select(1, Tensor::arange(1..2, &input.outputs.device()))
.squeeze_dim(1)
};
let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
self.state.update(
100.0 * area_under_curve,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
}
impl<B: Backend> Numeric for AurocMetric<B> {
fn value(&self) -> super::NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> super::NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_auroc() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.9], // High confidence positive
[0.7, 0.3], // Low confidence negative
[0.6, 0.4], // Low confidence negative
[0.2, 0.8], // High confidence positive
],
&device,
),
Tensor::from_data([1, 0, 0, 1], &device), // True labels
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 100.0);
}
#[test]
fn test_auroc_perfect_separation() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
Tensor::from_data([1, 0, 0, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 100.0); // Perfect AUC
}
#[test]
fn test_auroc_random() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.5, 0.5], // Random predictions
[0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5],
],
&device,
),
Tensor::from_data([1, 0, 0, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 50.0);
}
#[test]
fn test_auroc_all_one_class() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.9], // All positives predictions
[0.2, 0.8],
[0.3, 0.7],
[0.4, 0.6],
],
&device,
),
Tensor::from_data([1, 1, 1, 1], &device), // All positive class
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 0.0);
}
#[test]
#[should_panic(expected = "Currently only binary classification is supported")]
fn test_auroc_multiclass_error() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.2, 0.7], // More than 2 classes not supported
[0.3, 0.5, 0.2],
],
&device,
),
Tensor::from_data([2, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
}
}

View File

@@ -0,0 +1,278 @@
use std::sync::Arc;
use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
/// Metric metadata that can be used when computing metrics.
pub struct MetricMetadata {
/// The current progress.
pub progress: Progress,
/// The global progress of the training (e.g. epochs).
pub global_progress: Progress,
/// The current iteration.
pub iteration: Option<usize>,
/// The current learning rate.
pub lr: Option<LearningRate>,
}
impl MetricMetadata {
/// Fake metric metadata
#[cfg(test)]
pub fn fake() -> Self {
Self {
progress: Progress {
items_processed: 1,
items_total: 1,
},
global_progress: Progress {
items_processed: 0,
items_total: 1,
},
iteration: Some(0),
lr: None,
}
}
}
/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
/// For now we take the name as id to make sure that the same metric has the same id across different runs.
#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
pub struct MetricId {
/// The metric id.
id: Arc<String>,
}
/// Metric attributes define the properties intrinsic to different types of metric.
#[derive(Clone, Debug)]
pub enum MetricAttributes {
/// Numeric attributes.
Numeric(NumericAttributes),
/// No attributes.
None,
}
/// Definition of a metric.
#[derive(Clone, Debug)]
pub struct MetricDefinition {
/// The metric's id.
pub metric_id: MetricId,
/// The name of the metric.
pub name: String,
/// The description of the metric.
pub description: Option<String>,
/// The attributes of the metric.
pub attributes: MetricAttributes,
}
impl MetricDefinition {
/// Create a new metric definition given the metric and a unique id.
pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
Self {
metric_id,
name: metric.name().to_string(),
description: metric.description(),
attributes: metric.attributes(),
}
}
}
/// Metric trait.
///
/// # Notes
///
/// Implementations should define their own input type only used by the metric.
/// This is important since some conflict may happen when the model output is adapted for each
/// metric's input type.
pub trait Metric: Send + Sync + Clone {
/// The input type of the metric.
type Input;
/// The parameterized name of the metric.
///
/// This should be unique, so avoid using short generic names, prefer using the long name.
///
/// For a metric that can exist at different parameters (e.g., top-k accuracy for different
/// values of k), the name should be unique for each instance.
fn name(&self) -> MetricName;
/// A short description of the metric.
fn description(&self) -> Option<String> {
None
}
/// Attributes of the metric.
///
/// By default, metrics have no attributes.
fn attributes(&self) -> MetricAttributes {
MetricAttributes::None
}
/// Update the metric state and returns the current metric entry.
fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
/// Clear the metric state.
fn clear(&mut self);
}
/// Type used to store metric names efficiently.
pub type MetricName = Arc<String>;
/// Adaptor are used to transform types so that they can be used by metrics.
///
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).
pub trait Adaptor<T> {
/// Adapt the type to be passed to a [metric](Metric).
fn adapt(&self) -> T;
}
impl<T> Adaptor<()> for T {
fn adapt(&self) {}
}
/// Attributes that describe intrinsic properties of a numeric metric.
#[derive(Clone, Debug)]
pub struct NumericAttributes {
/// Optional unit (e.g. "%", "ms", "pixels")
pub unit: Option<String>,
/// Whether larger values are better (true) or smaller are better (false).
pub higher_is_better: bool,
}
impl From<NumericAttributes> for MetricAttributes {
fn from(attr: NumericAttributes) -> Self {
MetricAttributes::Numeric(attr)
}
}
impl Default for NumericAttributes {
fn default() -> Self {
Self {
unit: None,
higher_is_better: true,
}
}
}
/// Declare a metric to be numeric.
///
/// This is useful to plot the values of a metric during training.
pub trait Numeric {
/// Returns the numeric value of the metric.
fn value(&self) -> NumericEntry;
/// Returns the current aggregated value of the metric over the global step (epoch).
fn running_value(&self) -> NumericEntry;
}
/// Serialized form of a metric entry.
#[derive(Debug, Clone, new)]
pub struct SerializedEntry {
/// The string to be displayed.
pub formatted: String,
/// The string to be saved.
pub serialized: String,
}
/// Data type that contains the current state of a metric at a given time.
#[derive(Debug, Clone)]
pub struct MetricEntry {
/// Id of the entry's metric.
pub metric_id: MetricId,
/// The serialized form of the entry.
pub serialized_entry: SerializedEntry,
}
impl MetricEntry {
/// Create a new metric.
pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
Self {
metric_id,
serialized_entry,
}
}
}
/// Numeric metric entry.
#[derive(Debug, Clone)]
pub enum NumericEntry {
/// Single numeric value.
Value(f64),
/// Aggregated numeric (value, number of elements).
Aggregated {
/// The aggregated value of all entries.
aggregated_value: f64,
/// The number of entries present in the aggregated value.
count: usize,
},
}
impl NumericEntry {
/// Gets the current aggregated value of the metric.
pub fn current(&self) -> f64 {
match self {
NumericEntry::Value(val) => *val,
NumericEntry::Aggregated {
aggregated_value, ..
} => *aggregated_value,
}
}
/// Returns a String representing the NumericEntry
pub fn serialize(&self) -> String {
match self {
Self::Value(v) => v.to_string(),
Self::Aggregated {
aggregated_value,
count,
} => format!("{aggregated_value},{count}"),
}
}
/// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
pub fn deserialize(entry: &str) -> Result<Self, String> {
// Check for comma separated values
let values = entry.split(',').collect::<Vec<_>>();
let num_values = values.len();
if num_values == 1 {
// Numeric value
match values[0].parse::<f64>() {
Ok(value) => Ok(NumericEntry::Value(value)),
Err(err) => Err(err.to_string()),
}
} else if num_values == 2 {
// Aggregated numeric (value, number of elements)
let (value, numel) = (values[0], values[1]);
match value.parse::<f64>() {
Ok(value) => match numel.parse::<usize>() {
Ok(numel) => Ok(NumericEntry::Aggregated {
aggregated_value: value,
count: numel,
}),
Err(err) => Err(err.to_string()),
},
Err(err) => Err(err.to_string()),
}
} else {
Err("Invalid number of values for numeric entry".to_string())
}
}
/// Compare this numeric metric's value with another one using the specified direction.
pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
(self.current() > other.current()) == higher_is_better
}
}
/// Format a float with the given precision. Will use scientific notation if necessary.
pub fn format_float(float: f64, precision: usize) -> String {
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
match scientific_notation_threshold >= float {
true => format!("{float:.precision$e}"),
false => format!("{float:.precision$}"),
}
}

View File

@@ -0,0 +1,239 @@
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricMetadata, SerializedEntry};
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};
use core::marker::PhantomData;
use std::sync::Arc;
/// Computes the edit distance (Levenshtein distance) between two sequences of integers.
///
/// The edit distance is defined as the minimum number of single-element edits (insertions,
/// deletions, or substitutions) required to change one sequence into the other. This
/// implementation is optimized for space, using only two rows of the dynamic programming table.
///
pub(crate) fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize {
let mut prev = (0..=prediction.len()).collect::<Vec<_>>();
let mut curr = vec![0; prediction.len() + 1];
for (i, &r) in reference.iter().enumerate() {
curr[0] = i + 1;
for (j, &p) in prediction.iter().enumerate() {
curr[j + 1] = if r == p {
prev[j] // no operation needed
} else {
1 + prev[j].min(prev[j + 1]).min(curr[j]) // substitution, insertion, deletion
};
}
core::mem::swap(&mut prev, &mut curr);
}
prev[prediction.len()]
}
/// Character error rate (CER) is defined as the edit distance (e.g. Levenshtein distance) between the predicted
/// and reference character sequences, divided by the total number of characters in the reference.
/// This metric is commonly used in tasks such as speech recognition, OCR, or text generation
/// to quantify how closely the predicted output matches the ground truth at a character level.
///
#[derive(Clone)]
pub struct CharErrorRate<B: Backend> {
name: MetricName,
state: NumericMetricState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}
/// The [character error rate metric](CharErrorRate) input type.
#[derive(new)]
pub struct CerInput<B: Backend> {
/// The predicted token sequences (as a 2-D tensor of token indices).
pub outputs: Tensor<B, 2, Int>,
/// The target token sequences (as a 2-D tensor of token indices).
pub targets: Tensor<B, 2, Int>,
}
impl<B: Backend> Default for CharErrorRate<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> CharErrorRate<B> {
/// Creates the metric.
pub fn new() -> Self {
Self {
name: Arc::new("CER".to_string()),
state: NumericMetricState::default(),
pad_token: None,
_b: PhantomData,
}
}
/// Sets the pad token.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
/// The [character error rate metric](CharErrorRate) implementation.
impl<B: Backend> Metric for CharErrorRate<B> {
type Input = CerInput<B>;
fn update(&mut self, input: &CerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
let outputs = &input.outputs;
let targets = &input.targets;
let [batch_size, seq_len] = targets.dims();
let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token {
// Create boolean masks for non-padding tokens.
let output_mask = outputs.clone().not_equal_elem(pad as i64);
let target_mask = targets.clone().not_equal_elem(pad as i64);
let output_lengths_tensor = output_mask.int().sum_dim(1);
let target_lengths_tensor = target_mask.int().sum_dim(1);
(
output_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
target_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
)
} else {
// If there's no padding, all sequences have the full length.
(
vec![seq_len as i64; batch_size],
vec![seq_len as i64; batch_size],
)
};
let outputs_data = outputs.to_data().to_vec::<i64>().unwrap();
let targets_data = targets.to_data().to_vec::<i64>().unwrap();
let total_edit_distance: usize = (0..batch_size)
.map(|i| {
let start = i * seq_len;
// Get pre-calculated lengths for the current sequence.
let output_len = output_lengths[i] as usize;
let target_len = target_lengths[i] as usize;
let output_seq_slice = &outputs_data[start..(start + output_len)];
let target_seq_slice = &targets_data[start..(start + target_len)];
let output_seq: Vec<i32> = output_seq_slice.iter().map(|&x| x as i32).collect();
let target_seq: Vec<i32> = target_seq_slice.iter().map(|&x| x as i32).collect();
edit_distance(&target_seq, &output_seq)
})
.sum();
let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::<f64>();
let value = if total_target_length > 0.0 {
100.0 * total_edit_distance as f64 / total_target_length
} else {
0.0
};
self.state.update(
value,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset();
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
super::NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: false,
}
.into()
}
}
impl<B: Backend> Numeric for CharErrorRate<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
/// Perfect match ⇒ CER = 0 %.
#[test]
fn test_cer_without_padding() {
let device = Default::default();
let mut metric = CharErrorRate::<TestBackend>::new();
// Batch size = 2, sequence length = 2
let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
assert_eq!(0.0, metric.value().current());
}
/// Two edits in four target tokens ⇒ 50 %.
#[test]
fn test_cer_without_padding_two_errors() {
let device = Default::default();
let mut metric = CharErrorRate::<TestBackend>::new();
// One substitution in each sequence.
let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
// 2 edits / 4 tokens = 50 %
assert_eq!(50.0, metric.value().current());
}
/// Same scenario as above, but with right-padding (token 9) ignored.
#[test]
fn test_cer_with_padding() {
let device = Default::default();
let pad = 9_i64;
let mut metric = CharErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
// Each row has three columns, last one is the pad token.
let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
/// `clear()` must reset the running statistics to zero.
#[test]
fn test_clear_resets_state() {
let device = Default::default();
let mut metric = CharErrorRate::<TestBackend>::new();
let preds = Tensor::from_data([[1, 2]], &device);
let tgts = Tensor::from_data([[1, 3]], &device); // one error
metric.update(
&CerInput::new(preds.clone(), tgts.clone()),
&MetricMetadata::fake(),
);
assert!(metric.value().current() > 0.0);
metric.clear();
assert!(metric.value().current().is_nan());
}
}

View File

@@ -0,0 +1,33 @@
use std::num::NonZeroUsize;
/// Necessary data for classification metrics.
#[derive(Default, Debug, Clone)]
pub struct ClassificationMetricConfig {
pub decision_rule: DecisionRule,
pub class_reduction: ClassReduction,
}
/// The prediction decision rule for classification metrics.
#[derive(Debug, Clone)]
pub enum DecisionRule {
/// Consider a class predicted if its probability exceeds the threshold.
Threshold(f64),
/// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
TopK(NonZeroUsize),
}
impl Default for DecisionRule {
fn default() -> Self {
Self::Threshold(0.5)
}
}
/// The reduction strategy for classification metrics.
#[derive(Copy, Clone, Default, Debug)]
pub enum ClassReduction {
/// Computes the statistics over all classes before averaging
Micro,
/// Computes the statistics independently for each class before averaging
#[default]
Macro,
}

View File

@@ -0,0 +1,351 @@
use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule};
use burn_core::{
prelude::{Backend, Bool, Int, Tensor},
tensor::IndexingUpdateOp,
};
use std::fmt::{self, Debug};
/// Input for confusion statistics error types.
#[derive(new, Debug, Clone)]
pub struct ConfusionStatsInput<B: Backend> {
/// Sample x Class Non thresholded normalized predictions.
pub predictions: Tensor<B, 2>,
/// Sample x Class one-hot encoded target.
pub targets: Tensor<B, 2, Bool>,
}
impl<B: Backend> From<ConfusionStatsInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {
fn from(input: ConfusionStatsInput<B>) -> Self {
(input.predictions, input.targets)
}
}
impl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for ConfusionStatsInput<B> {
fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {
Self::new(value.0, value.1)
}
}
#[derive(Clone)]
pub struct ConfusionStats<B: Backend> {
confusion_classes: Tensor<B, 2, Int>,
class_reduction: ClassReduction,
}
impl<B: Backend> Debug for ConfusionStats<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let to_vec = |tensor_data: Tensor<B, 1>| {
tensor_data
.to_data()
.to_vec::<f32>()
.expect("A vector representation of the input Tensor is expected")
};
let ratio_of_support_vec =
|metric: Tensor<B, 1>| to_vec(self.clone().ratio_of_support(metric));
f.debug_struct("ConfusionStats")
.field("tp", &ratio_of_support_vec(self.clone().true_positive()))
.field("fp", &ratio_of_support_vec(self.clone().false_positive()))
.field("tn", &ratio_of_support_vec(self.clone().true_negative()))
.field("fn", &ratio_of_support_vec(self.clone().false_negative()))
.field("support", &to_vec(self.clone().support()))
.finish()
}
}
impl<B: Backend> ConfusionStats<B> {
/// Expects `predictions` to be normalized.
pub fn new(input: &ConfusionStatsInput<B>, config: &ClassificationMetricConfig) -> Self {
let prediction_mask = match config.decision_rule {
DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold),
DecisionRule::TopK(top_k) => {
let mask = input.predictions.zeros_like();
let indexes =
input
.predictions
.clone()
.argsort_descending(1)
.narrow(1, 0, top_k.get());
let values = indexes.ones_like().float();
mask.scatter(1, indexes, values, IndexingUpdateOp::Add)
.bool()
}
};
Self {
confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2,
class_reduction: config.class_reduction,
}
}
/// sum over samples
fn aggregate(
sample_class_mask: Tensor<B, 2, Bool>,
class_reduction: ClassReduction,
) -> Tensor<B, 1> {
use ClassReduction::{Macro, Micro};
match class_reduction {
Micro => sample_class_mask.float().sum(),
Macro => sample_class_mask.float().sum_dim(0).squeeze_dim(0),
}
}
pub fn true_positive(self) -> Tensor<B, 1> {
Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction)
}
pub fn true_negative(self) -> Tensor<B, 1> {
Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction)
}
pub fn false_positive(self) -> Tensor<B, 1> {
Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction)
}
pub fn false_negative(self) -> Tensor<B, 1> {
Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction)
}
pub fn positive(self) -> Tensor<B, 1> {
self.clone().true_positive() + self.false_negative()
}
pub fn negative(self) -> Tensor<B, 1> {
self.clone().true_negative() + self.false_positive()
}
pub fn predicted_positive(self) -> Tensor<B, 1> {
self.clone().true_positive() + self.false_positive()
}
pub fn support(self) -> Tensor<B, 1> {
self.clone().positive() + self.negative()
}
pub fn ratio_of_support(self, metric: Tensor<B, 1>) -> Tensor<B, 1> {
metric / self.clone().support()
}
}
#[cfg(test)]
mod tests {
use super::{ConfusionStats, ConfusionStatsInput};
use crate::{
TestBackend,
metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
};
use burn_core::prelude::TensorData;
use rstest::{fixture, rstest};
use std::num::NonZeroUsize;
fn top_k_config(
top_k: NonZeroUsize,
class_reduction: ClassReduction,
) -> ClassificationMetricConfig {
ClassificationMetricConfig {
decision_rule: DecisionRule::TopK(top_k),
class_reduction,
}
}
#[fixture]
#[once]
fn top_k_config_k1_micro() -> ClassificationMetricConfig {
top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro)
}
#[fixture]
#[once]
fn top_k_config_k1_macro() -> ClassificationMetricConfig {
top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro)
}
#[fixture]
#[once]
fn top_k_config_k2_micro() -> ClassificationMetricConfig {
top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro)
}
#[fixture]
#[once]
fn top_k_config_k2_macro() -> ClassificationMetricConfig {
top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro)
}
fn threshold_config(
threshold: f64,
class_reduction: ClassReduction,
) -> ClassificationMetricConfig {
ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
class_reduction,
}
}
#[fixture]
#[once]
fn threshold_config_micro() -> ClassificationMetricConfig {
threshold_config(THRESHOLD, ClassReduction::Micro)
}
#[fixture]
#[once]
fn threshold_config_macro() -> ClassificationMetricConfig {
threshold_config(THRESHOLD, ClassReduction::Macro)
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())]
fn test_true_positive(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.true_positive()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())]
fn test_true_negative(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.true_negative()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())]
fn test_false_positive(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.false_positive()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())]
fn test_false_negatives(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.false_negative()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())]
fn test_positive(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.positive()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())]
fn test_negative(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.negative()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
#[rstest]
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())]
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())]
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())]
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())]
fn test_predicted_positive(
#[case] classification_type: ClassificationType,
#[case] config: ClassificationMetricConfig,
#[case] expected: Vec<i64>,
) {
let input: ConfusionStatsInput<TestBackend> =
dummy_classification_input(&classification_type).into();
ConfusionStats::new(&input, &config)
.predicted_positive()
.int()
.into_data()
.assert_eq(&TensorData::from(expected.as_slice()), true);
}
}

View File

@@ -0,0 +1,76 @@
use std::sync::Arc;
/// CPU Temperature metric
use super::MetricMetadata;
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry};
use systemstat::{Platform, System};
/// CPU Temperature in celsius degrees
#[derive(Clone)]
pub struct CpuTemperature {
name: MetricName,
temp_celsius: f32,
sys: Arc<System>,
}
impl CpuTemperature {
/// Creates a new CPU temp metric
pub fn new() -> Self {
let name = Arc::new("CPU Temperature".to_string());
Self {
name,
temp_celsius: 0.,
sys: Arc::new(System::new()),
}
}
}
impl Default for CpuTemperature {
fn default() -> Self {
CpuTemperature::new()
}
}
impl Metric for CpuTemperature {
type Input = ();
fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
match self.sys.cpu_temp() {
Ok(temp) => self.temp_celsius = temp,
Err(_) => self.temp_celsius = f32::NAN,
}
let formatted = match self.temp_celsius.is_nan() {
true => format!("{}: NaN °C", self.name()),
false => format!("{}: {:.2} °C", self.name(), self.temp_celsius),
};
let raw = format!("{:.2}", self.temp_celsius);
SerializedEntry::new(formatted, raw)
}
fn clear(&mut self) {}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
super::NumericAttributes {
unit: Some("°C".to_string()),
higher_is_better: false,
}
.into()
}
}
impl Numeric for CpuTemperature {
fn value(&self) -> NumericEntry {
NumericEntry::Value(self.temp_celsius as f64)
}
fn running_value(&self) -> NumericEntry {
NumericEntry::Value(self.temp_celsius as f64)
}
}

View File

@@ -0,0 +1,103 @@
use super::MetricMetadata;
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use sysinfo::{CpuRefreshKind, RefreshKind, System};
/// General CPU Usage metric
pub struct CpuUse {
name: MetricName,
last_refresh: Instant,
refresh_frequency: Duration,
sys: System,
current: f64,
}
impl Clone for CpuUse {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
last_refresh: self.last_refresh,
refresh_frequency: self.refresh_frequency,
sys: System::new(),
current: self.current,
}
}
}
impl CpuUse {
/// Creates a new CPU metric
pub fn new() -> Self {
let mut sys = System::new();
let current = Self::refresh(&mut sys);
let name = "CPU Usage".to_string();
Self {
name: Arc::new(name),
last_refresh: Instant::now(),
refresh_frequency: Duration::from_millis(200),
sys,
current,
}
}
fn refresh(sys: &mut System) -> f64 {
sys.refresh_specifics(
RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()),
);
let cpus = sys.cpus();
let num_cpus = cpus.len();
let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64;
use_percentage / num_cpus as f64
}
}
impl Default for CpuUse {
fn default() -> Self {
CpuUse::new()
}
}
impl Metric for CpuUse {
type Input = ();
fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
if self.last_refresh.elapsed() >= self.refresh_frequency {
self.current = Self::refresh(&mut self.sys);
self.last_refresh = Instant::now();
}
let formatted = format!("{}: {:.2} %", self.name(), self.current);
let raw = format!("{:.2}", self.current);
SerializedEntry::new(formatted, raw)
}
fn clear(&mut self) {}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
super::NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: false,
}
.into()
}
}
impl Numeric for CpuUse {
fn value(&self) -> NumericEntry {
NumericEntry::Value(self.current)
}
fn running_value(&self) -> NumericEntry {
NumericEntry::Value(self.current)
}
}

View File

@@ -0,0 +1,108 @@
use std::sync::Arc;
use super::MetricMetadata;
use crate::metric::{Metric, MetricName, SerializedEntry};
use nvml_wrapper::Nvml;
/// Track basic cuda infos.
#[derive(Clone)]
pub struct CudaMetric {
name: MetricName,
nvml: Arc<Option<Nvml>>,
}
impl CudaMetric {
/// Creates a new metric for CUDA.
pub fn new() -> Self {
Self {
name: Arc::new("Cuda".to_string()),
nvml: Arc::new(Nvml::init().map(Some).unwrap_or_else(|err| {
log::warn!("Unable to initialize CUDA Metric: {err}");
None
})),
}
}
}
impl Default for CudaMetric {
fn default() -> Self {
Self::new()
}
}
impl Metric for CudaMetric {
type Input = ();
fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> SerializedEntry {
let not_available =
|| SerializedEntry::new("Unavailable".to_string(), "Unavailable".to_string());
let available = |nvml: &Nvml| {
let mut formatted = String::new();
let mut raw_running = String::new();
let device_count = match nvml.device_count() {
Ok(val) => val,
Err(err) => {
log::warn!("Unable to get the number of cuda devices: {err}");
return not_available();
}
};
for index in 0..device_count {
let device = match nvml.device_by_index(index) {
Ok(val) => val,
Err(err) => {
log::warn!("Unable to get device {index}: {err}");
return not_available();
}
};
let memory_info = match device.memory_info() {
Ok(info) => info,
Err(err) => {
log::warn!("Unable to get memory info from device {index}: {err}");
return not_available();
}
};
let used_gb = memory_info.used as f64 * 1e-9;
let total_gb = memory_info.total as f64 * 1e-9;
let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
let memory_info_raw = format!("{used_gb}/{total_gb}");
formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
raw_running = format!("{memory_info_raw} ");
let utilization_rates = match device.utilization_rates() {
Ok(rate) => rate,
Err(err) => {
log::warn!("Unable to get utilization rates from device {index}: {err}");
return not_available();
}
};
let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
// Power is the currency for perf/W. NVML reports milliwatts.
if let Ok(power_mw) = device.power_usage() {
let power_w = power_mw as f64 / 1000.0;
formatted = format!("{formatted} - Power {power_w:.1} W");
}
}
SerializedEntry::new(formatted, raw_running)
};
match self.nvml.as_ref() {
Some(nvml) => available(nvml),
None => not_available(),
}
}
fn clear(&mut self) {}
fn name(&self) -> MetricName {
self.name.clone()
}
}

View File

@@ -0,0 +1,259 @@
use crate::metric::{MetricName, Numeric};
use super::{
Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
confusion_stats::{ConfusionStats, ConfusionStatsInput},
state::{FormatOptions, NumericMetricState},
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::cast::ToElement,
};
use core::marker::PhantomData;
use std::{num::NonZeroUsize, sync::Arc};
/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric.
///
/// The `beta` parameter represents the ratio of recall importance to precision importance.
/// `beta > 1` gives more weight to recall, while `beta < 1` favors precision.
#[derive(Clone)]
pub struct FBetaScoreMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
config: ClassificationMetricConfig,
beta: f64,
}
impl<B: Backend> Default for FBetaScoreMetric<B> {
fn default() -> Self {
Self::new(Default::default(), Default::default())
}
}
impl<B: Backend> FBetaScoreMetric<B> {
#[allow(dead_code)]
fn new(config: ClassificationMetricConfig, beta: f64) -> Self {
let name = Arc::new(format!(
"FBetaScore ({}) @ {:?} [{:?}]",
beta, config.decision_rule, config.class_reduction
));
Self {
name,
config,
beta,
state: Default::default(),
_b: PhantomData,
}
}
/// F-beta score metric for binary classification.
///
/// # Arguments
///
/// * `beta` - Positive real factor to weight recall's importance.
/// * `threshold` - The threshold to transform a probability into a binary prediction.
#[allow(dead_code)]
pub fn binary(beta: f64, threshold: f64) -> Self {
Self::new(
ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
// binary classification results are the same independently of class_reduction
..Default::default()
},
beta,
)
}
/// F-beta score metric for multiclass classification.
///
/// # Arguments
///
/// * `beta` - Positive real factor to weight recall's importance.
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self {
Self::new(
ClassificationMetricConfig {
decision_rule: DecisionRule::TopK(
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
),
class_reduction,
},
beta,
)
}
/// F-beta score metric for multi-label classification.
///
/// # Arguments
///
/// * `beta` - Positive real factor to weight recall's importance.
/// * `threshold` - The threshold to transform a probability into a binary prediction.
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self {
Self::new(
ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
class_reduction,
},
beta,
)
}
fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
use ClassReduction::{Macro, Micro};
let avg_tensor = match self.config.class_reduction {
Micro => aggregated_metric,
Macro => {
if aggregated_metric
.clone()
.contains_nan()
.any()
.into_scalar()
.to_bool()
{
let nan_mask = aggregated_metric.clone().is_nan();
aggregated_metric = aggregated_metric
.clone()
.select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))
}
aggregated_metric.mean()
}
};
avg_tensor.into_scalar().to_f64()
}
}
impl<B: Backend> Metric for FBetaScoreMetric<B> {
type Input = ConfusionStatsInput<B>;
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
let [sample_size, _] = input.predictions.dims();
let cf_stats = ConfusionStats::new(input, &self.config);
let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2));
let metric = self.class_average(
scaled_true_positive.clone()
/ (scaled_true_positive
+ cf_stats.clone().false_negative() * self.beta.powi(2)
+ cf_stats.false_positive()),
);
self.state.update(
100.0 * metric,
sample_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for FBetaScoreMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::{
ClassReduction::{self, *},
FBetaScoreMetric, Metric, MetricMetadata,
};
use crate::metric::Numeric;
use crate::{
TestBackend,
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
};
use burn_core::tensor::TensorData;
use burn_core::tensor::Tolerance;
use rstest::rstest;
#[rstest]
#[case::binary_b1(1.0, THRESHOLD, 0.5)]
#[case::binary_b2(2.0, THRESHOLD, 0.5)]
fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) {
let input = dummy_classification_input(&ClassificationType::Binary).into();
let mut metric = FBetaScoreMetric::binary(beta, threshold);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[rstest]
#[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)]
#[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))]
#[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)]
#[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)]
#[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)]
#[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))]
#[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)]
#[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)]
fn test_multiclass_fscore(
#[case] beta: f64,
#[case] class_reduction: ClassReduction,
#[case] top_k: usize,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[rstest]
#[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))]
#[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)]
#[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))]
#[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)]
fn test_multilabel_fscore(
#[case] beta: f64,
#[case] class_reduction: ClassReduction,
#[case] threshold: f64,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[test]
fn test_parameterized_unique_name() {
let metric_a = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 1, ClassReduction::Macro);
let metric_b = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 2, ClassReduction::Macro);
let metric_c = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 1, ClassReduction::Macro);
assert_ne!(metric_a.name(), metric_b.name());
assert_eq!(metric_a.name(), metric_c.name());
let metric_a = FBetaScoreMetric::<TestBackend>::binary(0.5, 0.5);
let metric_b = FBetaScoreMetric::<TestBackend>::binary(0.75, 0.5);
assert_ne!(metric_a.name(), metric_b.name());
}
}

View File

@@ -0,0 +1,198 @@
use core::marker::PhantomData;
use std::sync::Arc;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricMetadata, SerializedEntry};
use crate::metric::{
Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
};
use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};
/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
#[derive(Clone)]
pub struct HammingScore<B: Backend> {
name: MetricName,
state: NumericMetricState,
threshold: f32,
sigmoid: bool,
_b: PhantomData<B>,
}
/// The [hamming score](HammingScore) input type.
#[derive(new)]
pub struct HammingScoreInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 2, Int>,
}
impl<B: Backend> HammingScore<B> {
/// Creates the metric.
pub fn new() -> Self {
Self::default()
}
fn update_name(&mut self) {
self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold));
}
/// Sets the threshold.
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self.update_name();
self
}
/// Sets the sigmoid activation function usage.
pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
self.sigmoid = sigmoid;
self.update_name();
self
}
}
impl<B: Backend> Default for HammingScore<B> {
/// Creates a new metric instance with default values.
fn default() -> Self {
let threshold = 0.5;
let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold));
Self {
name,
state: NumericMetricState::default(),
threshold,
sigmoid: false,
_b: PhantomData,
}
}
}
impl<B: Backend> Metric for HammingScore<B> {
type Input = HammingScoreInput<B>;
fn update(
&mut self,
input: &HammingScoreInput<B>,
_metadata: &MetricMetadata,
) -> SerializedEntry {
let [batch_size, _n_classes] = input.outputs.dims();
let targets = input.targets.clone();
let mut outputs = input.outputs.clone();
if self.sigmoid {
outputs = sigmoid(outputs);
}
let score = outputs
.greater_elem(self.threshold)
.equal(targets.bool())
.float()
.mean()
.into_scalar()
.elem::<f64>();
self.state.update(
100.0 * score,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for HammingScore<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_hamming_score() {
let device = Default::default();
let mut metric = HammingScore::<TestBackend>::new();
let x = Tensor::from_data(
[
[0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
[0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1]
[0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1]
[0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0]
],
&device,
);
let y = Tensor::from_data(
[
[0, 1, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 1, 0, 1],
[0, 0, 1, 0, 0],
],
&device,
);
let _entry = metric.update(
&HammingScoreInput::new(x.clone(), y.clone()),
&MetricMetadata::fake(),
);
assert_eq!(100.0, metric.value().current());
// Invert all targets: y = (1 - y)
let y = y.neg().add_scalar(1);
let _entry = metric.update(
&HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
&MetricMetadata::fake(),
);
assert_eq!(0.0, metric.value().current());
// Invert 5 target values -> 1 - (5/20) = 0.75
let y = Tensor::from_data(
[
[0, 1, 1, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 1, 1, 0, 0],
],
&device,
);
let _entry = metric.update(
&HammingScoreInput::new(x, y), // invert targets (1 - y)
&MetricMetadata::fake(),
);
assert_eq!(75.0, metric.value().current());
}
#[test]
fn test_parameterized_unique_name() {
let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);
let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);
let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);
assert_ne!(metric_a.name(), metric_b.name());
assert_eq!(metric_a.name(), metric_c.name());
}
}

View File

@@ -0,0 +1,89 @@
use std::sync::Arc;
use super::MetricMetadata;
use super::SerializedEntry;
use super::state::FormatOptions;
use super::state::NumericMetricState;
use crate::metric::MetricName;
use crate::metric::Numeric;
use crate::metric::{Metric, MetricAttributes, NumericAttributes, NumericEntry};
/// The loss metric.
#[derive(Clone)]
pub struct IterationSpeedMetric {
name: MetricName,
state: NumericMetricState,
instant: Option<std::time::Instant>,
}
impl Default for IterationSpeedMetric {
fn default() -> Self {
Self::new()
}
}
impl IterationSpeedMetric {
/// Create the metric.
pub fn new() -> Self {
Self {
name: Arc::new("Iteration Speed".to_string()),
state: Default::default(),
instant: Default::default(),
}
}
}
impl Metric for IterationSpeedMetric {
type Input = ();
fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry {
let raw = match self.instant {
Some(val) => {
// If iteration is not logged, compute the speed over the number of items processed.
// 1 iteration should equal 1 item when iteration is not logged.
metadata
.iteration
.unwrap_or(metadata.progress.items_processed) as f64
/ val.elapsed().as_secs_f64()
}
None => {
self.instant = Some(std::time::Instant::now());
0.0
}
};
self.state.update(
raw,
1,
FormatOptions::new(self.name())
.unit("iter/sec")
.precision(2),
)
}
fn clear(&mut self) {
self.instant = None;
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("iter/sec".to_string()),
higher_is_better: true,
}
.into()
}
}
impl Numeric for IterationSpeedMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,67 @@
use std::sync::Arc;
use super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Track the learning rate across iterations.
#[derive(Clone)]
pub struct LearningRateMetric {
name: MetricName,
state: NumericMetricState,
}
impl LearningRateMetric {
/// Creates a new learning rate metric.
pub fn new() -> Self {
Self {
name: Arc::new("Learning Rate".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for LearningRateMetric {
fn default() -> Self {
Self::new()
}
}
impl Metric for LearningRateMetric {
type Input = ();
fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> SerializedEntry {
let lr = metadata.lr.unwrap_or(0.0);
self.state
.update(lr, 1, FormatOptions::new(self.name()).precision(2))
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: false,
}
.into()
}
}
impl Numeric for LearningRateMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,89 @@
use std::sync::Arc;
use super::MetricMetadata;
use super::SerializedEntry;
use super::state::FormatOptions;
use super::state::NumericMetricState;
use crate::metric::MetricName;
use crate::metric::{Metric, MetricAttributes, Numeric, NumericAttributes, NumericEntry};
use burn_core::tensor::Tensor;
use burn_core::tensor::backend::Backend;
/// The loss metric.
#[derive(Clone)]
pub struct LossMetric<B: Backend> {
name: Arc<String>,
state: NumericMetricState,
_b: B,
}
/// The [loss metric](LossMetric) input type.
#[derive(new)]
pub struct LossInput<B: Backend> {
tensor: Tensor<B, 1>,
}
impl<B: Backend> Default for LossMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> LossMetric<B> {
/// Create the metric.
pub fn new() -> Self {
Self {
name: Arc::new("Loss".to_string()),
state: NumericMetricState::default(),
_b: Default::default(),
}
}
}
impl<B: Backend> Metric for LossMetric<B> {
type Input = LossInput<B>;
fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
let [batch_size] = loss.tensor.dims();
let loss = loss
.tensor
.clone()
.mean()
.into_data()
.iter::<f64>()
.next()
.unwrap();
self.state.update(
loss,
batch_size,
FormatOptions::new(self.name()).precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: false,
}
.into()
}
}
impl<B: Backend> Numeric for LossMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,111 @@
/// RAM use metric
use super::{MetricAttributes, MetricMetadata, NumericAttributes};
use crate::metric::{Metric, Numeric, NumericEntry, SerializedEntry};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use sysinfo::System;
/// Memory information
pub struct CpuMemory {
name: Arc<String>,
last_refresh: Instant,
refresh_frequency: Duration,
sys: System,
ram_bytes_total: u64,
ram_bytes_used: u64,
}
impl Clone for CpuMemory {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
last_refresh: self.last_refresh,
refresh_frequency: self.refresh_frequency,
sys: System::new(),
ram_bytes_total: self.ram_bytes_total,
ram_bytes_used: self.ram_bytes_used,
}
}
}
impl CpuMemory {
/// Creates a new memory metric
pub fn new() -> Self {
let mut metric = Self {
name: Arc::new("CPU Memory".into()),
last_refresh: Instant::now(),
refresh_frequency: Duration::from_millis(200),
sys: System::new(),
ram_bytes_total: 0,
ram_bytes_used: 0,
};
metric.refresh();
metric
}
fn refresh(&mut self) {
self.sys.refresh_memory();
self.last_refresh = Instant::now();
// bytes of RAM available
self.ram_bytes_total = self.sys.total_memory();
// bytes of RAM in use
self.ram_bytes_used = self.sys.used_memory();
}
}
impl Default for CpuMemory {
fn default() -> Self {
CpuMemory::new()
}
}
impl Metric for CpuMemory {
type Input = ();
fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
if self.last_refresh.elapsed() >= self.refresh_frequency {
self.refresh();
}
let raw = bytes2gb(self.ram_bytes_used);
let formatted = format!(
"RAM Used: {:.2} / {:.2} Gb",
raw,
bytes2gb(self.ram_bytes_total),
);
SerializedEntry::new(formatted, raw.to_string())
}
fn clear(&mut self) {}
fn name(&self) -> Arc<String> {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("Gb".to_string()),
higher_is_better: false,
}
.into()
}
}
impl Numeric for CpuMemory {
fn value(&self) -> NumericEntry {
NumericEntry::Value(bytes2gb(self.ram_bytes_used))
}
fn running_value(&self) -> NumericEntry {
NumericEntry::Value(bytes2gb(self.ram_bytes_used))
}
}
fn bytes2gb(bytes: u64) -> f64 {
bytes as f64 / 1e9
}

View File

@@ -0,0 +1,71 @@
/// State module.
pub mod state;
/// Module responsible to save and exposes data collected during training.
pub mod store;
/// Metrics module for vision tasks.
#[cfg(feature = "vision")]
pub mod vision;
//Metrics for reinforcement learning.
#[cfg(feature = "rl")]
mod rl;
#[cfg(feature = "rl")]
pub use rl::*;
// System metrics
#[cfg(feature = "sys-metrics")]
mod cpu_temp;
#[cfg(feature = "sys-metrics")]
mod cpu_use;
#[cfg(feature = "sys-metrics")]
mod cuda;
#[cfg(feature = "sys-metrics")]
mod memory_use;
#[cfg(feature = "sys-metrics")]
pub use cpu_temp::*;
#[cfg(feature = "sys-metrics")]
pub use cpu_use::*;
#[cfg(feature = "sys-metrics")]
pub use cuda::*;
#[cfg(feature = "sys-metrics")]
pub use memory_use::*;
// Training metrics
mod acc;
mod auroc;
mod base;
mod cer;
mod confusion_stats;
mod fbetascore;
mod hamming;
mod iteration;
mod learning_rate;
mod loss;
mod perplexity;
mod precision;
mod recall;
mod top_k_acc;
mod wer;
pub use acc::*;
pub use auroc::*;
pub use base::*;
pub use cer::*;
pub use confusion_stats::ConfusionStatsInput;
pub use fbetascore::*;
pub use hamming::*;
pub use iteration::*;
pub use learning_rate::*;
pub use loss::*;
pub use perplexity::*;
pub use precision::*;
pub use recall::*;
pub use top_k_acc::*;
pub use wer::*;
pub(crate) mod classification;
pub(crate) mod processor;
pub use crate::metric::classification::ClassReduction;
// Expose `ItemLazy` so it can be implemented for custom types
pub use processor::ItemLazy;

View File

@@ -0,0 +1,438 @@
use core::marker::PhantomData;
use super::state::FormatOptions;
use super::{MetricMetadata, NumericEntry, SerializedEntry, format_float};
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
/// Custom state for perplexity metric that correctly accumulates negative log-likelihood.
///
/// Unlike other metrics that can be averaged, perplexity requires special handling:
/// - Accumulate total negative log-likelihood across all tokens
/// - Accumulate total number of effective tokens
/// - Compute perplexity as exp(total_nll / total_tokens) only at the end
#[derive(Clone)]
struct PerplexityState {
/// Sum of negative log-likelihood across all tokens
sum_nll: f64,
/// Total number of effective tokens (excluding padding)
total_tokens: usize,
/// Current batch perplexity (for display purposes)
current: f64,
}
impl PerplexityState {
fn new() -> Self {
Self {
sum_nll: 0.0,
total_tokens: 0,
current: f64::NAN,
}
}
fn reset(&mut self) {
self.sum_nll = 0.0;
self.total_tokens = 0;
self.current = f64::NAN;
}
/// Update state with negative log-likelihood and token count from current batch
fn update(
&mut self,
sum_log_prob: f64,
effective_tokens: usize,
format: FormatOptions,
) -> SerializedEntry {
// sum_log_prob is already the sum of log probabilities (negative values)
// We need to negate it to get negative log-likelihood
let batch_nll = -sum_log_prob;
// Accumulate across batches
self.sum_nll += batch_nll;
self.total_tokens += effective_tokens;
// Compute current batch perplexity for display
let batch_perplexity = if effective_tokens > 0 {
(batch_nll / effective_tokens as f64).exp()
} else {
f64::INFINITY
};
self.current = batch_perplexity;
// Compute running epoch perplexity
let epoch_perplexity = if self.total_tokens > 0 {
(self.sum_nll / self.total_tokens as f64).exp()
} else {
f64::INFINITY
};
// Format for display
let (formatted_current, formatted_running) = match format.precision_value() {
Some(precision) => (
format_float(batch_perplexity, precision),
format_float(epoch_perplexity, precision),
),
None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
};
let formatted = match format.unit_value() {
Some(unit) => {
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
}
None => format!("epoch {formatted_running} - batch {formatted_current}"),
};
// Serialize the state for aggregation
let serialized = NumericEntry::Aggregated {
aggregated_value: epoch_perplexity,
count: self.total_tokens,
}
.serialize();
SerializedEntry::new(formatted, serialized)
}
fn value(&self) -> NumericEntry {
let perplexity = if self.total_tokens > 0 {
(self.sum_nll / self.total_tokens as f64).exp()
} else {
f64::INFINITY
};
NumericEntry::Aggregated {
aggregated_value: perplexity,
count: self.total_tokens,
}
}
fn running_value(&self) -> NumericEntry {
self.value()
}
}
/// The perplexity metric.
///
/// Perplexity is a measure of how well a probability distribution or probability model
/// predicts a sample. It's commonly used to evaluate language models. A lower perplexity
/// indicates that the model is more confident in its predictions.
///
/// Mathematically, perplexity is defined as the exponentiation of the cross-entropy loss:
/// PPL = exp(H(p, q)) = exp(-1/N * Σ log(p(x_i)))
///
/// where:
/// - H(p, q) is the cross-entropy between the true distribution p and predicted distribution q
/// - N is the number of tokens
/// - p(x_i) is the predicted probability of the i-th token
///
/// # Aggregation
/// Unlike other metrics, perplexity cannot be simply averaged across batches.
/// This implementation correctly accumulates the total negative log-likelihood and
/// total token count across batches, then computes perplexity as exp(total_nll / total_tokens).
#[derive(Clone)]
pub struct PerplexityMetric<B: Backend> {
name: MetricName,
state: PerplexityState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}
/// The [perplexity metric](PerplexityMetric) input type.
#[derive(new)]
pub struct PerplexityInput<B: Backend> {
/// Logits tensor of shape [batch_size * sequence_length, vocab_size]
outputs: Tensor<B, 2>,
/// Target tokens tensor of shape [batch_size * sequence_length]
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Default for PerplexityMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> PerplexityMetric<B> {
/// Creates the metric.
pub fn new() -> Self {
Self {
name: MetricName::new("Perplexity".to_string()),
state: PerplexityState::new(),
pad_token: Default::default(),
_b: PhantomData,
}
}
/// Sets the pad token to exclude from perplexity calculation.
///
/// When a pad token is set, predictions for padding tokens are masked out
/// and do not contribute to the perplexity calculation. This is important
/// for variable-length sequences where padding is used.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for PerplexityMetric<B> {
type Input = PerplexityInput<B>;
fn update(
&mut self,
input: &PerplexityInput<B>,
_metadata: &MetricMetadata,
) -> SerializedEntry {
let targets = input.targets.clone();
let outputs = input.outputs.clone();
let [total_tokens, _vocab_size] = outputs.dims();
// Convert logits to log probabilities using log_softmax for numerical stability
let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
// Gather the log probabilities for the target tokens
let target_log_probs = log_probs
.gather(1, targets.clone().unsqueeze_dim(1))
.squeeze_dim(1);
let (sum_log_prob, effective_tokens) = match self.pad_token {
Some(pad_token) => {
// Create a mask for non-padding tokens
let mask = targets.clone().not_equal_elem(pad_token as i64);
// Apply mask to log probabilities (set padding log probs to 0)
let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
// Sum the log probabilities and count effective tokens
let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
(sum_log_prob, effective_tokens)
}
None => {
// No padding, use all tokens
let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
(sum_log_prob, total_tokens)
}
};
// Pass the sum_log_prob and effective_tokens to the state
// The state will handle the correct accumulation and perplexity calculation
self.state.update(
sum_log_prob,
effective_tokens,
FormatOptions::new(self.name()).precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: false,
}
.into()
}
}
impl<B: Backend> Numeric for PerplexityMetric<B> {
fn value(&self) -> NumericEntry {
self.state.value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_perplexity_perfect_prediction() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
// Perfect prediction: target is always the highest probability class
let input = PerplexityInput::new(
Tensor::from_data(
[
[10.0, 0.0, 0.0], // Very confident prediction for class 0
[0.0, 10.0, 0.0], // Very confident prediction for class 1
[0.0, 0.0, 10.0], // Very confident prediction for class 2
],
&device,
),
Tensor::from_data([0, 1, 2], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
// Perfect predictions should result in very low perplexity (close to 1.0)
assert!(
perplexity < 1.1,
"Perfect predictions should have low perplexity, got {}",
perplexity
);
}
#[test]
fn test_perplexity_uniform_prediction() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
// Uniform prediction: all classes have equal probability
let input = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 0.0, 0.0], // Uniform distribution (after softmax)
[0.0, 0.0, 0.0], // Uniform distribution (after softmax)
[0.0, 0.0, 0.0], // Uniform distribution (after softmax)
],
&device,
),
Tensor::from_data([0, 1, 2], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
// Uniform distribution over 3 classes should have perplexity ≈ 3.0
assert!(
(perplexity - 3.0).abs() < 0.1,
"Uniform distribution perplexity should be ~3.0, got {}",
perplexity
);
}
#[test]
fn test_perplexity_with_padding() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
let input = PerplexityInput::new(
Tensor::from_data(
[
[10.0, 0.0, 0.0, 0.0], // Good prediction for class 0
[0.0, 10.0, 0.0, 0.0], // Good prediction for class 1
[0.0, 0.0, 0.0, 1.0], // This is padding - should be ignored
[0.0, 0.0, 0.0, 1.0], // This is padding - should be ignored
],
&device,
),
Tensor::from_data([0, 1, 3, 3], &device), // 3 is pad token
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
// Should only consider the first two predictions, both of which are confident
assert!(
perplexity < 1.1,
"Good predictions with padding should have low perplexity, got {}",
perplexity
);
}
#[test]
fn test_perplexity_wrong_prediction() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
// Wrong predictions: target class has very low probability
let input = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 10.0, 0.0], // Predicts class 1, but target is 0
[10.0, 0.0, 0.0], // Predicts class 0, but target is 1
[0.0, 0.0, 10.0], // Predicts class 2, but target is 0
],
&device,
),
Tensor::from_data([0, 1, 0], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
// Wrong predictions should result in high perplexity
assert!(
perplexity > 10.0,
"Wrong predictions should have high perplexity, got {}",
perplexity
);
}
#[test]
fn test_perplexity_multi_batch_aggregation() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
// First batch: 2 tokens with uniform distribution (log_prob ≈ -1.0986 each)
let input1 = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
[0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
],
&device,
),
Tensor::from_data([0, 1], &device),
);
// Second batch: 1 token with uniform distribution
let input2 = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
],
&device,
),
Tensor::from_data([2], &device),
);
// Update with both batches
let _entry1 = metric.update(&input1, &MetricMetadata::fake());
let _entry2 = metric.update(&input2, &MetricMetadata::fake());
let aggregated_perplexity = metric.value().current();
// For uniform distribution over 3 classes: log_prob ≈ -log(3) ≈ -1.0986
// Total negative log-likelihood: 3 * 1.0986 ≈ 3.2958
// Total tokens: 3
// Expected perplexity: exp(3.2958 / 3) = exp(1.0986) ≈ 3.0
assert!(
(aggregated_perplexity - 3.0).abs() < 0.1,
"Multi-batch aggregated perplexity should be ~3.0, got {}",
aggregated_perplexity
);
// Compare with single batch containing all data
let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
let single_input = PerplexityInput::new(
Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
Tensor::from_data([0, 1, 2], &device),
);
let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
let single_batch_perplexity = single_batch_metric.value().current();
// Multi-batch and single-batch should give the same result
assert!(
(aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
"Multi-batch ({}) and single-batch ({}) perplexity should match",
aggregated_perplexity,
single_batch_perplexity
);
}
}

View File

@@ -0,0 +1,231 @@
use crate::metric::{MetricName, Numeric};
use super::{
Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
confusion_stats::{ConfusionStats, ConfusionStatsInput},
state::{FormatOptions, NumericMetricState},
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::cast::ToElement,
};
use core::marker::PhantomData;
use std::{num::NonZeroUsize, sync::Arc};
/// The Precision Metric
#[derive(Clone)]
pub struct PrecisionMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
config: ClassificationMetricConfig,
}
impl<B: Backend> Default for PrecisionMetric<B> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<B: Backend> PrecisionMetric<B> {
fn new(config: ClassificationMetricConfig) -> Self {
let state = Default::default();
let name = Arc::new(format!(
"Precision @ {:?} [{:?}]",
config.decision_rule, config.class_reduction
));
Self {
state,
config,
name,
_b: Default::default(),
}
}
/// Precision metric for binary classification.
///
/// # Arguments
///
/// * `threshold` - The threshold to transform a probability into a binary prediction.
#[allow(dead_code)]
pub fn binary(threshold: f64) -> Self {
Self::new(ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
// binary classification results are the same independently of class_reduction
..Default::default()
})
}
/// Precision metric for multiclass classification.
///
/// # Arguments
///
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
Self::new(ClassificationMetricConfig {
decision_rule: DecisionRule::TopK(
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
),
class_reduction,
})
}
/// Precision metric for multi-label classification.
///
/// # Arguments
///
/// * `threshold` - The threshold to transform a probability into a binary value.
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
Self {
config: ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
class_reduction,
},
..Default::default()
}
}
fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
use ClassReduction::{Macro, Micro};
let avg_tensor = match self.config.class_reduction {
Micro => aggregated_metric,
Macro => {
if aggregated_metric
.clone()
.contains_nan()
.any()
.into_scalar()
.to_bool()
{
let nan_mask = aggregated_metric.clone().is_nan();
aggregated_metric = aggregated_metric
.clone()
.select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))
}
aggregated_metric.mean()
}
};
avg_tensor.into_scalar().to_f64()
}
}
impl<B: Backend> Metric for PrecisionMetric<B> {
type Input = ConfusionStatsInput<B>;
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
let [sample_size, _] = input.predictions.dims();
let cf_stats = ConfusionStats::new(input, &self.config);
let metric =
self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive());
self.state.update(
100.0 * metric,
sample_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for PrecisionMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::{
ClassReduction::{self, *},
Metric, MetricMetadata, PrecisionMetric,
};
use crate::metric::Numeric;
use crate::{
TestBackend,
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
};
use burn_core::tensor::TensorData;
use burn_core::tensor::Tolerance;
use rstest::rstest;
#[rstest]
#[case::binary(THRESHOLD, 0.5)]
fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) {
let input = dummy_classification_input(&ClassificationType::Binary).into();
let mut metric = PrecisionMetric::binary(threshold);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[rstest]
#[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]
#[case::multiclass_micro_k2(Micro, 2, 4.0/10.0)]
#[case::multiclass_macro_k1(Macro, 1, (0.5 + 0.5 + 1.0)/3.0)]
#[case::multiclass_macro_k2(Macro, 2, (0.5 + 1.0/4.0 + 0.5)/3.0)]
fn test_multiclass_precision(
#[case] class_reduction: ClassReduction,
#[case] top_k: usize,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
let mut metric = PrecisionMetric::multiclass(top_k, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[rstest]
#[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)]
#[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)]
fn test_multilabel_precision(
#[case] class_reduction: ClassReduction,
#[case] threshold: f64,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
let mut metric = PrecisionMetric::multilabel(threshold, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[test]
fn test_parameterized_unique_name() {
let metric_a = PrecisionMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
let metric_b = PrecisionMetric::<TestBackend>::multiclass(2, ClassReduction::Macro);
let metric_c = PrecisionMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
assert_ne!(metric_a.name(), metric_b.name());
assert_eq!(metric_a.name(), metric_c.name());
let metric_a = PrecisionMetric::<TestBackend>::binary(0.5);
let metric_b = PrecisionMetric::<TestBackend>::binary(0.75);
assert_ne!(metric_a.name(), metric_b.name());
}
}

View File

@@ -0,0 +1,137 @@
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
use super::EventProcessorTraining;
use async_channel::{Receiver, Sender};
/// Event processor for the training process.
pub struct AsyncProcessorTraining<ET, EV> {
sender: Sender<Message<ET, EV>>,
}
/// Event processor for the model evaluation.
pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
sender: Sender<EvalMessage<P>>,
}
struct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {
processor: P,
rec: Receiver<Message<ET, EV>>,
}
struct WorkerEvaluation<P: EventProcessorEvaluation> {
processor: P,
rec: Receiver<EvalMessage<P>>,
}
impl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>
WorkerTraining<ET, EV, P>
{
pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {
let mut worker = Self { processor, rec };
std::thread::spawn(move || {
while let Ok(msg) = worker.rec.recv_blocking() {
match msg {
Message::Train(event) => worker.processor.process_train(event),
Message::Valid(event) => worker.processor.process_valid(event),
Message::Renderer(callback) => {
callback.send_blocking(worker.processor.renderer()).unwrap();
return;
}
}
}
});
}
}
impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
let mut worker = Self { processor, rec };
std::thread::spawn(move || {
while let Ok(event) = worker.rec.recv_blocking() {
match event {
EvalMessage::Test(event) => worker.processor.process_test(event),
EvalMessage::Renderer(sender) => {
sender.send_blocking(worker.processor.renderer()).unwrap();
return;
}
}
}
});
}
}
impl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {
/// Create an event processor for training.
pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {
let (sender, rec) = async_channel::bounded(1);
WorkerTraining::start(processor, rec);
Self { sender }
}
}
impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
/// Create an event processor for model evaluation.
pub fn new(processor: P) -> Self {
let (sender, rec) = async_channel::bounded(1);
WorkerEvaluation::start(processor, rec);
Self { sender }
}
}
enum Message<EventTrain, EventValid> {
Train(EventTrain),
Valid(EventValid),
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
}
enum EvalMessage<P: EventProcessorEvaluation> {
Test(EvaluatorEvent<P::ItemTest>),
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
}
impl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {
fn process_train(&mut self, event: ET) {
self.sender.send_blocking(Message::Train(event)).unwrap();
}
fn process_valid(&mut self, event: EV) {
self.sender.send_blocking(Message::Valid(event)).unwrap();
}
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
let (sender, rec) = async_channel::bounded(1);
self.sender
.send_blocking(Message::Renderer(sender))
.unwrap();
match rec.recv_blocking() {
Ok(value) => value,
Err(err) => panic!("{err:?}"),
}
}
}
impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
type ItemTest = P::ItemTest;
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
}
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
let (sender, rec) = async_channel::bounded(1);
self.sender
.send_blocking(EvalMessage::Renderer(sender))
.unwrap();
match rec.recv_blocking() {
Ok(value) => value,
Err(err) => panic!("{err:?}"),
}
}
}

View File

@@ -0,0 +1,126 @@
use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
use crate::{
LearnerSummary,
renderer::{EvaluationName, MetricsRenderer},
};
/// Event happening during the training/validation process.
pub enum LearnerEvent<T> {
/// Signal the start of the process (e.g., training start)
Start,
/// Signal that an item have been processed.
ProcessedItem(TrainingItem<T>),
/// Signal the end of an epoch.
EndEpoch(usize),
/// Signal the end of the process (e.g., training end).
End(Option<LearnerSummary>),
}
/// Event happening during the evaluation process.
pub enum EvaluatorEvent<T> {
/// Signal the start of the process (e.g., evaluation start)
Start,
/// Signal that an item have been processed.
ProcessedItem(EvaluationName, EvaluationItem<T>),
/// Signal the end of the process (e.g., evaluation end).
End(Option<LearnerSummary>),
}
/// Items that are lazy are not ready to be processed by metrics.
///
/// We want to sync them on a different thread to avoid blocking training.
pub trait ItemLazy: Send {
/// Item that is properly synced and ready to be processed by metrics.
type ItemSync: Send;
/// Sync the item.
fn sync(self) -> Self::ItemSync;
}
/// Process events happening during training and validation.
pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
/// Collect a training event.
fn process_train(&mut self, event: TrainEvent);
/// Collect a validation event.
fn process_valid(&mut self, event: ValidEvent);
/// Returns the renderer used for training.
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
/// Process events happening during evaluation.
pub trait EventProcessorEvaluation: Send {
/// The test item.
type ItemTest: ItemLazy;
/// Collect a test event.
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
/// Returns the renderer used for evaluation.
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
/// A learner item.
#[derive(new)]
pub struct TrainingItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The global progress of the training (e.g. epochs).
pub global_progress: Progress,
/// The iteration, if it it different from the items processed.
pub iteration: Option<usize>,
/// The learning rate.
pub lr: Option<LearningRate>,
}
impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
type ItemSync = TrainingItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
TrainingItem {
item: self.item.sync(),
progress: self.progress,
global_progress: self.global_progress,
iteration: self.iteration,
lr: self.lr,
}
}
}
/// An evaluation item.
#[derive(new)]
pub struct EvaluationItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The iteration, if it it different from the items processed.
pub iteration: Option<usize>,
}
impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
type ItemSync = EvaluationItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
EvaluationItem {
item: self.item.sync(),
progress: self.progress,
iteration: self.iteration,
}
}
}
impl ItemLazy for () {
type ItemSync = ();
fn sync(self) -> Self::ItemSync {}
}

View File

@@ -0,0 +1,257 @@
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation};
use crate::metric::store::{EpochSummary, EventStoreClient, Split};
use crate::renderer::{
EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress,
};
use std::sync::Arc;
/// An [event processor](EventProcessorTraining) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
pub struct FullEventProcessorTraining<T: ItemLazy, V: ItemLazy> {
metrics: MetricsTraining<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
/// An [event processor](EventProcessorEvaluation) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
pub struct FullEventProcessorEvaluation<T: ItemLazy> {
metrics: MetricsEvaluation<T>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
impl<T: ItemLazy, V: ItemLazy> FullEventProcessorTraining<T, V> {
pub(crate) fn new(
metrics: MetricsTraining<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
) -> Self {
Self {
metrics,
renderer,
store,
}
}
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let mut indicators = vec![];
indicators.push(ProgressType::Detailed {
tag: String::from("Epoch"),
progress: progress.global_progress.clone(),
});
if let Some(iteration) = progress.iteration {
indicators.push(ProgressType::Value {
tag: String::from("Iteration"),
value: iteration,
});
};
if let Some(p) = &progress.progress {
indicators.push(ProgressType::Detailed {
tag: String::from("Items"),
progress: p.clone(),
});
};
indicators
}
}
impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
pub(crate) fn new(
metrics: MetricsEvaluation<T>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
) -> Self {
Self {
metrics,
renderer,
store,
}
}
fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec<ProgressType> {
let mut indicators = vec![];
if let Some(iteration) = progress.iteration {
indicators.push(ProgressType::Value {
tag: String::from("Iteration"),
value: iteration,
});
};
indicators.push(ProgressType::Detailed {
tag: String::from("Items"),
progress: progress.progress.clone(),
});
indicators
}
}
impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
type ItemTest = T;
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
match event {
EvaluatorEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(crate::metric::store::Event::MetricsInit(
definitions.clone(),
));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
EvaluatorEvent::ProcessedItem(name, item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_test(&item, &metadata);
self.store.add_event_test(
crate::metric::store::Event::MetricsUpdate(update.clone()),
name.name.clone(),
);
update.entries.into_iter().for_each(|entry| {
self.renderer
.update_test(name.clone(), MetricState::Generic(entry))
});
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_test(
name.clone(),
MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
),
)
});
let indicators = self.progress_indicators(&progress);
self.renderer.render_test(progress, indicators);
}
EvaluatorEvent::End(summary) => {
self.renderer.on_test_end(summary).ok();
}
}
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
for FullEventProcessorTraining<T, V>
{
fn process_train(&mut self, event: LearnerEvent<T>) {
match event {
LearnerEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(crate::metric::store::Event::MetricsInit(
definitions.clone(),
));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
LearnerEvent::ProcessedItem(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_train(&item, &metadata);
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_train(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
let indicators = self.progress_indicators(&progress);
self.renderer.render_train(progress, indicators);
}
LearnerEvent::EndEpoch(epoch) => {
self.store
.add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(
epoch,
Split::Train,
)));
self.metrics.end_epoch_train();
}
LearnerEvent::End(summary) => {
self.renderer.on_train_end(summary).ok();
}
}
}
fn process_valid(&mut self, event: LearnerEvent<V>) {
match event {
LearnerEvent::Start => {} // no-op for now
LearnerEvent::ProcessedItem(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_valid(&item, &metadata);
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_valid(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
let indicators = self.progress_indicators(&progress);
self.renderer.render_valid(progress, indicators);
}
LearnerEvent::EndEpoch(epoch) => {
self.store
.add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(
epoch,
Split::Valid,
)));
self.metrics.end_epoch_valid();
}
LearnerEvent::End(_) => {} // no-op for now
}
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}

View File

@@ -0,0 +1,341 @@
use std::collections::HashMap;
use super::{ItemLazy, TrainingItem};
use crate::{
EvaluationItem,
metric::{
Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
store::{MetricsUpdate, NumericMetricUpdate},
},
renderer::{EvaluationProgress, TrainingProgress},
};
pub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {
train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,
train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
pub(crate) struct MetricsEvaluation<T: ItemLazy> {
test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
impl<T: ItemLazy> Default for MetricsEvaluation<T> {
fn default() -> Self {
Self {
test: Default::default(),
test_numeric: Default::default(),
metric_definitions: HashMap::default(),
}
}
}
impl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {
fn default() -> Self {
Self {
train: Vec::default(),
valid: Vec::default(),
train_numeric: Vec::default(),
valid_numeric: Vec::default(),
metric_definitions: HashMap::default(),
}
}
}
impl<T: ItemLazy> MetricsEvaluation<T> {
/// Register a testing metric.
pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
T::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.test.push(Box::new(metric))
}
/// Register a numeric testing metric.
pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
T::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.test_numeric.push(Box::new(metric))
}
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
self.metric_definitions.insert(
metric.id.clone(),
MetricDefinition::new(metric.id.clone(), &metric.metric),
);
}
/// Get metric definitions.
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
self.metric_definitions.values().cloned().collect()
}
/// Update the testing information from the testing item.
pub(crate) fn update_test(
&mut self,
item: &EvaluationItem<T::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.test.len());
let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
for metric in self.test.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.test_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
}
impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
/// Register a training metric.
pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
T::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train.push(Box::new(metric))
}
/// Register a validation metric.
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
V::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.valid.push(Box::new(metric))
}
/// Register a numeric training metric.
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
T::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_numeric.push(Box::new(metric))
}
/// Register a numeric validation metric.
pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)
where
V::ItemSync: Adaptor<Me::Input> + 'static,
Me: Metric + Numeric + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.valid_numeric.push(Box::new(metric))
}
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
self.metric_definitions.insert(
metric.id.clone(),
MetricDefinition::new(metric.id.clone(), &metric.metric),
);
}
/// Get metric definitions for all splits
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
self.metric_definitions.values().cloned().collect()
}
/// Update the training information from the training item.
pub(crate) fn update_train(
&mut self,
item: &TrainingItem<T::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.train.len());
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
for metric in self.train.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.train_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the training information from the validation item.
pub(crate) fn update_valid(
&mut self,
item: &TrainingItem<V::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.valid.len());
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
for metric in self.valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Signal the end of a training epoch.
pub(crate) fn end_epoch_train(&mut self) {
for metric in self.train.iter_mut() {
metric.clear();
}
for metric in self.train_numeric.iter_mut() {
metric.clear();
}
}
/// Signal the end of a validation epoch.
pub(crate) fn end_epoch_valid(&mut self) {
for metric in self.valid.iter_mut() {
metric.clear();
}
for metric in self.valid_numeric.iter_mut() {
metric.clear();
}
}
}
impl<T> From<&TrainingItem<T>> for TrainingProgress {
fn from(item: &TrainingItem<T>) -> Self {
Self {
progress: Some(item.progress.clone()),
global_progress: item.global_progress.clone(),
iteration: item.iteration,
}
}
}
impl<T> From<&EvaluationItem<T>> for TrainingProgress {
fn from(item: &EvaluationItem<T>) -> Self {
Self {
progress: None,
global_progress: item.progress.clone(),
iteration: item.iteration,
}
}
}
impl<T> From<&EvaluationItem<T>> for EvaluationProgress {
fn from(item: &EvaluationItem<T>) -> Self {
Self {
progress: item.progress.clone(),
iteration: item.iteration,
}
}
}
impl<T> From<&TrainingItem<T>> for MetricMetadata {
fn from(item: &TrainingItem<T>) -> Self {
Self {
progress: item.progress.clone(),
global_progress: item.global_progress.clone(),
iteration: item.iteration,
lr: item.lr,
}
}
}
impl<T> From<&EvaluationItem<T>> for MetricMetadata {
fn from(item: &EvaluationItem<T>) -> Self {
Self {
progress: item.progress.clone(),
global_progress: item.progress.clone(),
iteration: item.iteration,
lr: None,
}
}
}
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;
fn clear(&mut self);
}
pub(crate) trait MetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
pub(crate) struct MetricWrapper<M> {
pub id: MetricId,
pub metric: M,
}
impl<M: Metric> MetricWrapper<M> {
pub fn new(metric: M) -> Self {
Self {
id: MetricId::new(metric.name()),
metric,
}
}
}
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + Numeric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {
let serialized_entry = self.metric.update(&item.adapt(), metadata);
let update = MetricEntry::new(self.id.clone(), serialized_entry);
let numeric = self.metric.value();
let running = self.metric.running_value();
NumericMetricUpdate {
entry: update,
numeric_entry: numeric,
running_entry: running,
}
}
fn clear(&mut self) {
self.metric.clear()
}
}
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {
let serialized_entry = self.metric.update(&item.adapt(), metadata);
MetricEntry::new(self.id.clone(), serialized_entry)
}
fn clear(&mut self) {
self.metric.clear()
}
}

View File

@@ -0,0 +1,76 @@
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
use crate::{
metric::store::{EpochSummary, EventStoreClient, Split},
renderer::cli::CliMetricsRenderer,
};
use std::sync::Arc;
/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
#[allow(dead_code)]
#[derive(new)]
pub(crate) struct MinimalEventProcessor<T: ItemLazy, V: ItemLazy> {
metrics: MetricsTraining<T, V>,
store: Arc<EventStoreClient>,
}
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
for MinimalEventProcessor<T, V>
{
fn process_train(&mut self, event: LearnerEvent<T>) {
match event {
LearnerEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(crate::metric::store::Event::MetricsInit(definitions));
}
LearnerEvent::ProcessedItem(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_train(&item, &metadata);
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update));
}
LearnerEvent::EndEpoch(epoch) => {
self.metrics.end_epoch_train();
self.store
.add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(
epoch,
Split::Train,
)));
}
LearnerEvent::End(_summary) => {} // no-op for now
}
}
fn process_valid(&mut self, event: LearnerEvent<V>) {
match event {
LearnerEvent::Start => {} // no-op for now
LearnerEvent::ProcessedItem(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_valid(&item, &metadata);
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update));
}
LearnerEvent::EndEpoch(epoch) => {
self.metrics.end_epoch_valid();
self.store
.add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(
epoch,
Split::Valid,
)));
}
LearnerEvent::End(_) => {} // no-op for now
}
}
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
// TODO: Check for another default.
Box::new(CliMetricsRenderer::new())
}
}

View File

@@ -0,0 +1,77 @@
mod async_wrapper;
mod base;
mod full;
mod metrics;
mod minimal;
#[cfg(feature = "rl")]
mod rl_metrics;
#[cfg(feature = "rl")]
mod rl_processor;
pub use base::*;
pub(crate) use full::*;
pub(crate) use metrics::*;
#[cfg(feature = "rl")]
pub(crate) use rl_metrics::*;
#[cfg(feature = "rl")]
pub(crate) use rl_processor::*;
#[cfg(test)]
pub(crate) use minimal::*;
pub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining};
#[cfg(test)]
pub(crate) mod test_utils {
use crate::metric::{
Adaptor, LossInput,
processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem},
};
use burn_core::tensor::{ElementConversion, Tensor, backend::Backend};
use super::ItemLazy;
impl ItemLazy for f64 {
type ItemSync = f64;
fn sync(self) -> Self::ItemSync {
self
}
}
impl<B: Backend> Adaptor<LossInput<B>> for f64 {
fn adapt(&self) -> LossInput<B> {
let device = B::Device::default();
LossInput::new(Tensor::from_data([self.elem::<B::FloatElem>()], &device))
}
}
pub(crate) fn process_train(
processor: &mut MinimalEventProcessor<f64, f64>,
value: f64,
epoch: usize,
) {
let dummy_progress = burn_core::data::dataloader::Progress {
items_processed: 1,
items_total: 10,
};
let dummy_global_progress = burn_core::data::dataloader::Progress {
items_processed: epoch,
items_total: 3,
};
let dummy_iteration = Some(1);
processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new(
value,
dummy_progress,
dummy_global_progress,
dummy_iteration,
None,
)));
}
pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor<f64, f64>, epoch: usize) {
processor.process_train(LearnerEvent::EndEpoch(epoch));
processor.process_valid(LearnerEvent::EndEpoch(epoch));
}
}

View File

@@ -0,0 +1,268 @@
use std::collections::HashMap;
use crate::{
EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater,
metric::{
Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate,
},
};
pub(crate) struct RLMetrics<TS: ItemLazy, ES: ItemLazy> {
train_step: Vec<Box<dyn MetricUpdater<TS::ItemSync>>>,
env_step: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
env_step_valid: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
episode_end: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
episode_end_valid: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
train_step_numeric: Vec<Box<dyn NumericMetricUpdater<TS::ItemSync>>>,
env_step_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
env_step_valid_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
episode_end_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
episode_end_valid_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
impl<TS: ItemLazy, ES: ItemLazy> Default for RLMetrics<TS, ES> {
fn default() -> Self {
Self {
train_step: Vec::default(),
env_step: Vec::default(),
env_step_valid: Vec::default(),
episode_end: Vec::default(),
episode_end_valid: Vec::default(),
train_step_numeric: Vec::default(),
env_step_numeric: Vec::default(),
env_step_valid_numeric: Vec::default(),
episode_end_numeric: Vec::default(),
episode_end_valid_numeric: Vec::default(),
metric_definitions: HashMap::default(),
}
}
}
impl<TS: ItemLazy, ES: ItemLazy> RLMetrics<TS, ES> {
/// Register a training metric.
pub(crate) fn register_text_metric_agent<Me: Metric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step.push(Box::new(metric))
}
/// Register a training metric.
pub(crate) fn register_agent_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_numeric.push(Box::new(metric))
}
/// Register a training metric.
pub(crate) fn register_text_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
where
TS::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_step.push(Box::new(metric))
}
/// Register a training metric.
pub(crate) fn register_metric_train<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
TS::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_step_numeric.push(Box::new(metric))
}
/// Register a validation env-step metric.
pub(crate) fn register_text_metric_agent_valid<Me: Metric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_valid.push(Box::new(metric))
}
/// Register a validation env-step numeric metric.
pub(crate) fn register_agent_metric_valid<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_valid_numeric.push(Box::new(metric))
}
/// Register an episode-end metric.
pub(crate) fn register_text_metric_episode<Me: Metric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end.push(Box::new(metric))
}
/// Register an episode-end numeric metric.
pub(crate) fn register_episode_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_numeric.push(Box::new(metric))
}
/// Register an episode-end metric for validation.
pub(crate) fn register_text_metric_episode_valid<Me: Metric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_valid.push(Box::new(metric))
}
/// Register an episode-end numeric metric for validation.
pub(crate) fn register_episode_metric_valid<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_valid_numeric.push(Box::new(metric))
}
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
self.metric_definitions.insert(
metric.id.clone(),
MetricDefinition::new(metric.id.clone(), &metric.metric),
);
}
/// Get metric definitions for all splits
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
self.metric_definitions.values().cloned().collect()
}
/// Update the training information from the training item.
pub(crate) fn update_train_step(
&mut self,
item: &EvaluationItem<TS::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.train_step.len());
let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len());
for metric in self.train_step.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.train_step_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the env-step metrics from an environment step item.
pub(crate) fn update_env_step(
&mut self,
item: &EvaluationItem<ES::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.env_step.len());
let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len());
for metric in self.env_step.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.env_step_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the env-step metrics for validation from an environment step item.
pub(crate) fn update_env_step_valid(
&mut self,
item: &EvaluationItem<ES::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.env_step_valid.len());
let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len());
for metric in self.env_step_valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.env_step_valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the episode-end metrics from an episode summary.
pub(crate) fn update_episode_end(
&mut self,
item: &EvaluationItem<EpisodeSummary>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.episode_end.len());
let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len());
for metric in self.episode_end.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.episode_end_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the episode-end metrics for validation from an episode summary.
pub(crate) fn update_episode_end_valid(
&mut self,
item: &EvaluationItem<EpisodeSummary>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.episode_end_valid.len());
let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len());
for metric in self.episode_end_valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.episode_end_valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
}

View File

@@ -0,0 +1,177 @@
use std::sync::Arc;
use crate::{
EpisodeSummary, EvaluationItem, EventProcessorTraining, ItemLazy, LearnerSummary, RLMetrics,
metric::store::{Event, EventStoreClient, MetricsUpdate},
renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress},
};
/// Event happening during reinforcement learning.
pub enum RLEvent<TS, ES> {
/// Signal the start of the process (e.g., learning starts).
Start,
/// Signal an agent's training step.
TrainStep(EvaluationItem<TS>),
/// Signal a timestep of the agent-environment interface.
TimeStep(EvaluationItem<ES>),
/// Signal an episode end.
EpisodeEnd(EvaluationItem<EpisodeSummary>),
/// Signal the end of the process (e.g., learning ends).
End(Option<LearnerSummary>),
}
/// Event happening during evaluation of a reinforcement learning's agent.
pub enum AgentEvaluationEvent<T> {
/// Signal the start of the process (e.g., training start)
Start,
/// Signal a timestep of the agent-environment interface.
TimeStep(EvaluationItem<T>),
/// Signal an episode end.
EpisodeEnd(EvaluationItem<EpisodeSummary>),
/// Signal the end of the process (e.g., training end).
End,
}
/// An [event processor](EventProcessorTraining) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
#[derive(new)]
pub struct RLEventProcessor<TS: ItemLazy, ES: ItemLazy> {
metrics: RLMetrics<TS, ES>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let indicators = vec![ProgressType::Detailed {
tag: String::from("Step"),
progress: progress.global_progress.clone(),
}];
indicators
}
fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let indicators = vec![ProgressType::Detailed {
tag: String::from("Step"),
progress: progress.global_progress.clone(),
}];
indicators
}
}
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
fn process_update_train(&mut self, update: MetricsUpdate) {
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_train(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
}
fn process_update_valid(&mut self, update: MetricsUpdate) {
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_valid(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
}
}
impl<TS: ItemLazy, ES: ItemLazy> EventProcessorTraining<RLEvent<TS, ES>, AgentEvaluationEvent<ES>>
for RLEventProcessor<TS, ES>
{
fn process_train(&mut self, event: RLEvent<TS, ES>) {
match event {
RLEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(Event::MetricsInit(definitions.clone()));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
RLEvent::TrainStep(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_train_step(&item, &metadata);
self.process_update_train(update);
}
RLEvent::TimeStep(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_env_step(&item, &metadata);
self.process_update_train(update);
let status = self.progress_indicators(&progress);
self.renderer.render_train(progress, status);
}
RLEvent::EpisodeEnd(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_episode_end(&item, &metadata);
self.process_update_train(update);
}
RLEvent::End(learner_summary) => {
self.renderer.on_train_end(learner_summary).ok();
}
}
}
fn process_valid(&mut self, event: AgentEvaluationEvent<ES>) {
match event {
AgentEvaluationEvent::Start => {} // no-op for now
AgentEvaluationEvent::TimeStep(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_env_step_valid(&item, &metadata);
self.process_update_valid(update);
}
AgentEvaluationEvent::EpisodeEnd(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_episode_end_valid(&item, &metadata);
self.process_update_valid(update);
let status = self.progress_indicators_eval(&progress);
self.renderer.render_valid(progress, status);
}
AgentEvaluationEvent::End => {} // no-op for now
}
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}

View File

@@ -0,0 +1,226 @@
use crate::metric::{MetricName, Numeric};
use super::{
Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
confusion_stats::{ConfusionStats, ConfusionStatsInput},
state::{FormatOptions, NumericMetricState},
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::cast::ToElement,
};
use core::marker::PhantomData;
use std::{num::NonZeroUsize, sync::Arc};
///The Recall Metric
#[derive(Clone)]
pub struct RecallMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
config: ClassificationMetricConfig,
}
impl<B: Backend> Default for RecallMetric<B> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<B: Backend> RecallMetric<B> {
fn new(config: ClassificationMetricConfig) -> Self {
let state = Default::default();
let name = Arc::new(format!(
"Recall @ {:?} [{:?}]",
config.decision_rule, config.class_reduction
));
Self {
state,
config,
name,
_b: Default::default(),
}
}
/// Recall metric for binary classification.
///
/// # Arguments
///
/// * `threshold` - The threshold to transform a probability into a binary prediction.
#[allow(dead_code)]
pub fn binary(threshold: f64) -> Self {
Self::new(ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
// binary classification results are the same independently of class_reduction
..Default::default()
})
}
/// Recall metric for multiclass classification.
///
/// # Arguments
///
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
Self::new(ClassificationMetricConfig {
decision_rule: DecisionRule::TopK(
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
),
class_reduction,
})
}
/// Recall metric for multi-label classification.
///
/// # Arguments
///
/// * `threshold` - The threshold to transform a probability into a binary prediction.
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
Self::new(ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
class_reduction,
})
}
fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
use ClassReduction::{Macro, Micro};
let avg_tensor = match self.config.class_reduction {
Micro => aggregated_metric,
Macro => {
if aggregated_metric
.clone()
.contains_nan()
.any()
.into_scalar()
.to_bool()
{
let nan_mask = aggregated_metric.clone().is_nan();
aggregated_metric = aggregated_metric
.clone()
.select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))
}
aggregated_metric.mean()
}
};
avg_tensor.into_scalar().to_f64()
}
}
impl<B: Backend> Metric for RecallMetric<B> {
type Input = ConfusionStatsInput<B>;
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
let [sample_size, _] = input.predictions.dims();
let cf_stats = ConfusionStats::new(input, &self.config);
let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive());
self.state.update(
100.0 * metric,
sample_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for RecallMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::{
ClassReduction::{self, *},
Metric, MetricMetadata, RecallMetric,
};
use crate::metric::Numeric;
use crate::{
TestBackend,
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
};
use burn_core::tensor::{TensorData, Tolerance};
use rstest::rstest;
#[rstest]
#[case::binary(THRESHOLD, 0.5)]
fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) {
let input = dummy_classification_input(&ClassificationType::Binary).into();
let mut metric = RecallMetric::binary(threshold);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[rstest]
#[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]
#[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)]
#[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)]
#[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)]
fn test_multiclass_recall(
#[case] class_reduction: ClassReduction,
#[case] top_k: usize,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
let mut metric = RecallMetric::multiclass(top_k, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[rstest]
#[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)]
#[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)]
fn test_multilabel_recall(
#[case] class_reduction: ClassReduction,
#[case] threshold: f64,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
let mut metric = RecallMetric::multilabel(threshold, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
}
#[test]
fn test_parameterized_unique_name() {
let metric_a = RecallMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
let metric_b = RecallMetric::<TestBackend>::multiclass(2, ClassReduction::Macro);
let metric_c = RecallMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
assert_ne!(metric_a.name(), metric_b.name());
assert_eq!(metric_a.name(), metric_c.name());
let metric_a = RecallMetric::<TestBackend>::binary(0.5);
let metric_b = RecallMetric::<TestBackend>::binary(0.75);
assert_ne!(metric_a.name(), metric_b.name());
}
}

View File

@@ -0,0 +1,78 @@
use std::sync::Arc;
use super::super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Metric for the cumulative reward of the last completed episode.
#[derive(Clone)]
pub struct CumulativeRewardMetric {
name: MetricName,
state: NumericMetricState,
}
impl CumulativeRewardMetric {
/// Creates a new episode length metric.
pub fn new() -> Self {
Self {
name: Arc::new("Cum. Reward".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for CumulativeRewardMetric {
fn default() -> Self {
Self::new()
}
}
/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type.
#[derive(new)]
pub struct CumulativeRewardInput {
cum_reward: f64,
}
impl Metric for CumulativeRewardMetric {
type Input = CumulativeRewardInput;
fn update(
&mut self,
item: &CumulativeRewardInput,
_metadata: &MetricMetadata,
) -> SerializedEntry {
self.state.update(
item.cum_reward,
1,
FormatOptions::new(self.name()).precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: true,
}
.into()
}
}
impl Numeric for CumulativeRewardMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,71 @@
use std::sync::Arc;
use super::super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Metric for the length of the last completed episode.
#[derive(Clone)]
pub struct EpisodeLengthMetric {
name: MetricName,
state: NumericMetricState,
}
impl EpisodeLengthMetric {
/// Creates a new episode length metric.
pub fn new() -> Self {
Self {
name: Arc::new("Episode length".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for EpisodeLengthMetric {
fn default() -> Self {
Self::new()
}
}
/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type.
#[derive(new)]
pub struct EpisodeLengthInput {
ep_len: f64,
}
impl Metric for EpisodeLengthMetric {
type Input = EpisodeLengthInput;
fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {
self.state
.update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some(String::from("steps")),
higher_is_better: true,
}
.into()
}
}
impl Numeric for EpisodeLengthMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,78 @@
use std::sync::Arc;
use super::super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Metric for the length of the last completed episode.
#[derive(Clone)]
pub struct ExplorationRateMetric {
name: MetricName,
state: NumericMetricState,
}
impl ExplorationRateMetric {
/// Creates a new episode length metric.
pub fn new() -> Self {
Self {
name: Arc::new("Exploration rate".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for ExplorationRateMetric {
fn default() -> Self {
Self::new()
}
}
/// The [ExplorationRateMetric](ExplorationRateMetric) input type.
#[derive(new)]
pub struct ExplorationRateInput {
exploration_rate: f64,
}
impl Metric for ExplorationRateMetric {
type Input = ExplorationRateInput;
fn update(
&mut self,
item: &ExplorationRateInput,
_metadata: &MetricMetadata,
) -> SerializedEntry {
self.state.update(
item.exploration_rate,
1,
FormatOptions::new(self.name()).precision(3),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some(String::from("%")),
higher_is_better: false,
}
.into()
}
}
impl Numeric for ExplorationRateMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,7 @@
mod cum_reward;
mod ep_len;
mod exploration_rate;
pub use cum_reward::*;
pub use ep_len::*;
pub use exploration_rate::*;

View File

@@ -0,0 +1,144 @@
use std::sync::Arc;
use crate::metric::{MetricName, NumericEntry, SerializedEntry, format_float};
/// Useful utility to implement numeric metrics.
///
/// # Notes
///
/// The numeric metric store values inside floats.
/// Even if some metric are integers, their mean are floats.
#[derive(Clone)]
pub struct NumericMetricState {
sum: f64,
count: usize,
current: f64,
current_count: usize,
}
/// Formatting options for the [numeric metric state](NumericMetricState).
pub struct FormatOptions {
name: Arc<String>,
unit: Option<String>,
precision: Option<usize>,
}
impl FormatOptions {
/// Create the [formatting options](FormatOptions) with a name.
pub fn new(name: MetricName) -> Self {
Self {
name: name.clone(),
unit: None,
precision: None,
}
}
/// Specify the metric unit.
pub fn unit(mut self, unit: &str) -> Self {
self.unit = Some(unit.to_string());
self
}
/// Specify the floating point precision.
pub fn precision(mut self, precision: usize) -> Self {
self.precision = Some(precision);
self
}
/// Get the metric name.
pub fn name(&self) -> &Arc<String> {
&self.name
}
/// Get the metric unit.
pub fn unit_value(&self) -> &Option<String> {
&self.unit
}
/// Get the precision.
pub fn precision_value(&self) -> Option<usize> {
self.precision
}
}
impl NumericMetricState {
/// Create a new [numeric metric state](NumericMetricState).
pub fn new() -> Self {
Self {
sum: 0.0,
count: 0,
current: f64::NAN,
current_count: 0,
}
}
/// Reset the state.
pub fn reset(&mut self) {
self.sum = 0.0;
self.count = 0;
self.current = f64::NAN;
self.current_count = 0;
}
/// Update the state.
pub fn update(
&mut self,
value: f64,
batch_size: usize,
format: FormatOptions,
) -> SerializedEntry {
self.sum += value * batch_size as f64;
self.count += batch_size;
self.current = value;
self.current_count = batch_size;
let value_current = value;
let value_running = self.sum / self.count as f64;
// Numeric metric state is an aggregated value
let serialized = NumericEntry::Aggregated {
aggregated_value: value_current,
count: batch_size,
}
.serialize();
let (formatted_current, formatted_running) = match format.precision {
Some(precision) => (
format_float(value_current, precision),
format_float(value_running, precision),
),
None => (format!("{value_current}"), format!("{value_running}")),
};
// TODO: naming inconsistent with RL.
let formatted = match format.unit {
Some(unit) => {
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
}
None => format!("epoch {formatted_running} - batch {formatted_current}"),
};
SerializedEntry::new(formatted, serialized)
}
/// Get the numeric value.
pub fn current_value(&self) -> NumericEntry {
NumericEntry::Aggregated {
aggregated_value: self.current,
count: self.current_count,
}
}
/// Get the running aggregated value.
pub fn running_value(&self) -> NumericEntry {
NumericEntry::Aggregated {
aggregated_value: self.sum / self.count as f64,
count: self.count,
}
}
}
impl Default for NumericMetricState {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,251 @@
use crate::{
logger::MetricLogger,
metric::{NumericEntry, store::Split},
};
use std::collections::HashMap;
use super::{Aggregate, Direction};
/// Type that can be used to fetch and use numeric metric aggregates.
#[derive(Default, Debug)]
pub(crate) struct NumericMetricsAggregate {
value_for_each_epoch: HashMap<Key, f64>,
}
#[derive(new, Hash, PartialEq, Eq, Debug)]
struct Key {
name: String,
epoch: usize,
split: Split,
aggregate: Aggregate,
}
impl NumericMetricsAggregate {
pub(crate) fn aggregate(
&mut self,
name: &str,
epoch: usize,
split: &Split,
aggregate: Aggregate,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<f64> {
let key = Key::new(name.to_string(), epoch, split.clone(), aggregate);
if let Some(value) = self.value_for_each_epoch.get(&key) {
return Some(*value);
}
let points = || {
let mut errors = Vec::new();
for logger in loggers {
match logger.read_numeric(name, epoch, split) {
Ok(points) => return Ok(points),
Err(err) => errors.push(err),
};
}
Err(errors.join(" "))
};
let points = points().expect("Can read values");
if points.is_empty() {
return None;
}
// Accurately compute the aggregated value based on the *actual* number of points
// since not all mini-batches are guaranteed to have the specified batch size
let (sum, num_points) = points
.into_iter()
.map(|entry| match entry {
NumericEntry::Value(v) => (v, 1),
// Right now the mean is the only aggregate available, so we can assume that the sum
// of an entry corresponds to (value * number of elements)
NumericEntry::Aggregated {
aggregated_value,
count,
} => (aggregated_value * count as f64, count),
})
.reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))
.unwrap();
let value = match aggregate {
Aggregate::Mean => sum / num_points as f64,
};
self.value_for_each_epoch.insert(key, value);
Some(value)
}
pub(crate) fn find_epoch(
&mut self,
name: &str,
split: &Split,
aggregate: Aggregate,
direction: Direction,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<usize> {
let mut data = Vec::new();
let mut current_epoch = 1;
while let Some(value) = self.aggregate(name, current_epoch, split, aggregate, loggers) {
data.push(value);
current_epoch += 1;
}
if data.is_empty() {
return None;
}
let mut current_value = match &direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
for (i, value) in data.into_iter().enumerate() {
match &direction {
Direction::Lowest => {
if value < current_value {
current_value = value;
current_epoch = i + 1;
}
}
Direction::Highest => {
if value > current_value {
current_value = value;
current_epoch = i + 1;
}
}
}
}
Some(current_epoch)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
logger::{FileMetricLogger, InMemoryMetricLogger},
metric::{MetricDefinition, MetricEntry, MetricId, SerializedEntry, store::MetricsUpdate},
};
use super::*;
struct TestLogger {
logger: FileMetricLogger,
epoch: usize,
}
const NAME: &str = "test-logger";
impl TestLogger {
fn new() -> Self {
Self {
logger: FileMetricLogger::new("/tmp"),
epoch: 1,
}
}
fn log(&mut self, num: f64) {
let entry = MetricEntry::new(
MetricId::new(Arc::new(NAME.into())),
SerializedEntry::new(num.to_string(), num.to_string()),
);
let entries = Vec::from([entry]);
let metrics_update = MetricsUpdate::new(entries, vec![]);
self.logger.log(metrics_update, self.epoch, &Split::Train);
}
fn log_definition(&mut self) {
let definition = MetricDefinition {
metric_id: MetricId::new(Arc::new(NAME.into())),
name: NAME.into(),
attributes: crate::metric::MetricAttributes::None,
description: None,
};
self.logger.log_metric_definition(definition);
}
fn new_epoch(&mut self) {
self.epoch += 1;
}
}
#[test]
fn should_find_epoch() {
let mut logger = TestLogger::new();
let mut aggregate = NumericMetricsAggregate::default();
logger.log_definition();
logger.log(500.); // Epoch 1
logger.log(1000.); // Epoch 1
logger.new_epoch();
logger.log(200.); // Epoch 2
logger.log(1000.); // Epoch 2
logger.new_epoch();
logger.log(10000.); // Epoch 3
let value = aggregate
.find_epoch(
NAME,
&Split::Train,
Aggregate::Mean,
Direction::Lowest,
&mut [Box::new(logger.logger)],
)
.unwrap();
assert_eq!(value, 2);
}
#[test]
fn should_aggregate_numeric_entry() {
let mut logger = InMemoryMetricLogger::default();
let mut aggregate = NumericMetricsAggregate::default();
let metric_name = Arc::new("Loss".to_string());
let metric_id = MetricId::new(metric_name.clone());
let definition = MetricDefinition {
metric_id: metric_id.clone(),
name: metric_name.to_string(),
attributes: crate::metric::MetricAttributes::None,
description: None,
};
logger.log_metric_definition(definition);
// Epoch 1
let loss_1 = 0.5;
let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2
let entry = MetricEntry::new(
metric_id.clone(),
SerializedEntry::new(loss_1.to_string(), NumericEntry::Value(loss_1).serialize()),
);
let entries = Vec::from([entry]);
let metrics_update = MetricsUpdate::new(entries, vec![]);
logger.log(metrics_update, 1, &Split::Train);
let entry = MetricEntry::new(
metric_id.clone(),
SerializedEntry::new(
loss_2.to_string(),
NumericEntry::Aggregated {
aggregated_value: loss_2,
count: 2,
}
.serialize(),
),
);
let entries = Vec::from([entry]);
let metrics_update = MetricsUpdate::new(entries, vec![]);
logger.log(metrics_update, 1, &Split::Train);
let value = aggregate
.aggregate(
&metric_name,
1,
&Split::Train,
Aggregate::Mean,
&mut [Box::new(logger)],
)
.unwrap();
// Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875
assert_eq!(value, 1.0);
}
}

View File

@@ -0,0 +1,105 @@
use std::sync::Arc;
use crate::metric::{MetricDefinition, MetricEntry, NumericEntry};
/// Event happening during the training/validation process.
pub enum Event {
/// Signal the iniialization of the metrics
MetricsInit(Vec<MetricDefinition>),
/// Signal that metrics have been updated.
MetricsUpdate(MetricsUpdate),
/// Signal the end of an epoch.
EndEpoch(EpochSummary),
}
/// Contains all metric information.
#[derive(new, Clone, Debug)]
pub struct NumericMetricUpdate {
/// Generic metric information.
pub entry: MetricEntry,
/// The numeric information.
pub numeric_entry: NumericEntry,
/// Numeric value averaged over the global step (epoch).
pub running_entry: NumericEntry,
}
/// Contains all metric information.
#[derive(new, Clone, Debug)]
pub struct MetricsUpdate {
/// Metrics information related to non-numeric metrics.
pub entries: Vec<MetricEntry>,
/// Metrics information related to numeric metrics.
pub entries_numeric: Vec<NumericMetricUpdate>,
}
/// Summary information about a given epoch
#[derive(new, Clone, Debug)]
pub struct EpochSummary {
/// Epoch number.
pub epoch_number: usize,
/// Dataset split (train, valid, test).
pub split: Split,
}
/// Defines how training and validation events are collected and searched.
///
/// This trait also exposes methods that uses the collected data to compute useful information.
pub trait EventStore: Send {
/// Collect a training/validation event.
fn add_event(&mut self, event: Event, split: Split);
/// Find the epoch following the given criteria from the collected data.
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: &Split,
) -> Option<usize>;
/// Find the metric value for the current epoch following the given criteria.
fn find_metric(
&mut self,
name: &str,
epoch: usize,
aggregate: Aggregate,
split: &Split,
) -> Option<f64>;
}
#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
/// How to aggregate the metric.
pub enum Aggregate {
/// Compute the average.
Mean,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
/// The split to use.
pub enum Split {
/// The training split.
Train,
/// The validation split.
Valid,
/// The testing split, which might be tagged.
Test(Option<Arc<String>>),
}
impl std::fmt::Display for Split {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Split::Train => write!(f, "train"),
Split::Valid => write!(f, "valid"),
Split::Test(_) => write!(f, "test"),
}
}
}
#[derive(Copy, Clone)]
/// The direction of the query.
pub enum Direction {
/// Lower is better.
Lowest,
/// Higher is better.
Highest,
}

View File

@@ -0,0 +1,171 @@
use super::EventStore;
use super::{Aggregate, Direction, Event, Split};
use std::sync::Arc;
use std::{sync::mpsc, thread::JoinHandle};
/// Type that allows to communicate with an [event store](EventStore).
pub struct EventStoreClient {
sender: mpsc::Sender<Message>,
handler: Option<JoinHandle<()>>,
}
impl EventStoreClient {
/// Create a new [event store](EventStore) client.
pub(crate) fn new<C>(store: C) -> Self
where
C: EventStore + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(store, receiver);
let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);
Self { sender, handler }
}
}
impl EventStoreClient {
/// Add a training event to the [event store](EventStore).
pub(crate) fn add_event_train(&self, event: Event) {
self.sender
.send(Message::OnEventTrain(event))
.expect("Can send event to event store thread.");
}
/// Add a validation event to the [event store](EventStore).
pub(crate) fn add_event_valid(&self, event: Event) {
self.sender
.send(Message::OnEventValid(event))
.expect("Can send event to event store thread.");
}
/// Add a testing event to the [event store](EventStore).
pub(crate) fn add_event_test(&self, event: Event, tag: Arc<String>) {
self.sender
.send(Message::OnEventTest(event, tag))
.expect("Can send event to event store thread.");
}
/// Find the epoch following the given criteria from the collected data.
pub fn find_epoch(
&self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: &Split,
) -> Option<usize> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindEpoch(
name.to_string(),
aggregate,
direction,
split.clone(),
sender,
))
.expect("Can send event to event store thread.");
match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Event store thread crashed: {err:?}"),
}
}
/// Find the metric value for the current epoch following the given criteria.
pub fn find_metric(
&self,
name: &str,
epoch: usize,
aggregate: Aggregate,
split: &Split,
) -> Option<f64> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindMetric(
name.to_string(),
epoch,
aggregate,
split.clone(),
sender,
))
.expect("Can send event to event store thread.");
match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Event store thread crashed: {err:?}"),
}
}
}
#[derive(new)]
struct WorkerThread<S> {
store: S,
receiver: mpsc::Receiver<Message>,
}
impl<C> WorkerThread<C>
where
C: EventStore,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::End => {
return;
}
Message::FindEpoch(name, aggregate, direction, split, callback) => {
let response = self.store.find_epoch(&name, aggregate, direction, &split);
callback
.send(response)
.expect("Can send response using callback channel.");
}
Message::FindMetric(name, epoch, aggregate, split, callback) => {
let response = self.store.find_metric(&name, epoch, aggregate, &split);
callback
.send(response)
.expect("Can send response using callback channel.");
}
Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
Message::OnEventTest(event, tag) => {
self.store.add_event(event, Split::Test(Some(tag)))
}
}
}
}
}
enum Message {
OnEventTest(Event, Arc<String>),
OnEventTrain(Event),
OnEventValid(Event),
End,
FindEpoch(
String,
Aggregate,
Direction,
Split,
mpsc::SyncSender<Option<usize>>,
),
FindMetric(
String,
usize,
Aggregate,
Split,
mpsc::SyncSender<Option<f64>>,
),
}
impl Drop for EventStoreClient {
fn drop(&mut self) {
self.sender
.send(Message::End)
.expect("Can send the end message to the event store thread.");
let handler = self.handler.take();
if let Some(handler) = handler {
handler.join().expect("The event store thread should stop.");
}
}
}

View File

@@ -0,0 +1,72 @@
use std::collections::HashMap;
use super::{Aggregate, Direction, Event, EventStore, Split, aggregate::NumericMetricsAggregate};
use crate::logger::MetricLogger;
#[derive(Default)]
pub(crate) struct LogEventStore {
loggers: Vec<Box<dyn MetricLogger>>,
aggregate: NumericMetricsAggregate,
epochs: HashMap<Split, usize>,
}
impl EventStore for LogEventStore {
fn add_event(&mut self, event: Event, split: Split) {
let epoch = *self.epochs.entry(split.clone()).or_insert(1);
match event {
Event::MetricsInit(definitions) => {
definitions.iter().for_each(|def| {
self.loggers
.iter_mut()
.for_each(|logger| logger.log_metric_definition(def.clone()));
});
}
Event::MetricsUpdate(update) => {
self.loggers
.iter_mut()
.for_each(|logger| logger.log(update.clone(), epoch, &split));
}
Event::EndEpoch(summary) => {
self.epochs.insert(split, summary.epoch_number + 1);
self.loggers
.iter_mut()
.for_each(|logger| logger.log_epoch_summary(summary.clone()));
}
}
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: &Split,
) -> Option<usize> {
self.aggregate
.find_epoch(name, split, aggregate, direction, &mut self.loggers)
}
fn find_metric(
&mut self,
name: &str,
epoch: usize,
aggregate: Aggregate,
split: &Split,
) -> Option<f64> {
self.aggregate
.aggregate(name, epoch, split, aggregate, &mut self.loggers)
}
}
impl LogEventStore {
/// Register a logger for metrics.
pub(crate) fn register_logger<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers.push(Box::new(logger));
}
/// Returns whether any loggers are registered.
pub(crate) fn has_loggers(&self) -> bool {
!self.loggers.is_empty()
}
}

View File

@@ -0,0 +1,9 @@
pub(crate) mod aggregate;
mod base;
mod client;
mod log;
pub(crate) use self::log::*;
pub use base::*;
pub use client::*;

View File

@@ -0,0 +1,185 @@
use core::marker::PhantomData;
use std::sync::Arc;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricMetadata, SerializedEntry};
use crate::metric::{
Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
/// The Top-K accuracy metric.
///
/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`).
#[derive(Default, Clone)]
pub struct TopKAccuracyMetric<B: Backend> {
name: Arc<String>,
k: usize,
state: NumericMetricState,
/// If specified, targets equal to this value will be considered padding and will not count
/// towards the metric
pad_token: Option<usize>,
_b: PhantomData<B>,
}
/// The [top-k accuracy metric](TopKAccuracyMetric) input type.
#[derive(new)]
pub struct TopKAccuracyInput<B: Backend> {
/// The outputs (batch_size, num_classes)
outputs: Tensor<B, 2>,
/// The labels (batch_size)
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> TopKAccuracyMetric<B> {
/// Creates the metric.
pub fn new(k: usize) -> Self {
Self {
name: Arc::new(format!("Top-K Accuracy @ TopK({})", k)),
k,
..Default::default()
}
}
/// Sets the pad token.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for TopKAccuracyMetric<B> {
type Input = TopKAccuracyInput<B>;
fn update(
&mut self,
input: &TopKAccuracyInput<B>,
_metadata: &MetricMetadata,
) -> SerializedEntry {
let [batch_size, _n_classes] = input.outputs.dims();
let targets = input.targets.clone().to_device(&B::Device::default());
let outputs = input
.outputs
.clone()
.argsort_descending(1)
.narrow(1, 0, self.k)
.to_device(&B::Device::default())
.reshape([batch_size, self.k]);
let (targets, num_pad) = match self.pad_token {
Some(pad_token) => {
// we ignore the samples where the target is equal to the pad token
let mask = targets.clone().equal_elem(pad_token as i64);
let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();
(targets.clone().mask_fill(mask, -1_i64), num_pad)
}
None => (targets.clone(), 0_f64),
};
let accuracy = targets
.reshape([batch_size, 1])
.repeat_dim(1, self.k)
.equal(outputs)
.int()
.sum()
.into_scalar()
.elem::<f64>()
/ (batch_size as f64 - num_pad);
self.state.update(
100.0 * accuracy,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for TopKAccuracyMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_accuracy_without_padding() {
let device = Default::default();
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);
let input = TopKAccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8], // 2, 1
[1.0, 2.0, 0.5], // 1, 0
[0.4, 0.1, 0.2], // 0, 2
[0.6, 0.7, 0.2], // 1, 0
],
&device,
),
Tensor::from_data([2, 2, 1, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
#[test]
fn test_accuracy_with_padding() {
let device = Default::default();
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);
let input = TopKAccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8, 0.0], // 2, 1
[1.0, 2.0, 0.5, 0.0], // 1, 0
[0.4, 0.1, 0.2, 0.0], // 0, 2
[0.6, 0.7, 0.2, 0.0], // 1, 0
[0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
[0.0, 0.1, 0.2, 0.0], // Error on padding should not count
[0.6, 0.0, 0.2, 0.0], // Error on padding should not count
],
&device,
),
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
#[test]
fn test_parameterized_unique_name() {
let metric_a = TopKAccuracyMetric::<TestBackend>::new(2);
let metric_b = TopKAccuracyMetric::<TestBackend>::new(1);
let metric_c = TopKAccuracyMetric::<TestBackend>::new(2);
assert_ne!(metric_a.name(), metric_b.name());
assert_eq!(metric_a.name(), metric_c.name());
}
}

View File

@@ -0,0 +1,345 @@
use crate::metric::{MetricAttributes, MetricName, SerializedEntry};
use super::super::{
Metric, MetricMetadata,
state::{FormatOptions, NumericMetricState},
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::{ElementConversion, Int, s},
};
use core::marker::PhantomData;
/// Input type for the [DiceMetric].
///
/// # Type Parameters
/// - `B`: Backend type.
/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).
pub struct DiceInput<B: Backend, const D: usize = 4> {
/// Model outputs (predictions), as a tensor.
outputs: Tensor<B, D, Int>,
/// Ground truth targets, as a tensor.
targets: Tensor<B, D, Int>,
}
impl<B: Backend, const D: usize> DiceInput<B, D> {
/// Creates a new DiceInput with the given outputs and targets.
///
/// Inputs are expected to have the dimensions `[B, C, ...]`
/// where `B` is the batch size, `C` is the number of classes,
/// and `...` represents additional dimensions (e.g., height, width for images).
///
/// If `C` is more than 1, the first class (index 0) is considered the background.
/// Additionally, one-hot encoding is the responsibility of the caller.
///
/// # Arguments
/// - `outputs`: The model outputs as a tensor.
/// - `targets`: The ground truth targets as a tensor.
///
/// # Returns
/// A new instance of `DiceInput`.
///
/// # Panics
/// - If `D` is less than 3.
/// - If `outputs` and `targets` do not have the same dimensions.
/// - If `outputs` or `targets` do not have exactly `D` dimensions.
/// - If `outputs` and `targets` do not have the same shape.
pub fn new(outputs: Tensor<B, D, Int>, targets: Tensor<B, D, Int>) -> Self {
assert!(D >= 3, "DiceInput requires at least 3 dimensions.");
assert!(
outputs.dims() == targets.dims(),
"Outputs and targets must have the same dimensions. Got {:?} and {:?}",
outputs.dims(),
targets.dims()
);
Self { outputs, targets }
}
}
/// Configuration for the [DiceMetric].
#[derive(Debug, Clone, Copy)]
pub struct DiceMetricConfig {
/// Epsilon value to avoid division by zero.
pub epsilon: f64,
/// Whether to include the background class in the metric calculation.
/// The background is assumed to be the first class (index 0).
/// if `true`, will panic if there are fewer than 2 classes.
pub include_background: bool,
}
impl Default for DiceMetricConfig {
fn default() -> Self {
Self {
epsilon: 1e-7,
include_background: false,
}
}
}
/// The Dice-Sorenson coefficient (DSC) for evaluating overlap between two binary masks.
/// The DSC is defined as:
/// `DSC = 2 * (|X ∩ Y|) / (|X| + |Y|)`
/// where `X` is the model output and `Y` is the ground truth target.
///
/// # Type Parameters
/// - `B`: Backend type.
/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).
#[derive(Default, Clone)]
pub struct DiceMetric<B: Backend, const D: usize = 4> {
name: MetricName,
/// Internal state for numeric metric aggregation.
state: NumericMetricState,
/// Marker for backend type.
_b: PhantomData<B>,
/// Configuration for the metric.
config: DiceMetricConfig,
}
impl<B: Backend, const D: usize> DiceMetric<B, D> {
/// Creates a new Dice metric instance with default config.
pub fn new() -> Self {
Self::with_config(DiceMetricConfig::default())
}
/// Creates a new Dice metric with a custom config.
pub fn with_config(config: DiceMetricConfig) -> Self {
let name = MetricName::new(format!("{D}D Dice Metric"));
assert!(D >= 3, "DiceMetric requires at least 3 dimensions.");
Self {
name,
config,
..Default::default()
}
}
}
impl<B: Backend, const D: usize> Metric for DiceMetric<B, D> {
type Input = DiceInput<B, D>;
fn name(&self) -> MetricName {
self.name.clone()
}
fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
// Dice coefficient: 2 * (|X ∩ Y|) / (|X| + |Y|)
if item.outputs.dims() != item.targets.dims() {
panic!(
"Outputs and targets must have the same dimensions. Got {:?} and {:?}",
item.outputs.dims(),
item.targets.dims()
);
}
let dims = item.outputs.dims();
let batch_size = dims[0];
let n_classes = dims[1];
let mut outputs = item.outputs.clone();
let mut targets = item.targets.clone();
if !self.config.include_background && n_classes > 1 {
// If not including background, we can ignore the first class
outputs = outputs.slice(s![.., 1..]);
targets = targets.slice(s![.., 1..]);
} else if self.config.include_background && n_classes < 2 {
// If including background, we need at least 2 classes
panic!("Dice metric requires at least 2 classes when including background.");
}
let intersection = (outputs.clone() * targets.clone()).sum();
let outputs_sum = outputs.sum();
let targets_sum = targets.sum();
// Convert to f64
let intersection_val = intersection.into_scalar().elem::<f64>();
let outputs_sum_val = outputs_sum.into_scalar().elem::<f64>();
let targets_sum_val = targets_sum.into_scalar().elem::<f64>();
// Use epsilon from config
let epsilon = self.config.epsilon;
let dice =
(2.0 * intersection_val + epsilon) / (outputs_sum_val + targets_sum_val + epsilon);
self.state.update(
dice,
batch_size,
FormatOptions::new(self.name()).precision(4),
)
}
/// Clears the metric state.
fn clear(&mut self) {
self.state.reset();
}
fn attributes(&self) -> MetricAttributes {
crate::metric::NumericAttributes {
unit: None,
higher_is_better: true,
}
.into()
}
}
impl<B: Backend, const D: usize> crate::metric::Numeric for DiceMetric<B, D> {
fn value(&self) -> crate::metric::NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> crate::metric::NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TestBackend, metric::Numeric};
use burn_core::tensor::{Shape, Tensor};
#[test]
fn test_dice_perfect_overlap() {
let device = Default::default();
let mut metric = DiceMetric::<TestBackend, 4>::new();
let input = DiceInput::new(
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert!((metric.value().current() - 1.0).abs() < 1e-6);
}
#[test]
fn test_dice_no_overlap() {
let device = Default::default();
let mut metric = DiceMetric::<TestBackend, 4>::new();
let input = DiceInput::new(
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
Tensor::from_data([[[[0, 1], [0, 1]]]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert!(metric.value().current() < 1e-6);
}
#[test]
fn test_dice_partial_overlap() {
let device = Default::default();
let mut metric = DiceMetric::<TestBackend, 4>::new();
let input = DiceInput::new(
Tensor::from_data([[[[1, 1], [0, 0]]]], &device),
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
// intersection = 1, sum = 2+2=4, dice = 2*1/4 = 0.5
assert!((metric.value().current() - 0.5).abs() < 1e-6);
}
#[test]
fn test_dice_empty_masks() {
let device = Default::default();
let mut metric = DiceMetric::<TestBackend, 4>::new();
let input = DiceInput::new(
Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert!((metric.value().current() - 1.0).abs() < 1e-6);
}
#[test]
fn test_dice_no_background() {
let device = Default::default();
let mut metric = DiceMetric::<TestBackend, 4>::new();
let input = DiceInput::new(
Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert!((metric.value().current() - 1.0).abs() < 1e-6);
}
#[test]
fn test_dice_with_background() {
let device = Default::default();
let config = DiceMetricConfig {
epsilon: 1e-7,
include_background: true,
};
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
let input = DiceInput::new(
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert!((metric.value().current() - 1.0).abs() < 1e-6);
}
#[test]
fn test_dice_ignored_background() {
let device = Default::default();
let config = DiceMetricConfig {
epsilon: 1e-7,
include_background: false,
};
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
let input = DiceInput::new(
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert!((metric.value().current() - 1.0).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "DiceInput requires at least 3 dimensions.")]
fn test_invalid_input_dimensions() {
let device = Default::default();
// D = 2, should panic
let _ = DiceInput::<TestBackend, 2>::new(
Tensor::from_data([[0.0, 0.0]], &device),
Tensor::from_data([[0.0, 0.0]], &device),
);
}
#[test]
#[should_panic(
expected = "Outputs and targets must have the same dimensions. Got [1, 1, 2, 2] and [1, 1, 2, 3]"
)]
fn test_mismatched_shape() {
let device = Default::default();
// shapes differ
let _ = DiceInput::<TestBackend, 4>::new(
Tensor::from_data([[[[0.0; 2]; 2]; 1]; 1], &device),
Tensor::from_data([[[[0.0; 3]; 2]; 1]; 1], &device),
);
}
#[test]
#[should_panic(expected = "Dice metric requires at least 2 classes when including background.")]
fn test_include_background_panic() {
let device = Default::default();
let config = DiceMetricConfig {
epsilon: 1e-7,
include_background: true,
};
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
let input = DiceInput::new(
Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
);
// n_classes = 2, should not panic
let _entry = metric.update(&input, &MetricMetadata::fake());
let config = DiceMetricConfig {
epsilon: 1e-7,
include_background: true,
};
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
let input = DiceInput::new(
Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
);
// n_classes = 1, should panic
let _entry = metric.update(&input, &MetricMetadata::fake());
}
}

Some files were not shown because too many files have changed in this diff Show More