feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Training crate for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-train"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-train"
|
||||
documentation = "https://docs.rs/burn-train"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["sys-metrics", "tui", "rl"]
|
||||
doc = ["default"]
|
||||
vision = ["burn-nn", "burn-store/pytorch", "burn-std/network", "dirs"]
|
||||
tracing = [
|
||||
"burn-core/tracing",
|
||||
"burn-optim/tracing",
|
||||
"burn-collective?/tracing",
|
||||
]
|
||||
|
||||
|
||||
sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"]
|
||||
tui = ["ratatui"]
|
||||
rl = ["burn-rl"]
|
||||
# Distributed Data Parallel
|
||||
ddp = ["burn-collective", "burn-optim/collective"]
|
||||
|
||||
[dependencies]
|
||||
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [
|
||||
"dataset",
|
||||
"std",
|
||||
], default-features = false }
|
||||
burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [
|
||||
"std",
|
||||
], default-features = false }
|
||||
burn-rl = { path = "../burn-rl", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] }
|
||||
burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] }
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] }
|
||||
dirs = { workspace = true, optional = true }
|
||||
|
||||
log = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
tracing-appender = { workspace = true }
|
||||
tracing-core = { workspace = true }
|
||||
|
||||
# System Metrics
|
||||
nvml-wrapper = { workspace = true, optional = true }
|
||||
sysinfo = { workspace = true, optional = true }
|
||||
systemstat = { workspace = true, optional = true }
|
||||
|
||||
# Text UI
|
||||
ratatui = { workspace = true, optional = true, features = [
|
||||
"all-widgets",
|
||||
"crossterm",
|
||||
] }
|
||||
|
||||
# Utilities
|
||||
derive-new = { workspace = true }
|
||||
serde = { workspace = true, features = ["std", "derive"] }
|
||||
async-channel = { workspace = true }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
rstest.workspace = true
|
||||
thiserror.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-train/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-train/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,6 @@
|
||||
# Burn Train
|
||||
|
||||
This crate should be used with [burn](https://github.com/tracel-ai/burn).
|
||||
|
||||
[](https://crates.io/crates/burn-train)
|
||||
[](https://github.com/tracel-ai/burn-train/blob/master/README.md)
|
||||
@@ -0,0 +1,171 @@
|
||||
use super::{Checkpointer, CheckpointerError};
|
||||
use crate::Interrupter;
|
||||
use burn_core::{record::Record, tensor::backend::Backend};
|
||||
use std::sync::mpsc;
|
||||
|
||||
enum Message<R, B: Backend> {
|
||||
Restore(
|
||||
usize,
|
||||
B::Device,
|
||||
mpsc::SyncSender<Result<R, CheckpointerError>>,
|
||||
Option<Interrupter>,
|
||||
),
|
||||
Save(usize, R, Option<Interrupter>),
|
||||
Delete(usize, Option<Interrupter>),
|
||||
End,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct CheckpointerThread<C, R, B: Backend> {
|
||||
checkpointer: C,
|
||||
receiver: mpsc::Receiver<Message<R, B>>,
|
||||
}
|
||||
|
||||
impl<C, R, B> CheckpointerThread<C, R, B>
|
||||
where
|
||||
C: Checkpointer<R, B>,
|
||||
R: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
fn run(self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::Restore(epoch, device, callback, interrupter) => {
|
||||
let record = self.checkpointer.restore(epoch, &device);
|
||||
callback.send(record).unwrap_or_else(|err| {
|
||||
interrupter.map_or_else(
|
||||
|| {
|
||||
panic!(
|
||||
"Error when sending response through callback channel: {err}"
|
||||
)
|
||||
},
|
||||
|int| int.stop(Some(&err.to_string())),
|
||||
)
|
||||
});
|
||||
}
|
||||
Message::Save(epoch, state, interrupter) => {
|
||||
self.checkpointer.save(epoch, state).unwrap_or_else(|err| {
|
||||
interrupter.map_or_else(
|
||||
|| panic!("Error when saving the state: {err}"),
|
||||
|int| int.stop(Some(&err.to_string())),
|
||||
)
|
||||
});
|
||||
}
|
||||
Message::Delete(epoch, interrupter) => {
|
||||
self.checkpointer.delete(epoch).unwrap_or_else(|err| {
|
||||
interrupter.map_or_else(
|
||||
|| panic!("Error when deleting the state: {err}"),
|
||||
|int| int.stop(Some(&err.to_string())),
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Async checkpointer.
|
||||
pub struct AsyncCheckpointer<Record, B: Backend> {
|
||||
sender: mpsc::SyncSender<Message<Record, B>>,
|
||||
handler: Option<std::thread::JoinHandle<()>>,
|
||||
interrupter: Option<Interrupter>,
|
||||
}
|
||||
|
||||
impl<R, B> AsyncCheckpointer<R, B>
|
||||
where
|
||||
R: Record<B> + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
/// Create a new async checkpointer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `checkpointer` - The checkpointer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The async checkpointer.
|
||||
pub fn new<C>(checkpointer: C) -> Self
|
||||
where
|
||||
C: Checkpointer<R, B> + Send + 'static,
|
||||
{
|
||||
// Only on checkpoint can be done in advance.
|
||||
let (sender, receiver) = mpsc::sync_channel(0);
|
||||
let thread = CheckpointerThread::new(checkpointer, receiver);
|
||||
let handler = Some(std::thread::spawn(move || thread.run()));
|
||||
|
||||
Self {
|
||||
sender,
|
||||
handler,
|
||||
interrupter: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a handle used to interrupt training in case of checkpointing error.
|
||||
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
|
||||
self.interrupter = Some(interrupter);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
|
||||
where
|
||||
R: Record<B> + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
self.sender
|
||||
.send(Message::Save(epoch, record, self.interrupter.clone()))
|
||||
.expect("Can send message to checkpointer thread.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::Restore(
|
||||
epoch,
|
||||
device.clone(),
|
||||
sender,
|
||||
self.interrupter.clone(),
|
||||
))
|
||||
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
|
||||
|
||||
if let Ok(record) = receiver.recv() {
|
||||
return record;
|
||||
};
|
||||
|
||||
Err(CheckpointerError::Unknown("Channel error.".to_string()))
|
||||
}
|
||||
|
||||
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
|
||||
self.sender
|
||||
.send(Message::Delete(epoch, self.interrupter.clone()))
|
||||
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, B> Drop for AsyncCheckpointer<E, B>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
self.sender
|
||||
.send(Message::End)
|
||||
.expect("Can send the end message to the checkpointer thread.");
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler
|
||||
.join()
|
||||
.expect("The checkpointer thread should stop.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
use burn_core::{
|
||||
record::{Record, RecorderError},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
/// The error type for checkpointer.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CheckpointerError {
|
||||
/// IO error.
|
||||
#[error("I/O Error: `{0}`")]
|
||||
IOError(std::io::Error),
|
||||
|
||||
/// Recorder error.
|
||||
#[error("Recorder error: `{0}`")]
|
||||
RecorderError(RecorderError),
|
||||
|
||||
/// Other errors.
|
||||
#[error("Unknown error: `{0}`")]
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
/// The trait for checkpointer.
|
||||
pub trait Checkpointer<R, B>: Send + Sync
|
||||
where
|
||||
R: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
/// Save the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
/// * `record` - The record.
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
|
||||
|
||||
/// Delete the record at the given epoch if present.
|
||||
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
|
||||
|
||||
/// Restore the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
/// * `device` - The device used to restore the record.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::{
|
||||
record::{FileRecorder, Record},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
/// The file checkpointer.
|
||||
pub struct FileCheckpointer<FR> {
|
||||
directory: PathBuf,
|
||||
name: String,
|
||||
recorder: FR,
|
||||
}
|
||||
|
||||
impl<FR> FileCheckpointer<FR> {
|
||||
/// Creates a new file checkpointer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `recorder` - The file recorder.
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
/// * `name` - The name of the checkpoint.
|
||||
pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
|
||||
let directory = directory.as_ref();
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
|
||||
Self {
|
||||
directory: directory.to_path_buf(),
|
||||
name: name.to_string(),
|
||||
recorder,
|
||||
}
|
||||
}
|
||||
|
||||
fn path_for_epoch(&self, epoch: usize) -> PathBuf {
|
||||
self.directory.join(format!("{}-{}", self.name, epoch))
|
||||
}
|
||||
}
|
||||
|
||||
impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
|
||||
where
|
||||
R: Record<B>,
|
||||
FR: FileRecorder<B>,
|
||||
B: Backend,
|
||||
{
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::trace!("Saving checkpoint {} to {}", epoch, file_path.display());
|
||||
|
||||
self.recorder
|
||||
.record(record, file_path)
|
||||
.map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!(
|
||||
"Restoring checkpoint {} from {}",
|
||||
epoch,
|
||||
file_path.display()
|
||||
);
|
||||
let record = self
|
||||
.recorder
|
||||
.load(file_path, device)
|
||||
.map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
|
||||
let file_to_remove = format!(
|
||||
"{}.{}",
|
||||
self.path_for_epoch(epoch).display(),
|
||||
FR::file_extension(),
|
||||
);
|
||||
|
||||
if std::path::Path::new(&file_to_remove).exists() {
|
||||
log::trace!("Removing checkpoint {file_to_remove}");
|
||||
std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
mod async_checkpoint;
|
||||
mod base;
|
||||
mod file;
|
||||
mod strategy;
|
||||
|
||||
pub use async_checkpoint::*;
|
||||
pub use base::*;
|
||||
pub use file::*;
|
||||
pub use strategy::*;
|
||||
@@ -0,0 +1,34 @@
|
||||
use std::ops::DerefMut;
|
||||
|
||||
use crate::metric::store::EventStoreClient;
|
||||
|
||||
/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum CheckpointingAction {
|
||||
/// Delete the given epoch.
|
||||
Delete(usize),
|
||||
/// Save the current record.
|
||||
Save,
|
||||
}
|
||||
|
||||
/// Define when checkpoint should be saved and deleted.
|
||||
pub trait CheckpointingStrategy: Send {
|
||||
/// Based on the epoch, determine if the checkpoint should be saved.
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
collector: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction>;
|
||||
}
|
||||
|
||||
// We make dyn box implement the checkpointing strategy so that it can be used with generic, but
|
||||
// still be dynamic.
|
||||
impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
collector: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
self.deref_mut().checkpointing(epoch, collector)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
use crate::metric::store::EventStoreClient;
|
||||
|
||||
use super::{CheckpointingAction, CheckpointingStrategy};
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an
|
||||
/// epoch to be deleted.
|
||||
pub struct ComposedCheckpointingStrategy {
|
||||
strategies: Vec<Box<dyn CheckpointingStrategy>>,
|
||||
deleted: Vec<HashSet<usize>>,
|
||||
}
|
||||
|
||||
/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.
|
||||
#[derive(Default)]
|
||||
pub struct ComposedCheckpointingStrategyBuilder {
|
||||
strategies: Vec<Box<dyn CheckpointingStrategy>>,
|
||||
}
|
||||
|
||||
impl ComposedCheckpointingStrategyBuilder {
|
||||
/// Add a new [checkpointing strategy](CheckpointingStrategy).
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn add<S>(mut self, strategy: S) -> Self
|
||||
where
|
||||
S: CheckpointingStrategy + 'static,
|
||||
{
|
||||
self.strategies.push(Box::new(strategy));
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).
|
||||
pub fn build(self) -> ComposedCheckpointingStrategy {
|
||||
ComposedCheckpointingStrategy::new(self.strategies)
|
||||
}
|
||||
}
|
||||
|
||||
impl ComposedCheckpointingStrategy {
|
||||
fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
|
||||
Self {
|
||||
deleted: strategies.iter().map(|_| HashSet::new()).collect(),
|
||||
strategies,
|
||||
}
|
||||
}
|
||||
/// Create a new builder which help compose multiple
|
||||
/// [checkpointing strategies](CheckpointingStrategy).
|
||||
pub fn builder() -> ComposedCheckpointingStrategyBuilder {
|
||||
ComposedCheckpointingStrategyBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckpointingStrategy for ComposedCheckpointingStrategy {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
collector: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
let mut saved = false;
|
||||
let mut actions = Vec::new();
|
||||
let mut epochs_to_check = Vec::new();
|
||||
|
||||
for (i, strategy) in self.strategies.iter_mut().enumerate() {
|
||||
let actions = strategy.checkpointing(epoch, collector);
|
||||
// We assume that the strategy would not want the current epoch to be saved.
|
||||
// So we flag it as deleted.
|
||||
if actions.is_empty() {
|
||||
self.deleted
|
||||
.get_mut(i)
|
||||
.expect("As many 'deleted' as 'strategies'.")
|
||||
.insert(epoch);
|
||||
}
|
||||
|
||||
for action in actions {
|
||||
match action {
|
||||
CheckpointingAction::Delete(epoch) => {
|
||||
self.deleted
|
||||
.get_mut(i)
|
||||
.expect("As many 'deleted' as 'strategies'.")
|
||||
.insert(epoch);
|
||||
epochs_to_check.push(epoch);
|
||||
}
|
||||
CheckpointingAction::Save => saved = true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if saved {
|
||||
actions.push(CheckpointingAction::Save);
|
||||
}
|
||||
|
||||
for epoch in epochs_to_check.into_iter() {
|
||||
let mut num_true = 0;
|
||||
for i in 0..self.strategies.len() {
|
||||
if self
|
||||
.deleted
|
||||
.get(i)
|
||||
.expect("Ad many 'deleted' as 'strategies'.")
|
||||
.contains(&epoch)
|
||||
{
|
||||
num_true += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if num_true == self.strategies.len() {
|
||||
actions.push(CheckpointingAction::Delete(epoch));
|
||||
|
||||
for i in 0..self.strategies.len() {
|
||||
self.deleted
|
||||
.get_mut(i)
|
||||
.expect("As many 'deleted' as 'strategies'.")
|
||||
.remove(&epoch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actions
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};
|
||||
|
||||
#[test]
|
||||
fn should_delete_when_both_deletes() {
|
||||
let store = EventStoreClient::new(LogEventStore::default());
|
||||
let mut strategy = ComposedCheckpointingStrategy::builder()
|
||||
.add(KeepLastNCheckpoints::new(1))
|
||||
.add(KeepLastNCheckpoints::new(2))
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(1, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(2, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
|
||||
strategy.checkpointing(3, &store)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
use super::CheckpointingStrategy;
|
||||
use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};
|
||||
|
||||
/// Keep the last N checkpoints.
|
||||
///
|
||||
/// Very useful when training, minimizing disk space while ensuring that the training can be
|
||||
/// resumed even if something goes wrong.
|
||||
#[derive(new)]
|
||||
pub struct KeepLastNCheckpoints {
|
||||
num_keep: usize,
|
||||
}
|
||||
|
||||
impl CheckpointingStrategy for KeepLastNCheckpoints {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
_store: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
let mut actions = vec![CheckpointingAction::Save];
|
||||
|
||||
if let Some(epoch) = usize::checked_sub(epoch, self.num_keep)
|
||||
&& epoch > 0
|
||||
{
|
||||
actions.push(CheckpointingAction::Delete(epoch));
|
||||
}
|
||||
|
||||
actions
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::metric::store::LogEventStore;
|
||||
|
||||
#[test]
|
||||
fn should_always_delete_lastn_epoch_if_higher_than_one() {
|
||||
let mut strategy = KeepLastNCheckpoints::new(2);
|
||||
let store = EventStoreClient::new(LogEventStore::default());
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(1, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(2, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
|
||||
strategy.checkpointing(3, &store)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
use super::CheckpointingStrategy;
|
||||
use crate::{
|
||||
checkpoint::CheckpointingAction,
|
||||
metric::{
|
||||
Metric, MetricName,
|
||||
store::{Aggregate, Direction, EventStoreClient, Split},
|
||||
},
|
||||
};
|
||||
|
||||
/// Keep the best checkpoint based on a metric.
|
||||
pub struct MetricCheckpointingStrategy {
|
||||
current: Option<usize>,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
name: MetricName,
|
||||
}
|
||||
|
||||
impl MetricCheckpointingStrategy {
|
||||
/// Create a new metric checkpointing strategy.
|
||||
pub fn new<M>(metric: &M, aggregate: Aggregate, direction: Direction, split: Split) -> Self
|
||||
where
|
||||
M: Metric,
|
||||
{
|
||||
Self {
|
||||
current: None,
|
||||
name: metric.name(),
|
||||
aggregate,
|
||||
direction,
|
||||
split,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckpointingStrategy for MetricCheckpointingStrategy {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
store: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
let best_epoch =
|
||||
match store.find_epoch(&self.name, self.aggregate, self.direction, &self.split) {
|
||||
Some(epoch_best) => epoch_best,
|
||||
None => epoch,
|
||||
};
|
||||
|
||||
let mut actions = Vec::new();
|
||||
|
||||
if let Some(current) = self.current
|
||||
&& current != best_epoch
|
||||
{
|
||||
actions.push(CheckpointingAction::Delete(current));
|
||||
}
|
||||
|
||||
if best_epoch == epoch {
|
||||
actions.push(CheckpointingAction::Save);
|
||||
}
|
||||
|
||||
self.current = Some(best_epoch);
|
||||
|
||||
actions
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
EventProcessorTraining, TestBackend,
|
||||
logger::InMemoryMetricLogger,
|
||||
metric::{
|
||||
LossMetric,
|
||||
processor::{
|
||||
MetricsTraining, MinimalEventProcessor,
|
||||
test_utils::{end_epoch, process_train},
|
||||
},
|
||||
store::LogEventStore,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn always_keep_the_best_epoch() {
|
||||
let loss = LossMetric::<TestBackend>::new();
|
||||
let mut store = LogEventStore::default();
|
||||
let mut strategy = MetricCheckpointingStrategy::new(
|
||||
&loss,
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Train,
|
||||
);
|
||||
let mut metrics = MetricsTraining::<f64, f64>::default();
|
||||
// Register an in memory logger.
|
||||
store.register_logger(InMemoryMetricLogger::default());
|
||||
// Register the loss metric.
|
||||
metrics.register_train_metric_numeric(loss);
|
||||
let store = Arc::new(EventStoreClient::new(store));
|
||||
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
|
||||
processor.process_train(crate::LearnerEvent::Start);
|
||||
|
||||
// Two points for the first epoch. Mean 0.75
|
||||
let mut epoch = 1;
|
||||
process_train(&mut processor, 1.0, epoch);
|
||||
process_train(&mut processor, 0.5, epoch);
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
// Should save the current record.
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(epoch, &store)
|
||||
);
|
||||
|
||||
// Two points for the second epoch. Mean 0.4
|
||||
epoch += 1;
|
||||
process_train(&mut processor, 0.5, epoch);
|
||||
process_train(&mut processor, 0.3, epoch);
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
// Should save the current record and delete the previous one.
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
|
||||
strategy.checkpointing(epoch, &store)
|
||||
);
|
||||
|
||||
// Two points for the last epoch. Mean 2.0
|
||||
epoch += 1;
|
||||
process_train(&mut processor, 1.0, epoch);
|
||||
process_train(&mut processor, 3.0, epoch);
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
// Should not delete the previous record, since it's the best one, and should not save a
|
||||
// new one.
|
||||
assert!(strategy.checkpointing(epoch, &store).is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
mod base;
|
||||
mod composed;
|
||||
mod lastn;
|
||||
mod metric;
|
||||
|
||||
pub use base::*;
|
||||
pub use composed::*;
|
||||
pub use lastn::*;
|
||||
pub use metric::*;
|
||||
@@ -0,0 +1,66 @@
|
||||
use crate::{InferenceStep, TrainStep};
|
||||
use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
|
||||
use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Components used for a model to learn, grouped in one trait.
|
||||
pub trait LearningComponentsTypes {
|
||||
/// The backend used for training.
|
||||
type Backend: AutodiffBackend;
|
||||
/// The learning rate scheduler used for training.
|
||||
type LrScheduler: LrScheduler + 'static;
|
||||
/// The model to train.
|
||||
type TrainingModel: TrainStep
|
||||
+ AutodiffModule<Self::Backend, InnerModule = Self::InferenceModel>
|
||||
+ core::fmt::Display
|
||||
+ 'static;
|
||||
/// The non-autodiff type of the model.
|
||||
type InferenceModel: InferenceStep;
|
||||
/// The optimizer used for training.
|
||||
type Optimizer: Optimizer<Self::TrainingModel, Self::Backend> + 'static;
|
||||
}
|
||||
|
||||
/// Concrete type that implements the [LearningComponentsTypes](LearningComponentsTypes) trait.
|
||||
pub struct LearningComponentsMarker<B, LR, M, O> {
|
||||
_backend: PhantomData<B>,
|
||||
_lr_scheduler: PhantomData<LR>,
|
||||
_model: PhantomData<M>,
|
||||
_optimizer: PhantomData<O>,
|
||||
}
|
||||
|
||||
impl<B, LR, M, O> LearningComponentsTypes for LearningComponentsMarker<B, LR, M, O>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
LR: LrScheduler + 'static,
|
||||
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
|
||||
M::InnerModule: InferenceStep,
|
||||
O: Optimizer<M, B> + 'static,
|
||||
{
|
||||
type Backend = B;
|
||||
type LrScheduler = LR;
|
||||
type TrainingModel = M;
|
||||
type InferenceModel = M::InnerModule;
|
||||
type Optimizer = O;
|
||||
}
|
||||
|
||||
/// The training backend.
|
||||
pub type TrainingBackend<LC> = <LC as LearningComponentsTypes>::Backend;
|
||||
/// The inference backend.
|
||||
pub(crate) type InferenceBackend<LC> =
|
||||
<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend;
|
||||
/// The model used for training.
|
||||
pub type TrainingModel<LC> = <LC as LearningComponentsTypes>::TrainingModel;
|
||||
/// The non-autodiff model.
|
||||
pub(crate) type InferenceModel<LC> = <LC as LearningComponentsTypes>::InferenceModel;
|
||||
/// Type for training input.
|
||||
pub(crate) type TrainingModelInput<LC> =
|
||||
<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input;
|
||||
/// Type for inference input.
|
||||
pub(crate) type InferenceModelInput<LC> =
|
||||
<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input;
|
||||
/// Type for training output.
|
||||
pub(crate) type TrainingModelOutput<LC> =
|
||||
<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output;
|
||||
/// Type for inference output.
|
||||
pub(crate) type InferenceModelOutput<LC> =
|
||||
<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output;
|
||||
@@ -0,0 +1,72 @@
|
||||
use crate::{
|
||||
AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep,
|
||||
Interrupter, LearnerSummaryConfig,
|
||||
evaluator::components::EvaluatorComponentTypes,
|
||||
metric::processor::{EvaluatorEvent, EventProcessorEvaluation},
|
||||
renderer::{EvaluationName, MetricsRenderer},
|
||||
};
|
||||
use burn_core::{data::dataloader::DataLoader, module::Module};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) type TestBackend<EC> = <EC as EvaluatorComponentTypes>::Backend;
|
||||
pub(crate) type TestInput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Input;
|
||||
pub(crate) type TestOutput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Output;
|
||||
|
||||
pub(crate) type TestLoader<EC> = Arc<dyn DataLoader<TestBackend<EC>, TestInput<EC>>>;
|
||||
|
||||
/// Evaluates a model on a specific dataset.
|
||||
pub struct Evaluator<EC: EvaluatorComponentTypes> {
|
||||
pub(crate) model: EC::Model,
|
||||
pub(crate) interrupter: Interrupter,
|
||||
pub(crate) event_processor:
|
||||
AsyncProcessorEvaluation<FullEventProcessorEvaluation<TestOutput<EC>>>,
|
||||
/// Config for creating a summary of the evaluation
|
||||
pub summary: Option<LearnerSummaryConfig>,
|
||||
}
|
||||
|
||||
impl<EC: EvaluatorComponentTypes> Evaluator<EC> {
|
||||
/// Run the evaluation on the given dataset.
|
||||
///
|
||||
/// The data will be stored and displayed under the provided name.
|
||||
pub fn eval<S: core::fmt::Display>(
|
||||
mut self,
|
||||
name: S,
|
||||
dataloader: TestLoader<EC>,
|
||||
) -> Box<dyn MetricsRenderer> {
|
||||
// Move dataloader to the model device
|
||||
let dataloader = dataloader.to_device(self.model.devices().first().unwrap());
|
||||
|
||||
let name = EvaluationName::new(name);
|
||||
let mut iterator = dataloader.iter();
|
||||
let mut iteration = 0;
|
||||
|
||||
self.event_processor.process_test(EvaluatorEvent::Start);
|
||||
while let Some(item) = iterator.next() {
|
||||
let progress = iterator.progress();
|
||||
iteration += 1;
|
||||
|
||||
let item = self.model.step(item);
|
||||
let item = EvaluationItem::new(item, progress, Some(iteration));
|
||||
|
||||
self.event_processor
|
||||
.process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));
|
||||
|
||||
if self.interrupter.should_stop() {
|
||||
log::info!("Testing interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let summary = self.summary.and_then(|summary| {
|
||||
summary
|
||||
.init()
|
||||
.map(|summary| summary.with_model(self.model.to_string()))
|
||||
.ok()
|
||||
});
|
||||
|
||||
self.event_processor
|
||||
.process_test(EvaluatorEvent::End(summary));
|
||||
|
||||
self.event_processor.renderer()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
use crate::{
|
||||
ApplicationLoggerInstaller, Evaluator, FileApplicationLoggerInstaller, InferenceStep,
|
||||
Interrupter, LearnerSummaryConfig, TestOutput,
|
||||
evaluator::components::{EvaluatorComponentTypes, EvaluatorComponentTypesMarker},
|
||||
logger::FileMetricLogger,
|
||||
metric::{
|
||||
Adaptor, ItemLazy, Metric, Numeric,
|
||||
processor::{AsyncProcessorEvaluation, FullEventProcessorEvaluation, MetricsEvaluation},
|
||||
store::{EventStoreClient, LogEventStore},
|
||||
},
|
||||
renderer::{MetricsRenderer, default_renderer},
|
||||
};
|
||||
use burn_core::{module::Module, prelude::Backend};
|
||||
use std::{
|
||||
collections::BTreeSet,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
/// Struct to configure and create an [evaluator](Evaluator).
|
||||
///
|
||||
/// The generics components of the builder should probably not be set manually, as they are
|
||||
/// optimized for Rust type inference.
|
||||
pub struct EvaluatorBuilder<EC: EvaluatorComponentTypes> {
|
||||
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
event_store: LogEventStore,
|
||||
summary_metrics: BTreeSet<String>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
interrupter: Interrupter,
|
||||
metrics: MetricsEvaluation<TestOutput<EC>>,
|
||||
directory: PathBuf,
|
||||
summary: bool,
|
||||
}
|
||||
|
||||
impl<B, M> EvaluatorBuilder<EvaluatorComponentTypesMarker<B, M>>
|
||||
where
|
||||
B: Backend,
|
||||
M: Module<B> + InferenceStep + core::fmt::Display + 'static,
|
||||
{
|
||||
/// Creates a new evaluator builder.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
pub fn new(directory: impl AsRef<Path>) -> Self {
|
||||
let directory = directory.as_ref().to_path_buf();
|
||||
let log_file = directory.join("evaluation.log");
|
||||
|
||||
Self {
|
||||
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(log_file))),
|
||||
event_store: LogEventStore::default(),
|
||||
summary_metrics: Default::default(),
|
||||
renderer: None,
|
||||
interrupter: Interrupter::new(),
|
||||
summary: false,
|
||||
metrics: MetricsEvaluation::default(),
|
||||
directory,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<EC: EvaluatorComponentTypes> EvaluatorBuilder<EC> {
|
||||
/// Registers [numeric](crate::metric::Numeric) test [metrics](Metric).
|
||||
pub fn metrics<Me: EvalMetricRegistration<EC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Registers text [metrics](Metric).
|
||||
pub fn metrics_text<Me: EvalTextMetricRegistration<EC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// By default, Rust logs are captured and written into
|
||||
/// `evaluation.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
pub fn with_application_logger(
|
||||
mut self,
|
||||
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
) -> Self {
|
||||
self.tracing_logger = logger;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) test [metric](Metric).
|
||||
pub fn metric_numeric<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_test_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a text test [metric](Metric).
|
||||
pub fn metric<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + 'static,
|
||||
<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_test_metric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default CLI renderer with a custom one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The custom renderer.
|
||||
pub fn renderer(mut self, renderer: Box<dyn MetricsRenderer + 'static>) -> Self {
|
||||
self.renderer = Some(renderer);
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable the evaluation summary report.
|
||||
///
|
||||
/// The summary will be displayed at the end of `.eval()`.
|
||||
pub fn summary(mut self) -> Self {
|
||||
self.summary = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the evaluator.
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn build(mut self, model: EC::Model) -> Evaluator<EC> {
|
||||
let renderer = self
|
||||
.renderer
|
||||
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), None));
|
||||
|
||||
self.event_store
|
||||
.register_logger(FileMetricLogger::new_eval(self.directory.clone()));
|
||||
let event_store = Arc::new(EventStoreClient::new(self.event_store));
|
||||
|
||||
let event_processor = AsyncProcessorEvaluation::new(FullEventProcessorEvaluation::new(
|
||||
self.metrics,
|
||||
renderer,
|
||||
event_store,
|
||||
));
|
||||
|
||||
let summary = if self.summary {
|
||||
Some(LearnerSummaryConfig {
|
||||
directory: self.directory,
|
||||
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Evaluator {
|
||||
model,
|
||||
interrupter: self.interrupter,
|
||||
event_processor,
|
||||
summary,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics.
|
||||
pub trait EvalMetricRegistration<EC: EvaluatorComponentTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: EvaluatorBuilder<EC>) -> EvaluatorBuilder<EC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics.
|
||||
pub trait EvalTextMetricRegistration<EC: EvaluatorComponentTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: EvaluatorBuilder<EC>) -> EvaluatorBuilder<EC>;
|
||||
}
|
||||
|
||||
macro_rules! gen_tuple {
|
||||
($($M:ident),*) => {
|
||||
impl<$($M,)* EC: EvaluatorComponentTypes> EvalTextMetricRegistration<EC> for ($($M,)*)
|
||||
where
|
||||
$(<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: EvaluatorBuilder<EC>,
|
||||
) -> EvaluatorBuilder<EC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric($M);)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* EC: EvaluatorComponentTypes> EvalMetricRegistration<EC> for ($($M,)*)
|
||||
where
|
||||
$(<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + $crate::metric::Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: EvaluatorBuilder<EC>,
|
||||
) -> EvaluatorBuilder<EC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_numeric($M);)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
gen_tuple!(M1);
|
||||
gen_tuple!(M1, M2);
|
||||
gen_tuple!(M1, M2, M3);
|
||||
gen_tuple!(M1, M2, M3, M4);
|
||||
gen_tuple!(M1, M2, M3, M4, M5);
|
||||
gen_tuple!(M1, M2, M3, M4, M5, M6);
|
||||
@@ -0,0 +1,25 @@
|
||||
use crate::InferenceStep;
|
||||
use burn_core::{module::Module, prelude::Backend};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// All components necessary to evaluate a model grouped in one trait.
|
||||
pub trait EvaluatorComponentTypes {
|
||||
/// The backend in used for the evaluation.
|
||||
type Backend: Backend;
|
||||
/// The model to evaluate.
|
||||
type Model: Module<Self::Backend> + InferenceStep + core::fmt::Display + 'static;
|
||||
}
|
||||
|
||||
/// A marker type used to provide [evaluation components](EvaluatorComponentTypes).
|
||||
pub struct EvaluatorComponentTypesMarker<B, M> {
|
||||
_p: PhantomData<(B, M)>,
|
||||
}
|
||||
|
||||
impl<B, M> EvaluatorComponentTypes for EvaluatorComponentTypesMarker<B, M>
|
||||
where
|
||||
B: Backend,
|
||||
M: Module<B> + InferenceStep + core::fmt::Display + 'static,
|
||||
{
|
||||
type Backend = B;
|
||||
type Model = M;
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod base;
|
||||
mod builder;
|
||||
|
||||
pub(crate) mod components;
|
||||
|
||||
pub use base::*;
|
||||
pub use builder::*;
|
||||
@@ -0,0 +1,69 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing_core::{Level, LevelFilter};
|
||||
use tracing_subscriber::filter::filter_fn;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{Layer, registry};
|
||||
|
||||
/// This trait is used to install an application logger.
|
||||
pub trait ApplicationLoggerInstaller {
|
||||
/// Install the application logger.
|
||||
fn install(&self) -> Result<(), String>;
|
||||
}
|
||||
|
||||
/// This struct is used to install a local file application logger to output logs to a given file path.
|
||||
pub struct FileApplicationLoggerInstaller {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl FileApplicationLoggerInstaller {
|
||||
/// Create a new file application logger.
|
||||
pub fn new(path: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
path: path.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
|
||||
fn install(&self) -> Result<(), String> {
|
||||
let path = Path::new(&self.path);
|
||||
let writer = tracing_appender::rolling::never(
|
||||
path.parent().unwrap_or_else(|| Path::new(".")),
|
||||
path.file_name().unwrap_or_else(|| {
|
||||
panic!("The path '{}' to point to a file.", self.path.display())
|
||||
}),
|
||||
);
|
||||
let layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
.with_writer(writer)
|
||||
.with_filter(LevelFilter::INFO)
|
||||
.with_filter(filter_fn(|m| {
|
||||
if let Some(path) = m.module_path() {
|
||||
// The wgpu crate is logging too much, so we skip `info` level.
|
||||
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}));
|
||||
|
||||
if registry().with(layer).try_init().is_err() {
|
||||
return Err("Failed to install the file logger.".to_string());
|
||||
}
|
||||
|
||||
let hook = std::panic::take_hook();
|
||||
let file_path = self.path.to_owned();
|
||||
|
||||
std::panic::set_hook(Box::new(move |info| {
|
||||
log::error!("PANIC => {info}");
|
||||
eprintln!(
|
||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
|
||||
'{}'\n=============",
|
||||
file_path.display()
|
||||
);
|
||||
hook(info);
|
||||
}));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
use crate::LearningComponentsMarker;
|
||||
use crate::checkpoint::{
|
||||
AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
|
||||
};
|
||||
use crate::components::{LearningComponentsTypes, TrainingBackend};
|
||||
use crate::metric::store::EventStoreClient;
|
||||
use crate::{
|
||||
CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
|
||||
TrainingModelOutput,
|
||||
};
|
||||
use burn_core::module::{AutodiffModule, Module};
|
||||
use burn_core::prelude::Backend;
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_optim::lr_scheduler::LrScheduler;
|
||||
use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// The record of the learner's model.
|
||||
pub type LearnerModelRecord<LC> =
|
||||
<<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
|
||||
/// The record of the optimizer.
|
||||
pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
|
||||
<LC as LearningComponentsTypes>::TrainingModel,
|
||||
TrainingBackend<LC>,
|
||||
>>::Record;
|
||||
/// The record of the LR scheduler.
|
||||
pub type LearnerSchedulerRecord<LC> =
|
||||
<<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
|
||||
|
||||
/// Learner struct encapsulating all components necessary to train a Neural Network model.
|
||||
pub struct Learner<LC: LearningComponentsTypes> {
|
||||
pub(crate) model: LC::TrainingModel,
|
||||
optim: LC::Optimizer,
|
||||
lr_scheduler: LC::LrScheduler,
|
||||
lr: f64,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
model: self.model.clone(),
|
||||
optim: self.optim.clone(),
|
||||
lr_scheduler: self.lr_scheduler.clone(),
|
||||
lr: self.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
LR: LrScheduler + 'static,
|
||||
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
|
||||
M::InnerModule: InferenceStep,
|
||||
O: Optimizer<M, B> + 'static,
|
||||
{
|
||||
/// Create a learner.
|
||||
pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
|
||||
Self {
|
||||
model,
|
||||
optim,
|
||||
lr_scheduler,
|
||||
lr: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> Learner<LC> {
|
||||
/// Fork the learner's model to the given device.
|
||||
pub fn fork(&mut self, device: &<TrainingBackend<LC> as Backend>::Device) {
|
||||
self.model = self.model().fork(device);
|
||||
}
|
||||
|
||||
/// Returns the current model.
|
||||
pub fn model(&self) -> LC::TrainingModel {
|
||||
self.model.clone()
|
||||
}
|
||||
|
||||
/// Returns the current learning rate.
|
||||
pub fn lr_current(&self) -> f64 {
|
||||
self.lr
|
||||
}
|
||||
|
||||
/// Executes a step of the learning rate scheduler.
|
||||
pub fn lr_step(&mut self) {
|
||||
self.lr = self.lr_scheduler.step();
|
||||
}
|
||||
|
||||
/// Runs a step of the model for training, which executes the forward and backward passes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The input for the model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output containing the model output and the gradients.
|
||||
pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
|
||||
self.model.step(item)
|
||||
}
|
||||
|
||||
/// Optimize the current module with the provided gradients and learning rate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `optim`: Optimizer used for learning.
|
||||
/// * `lr`: The learning rate used for this step.
|
||||
/// * `grads`: The gradients of each parameter in the current model.
|
||||
pub fn optimizer_step(&mut self, grads: GradientsParams) {
|
||||
self.model = self.model().optimize(&mut self.optim, self.lr, grads);
|
||||
}
|
||||
|
||||
/// Optimize the current module with the provided gradients and learning rate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `optim`: Optimizer used for learning.
|
||||
/// * `lr`: The learning rate used for this step.
|
||||
/// * `grads`: Multiple gradients associated to each parameter in the current model.
|
||||
pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
|
||||
self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
|
||||
}
|
||||
|
||||
/// Load the module state from a [record](LearnerModelRecord<LC>).
|
||||
pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
|
||||
self.model = self.model.clone().load_record(record);
|
||||
}
|
||||
|
||||
/// Load the state of the learner's optimizer as a [record](LearnerOptimizerRecord<LC>).
|
||||
pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
|
||||
self.optim = self.optim.clone().load_record(record);
|
||||
}
|
||||
|
||||
/// Load the state of the learner's scheduler as a [record](LearnerSchedulerRecord<LC>).
|
||||
pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
|
||||
self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
/// Used to create, delete, or load checkpoints of the training process.
|
||||
pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
|
||||
model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
|
||||
optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
|
||||
lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
|
||||
strategy: Box<dyn CheckpointingStrategy>,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
|
||||
/// Create checkpoint for the training process.
|
||||
pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
|
||||
let actions = self.strategy.checkpointing(epoch, store);
|
||||
|
||||
for action in actions {
|
||||
match action {
|
||||
CheckpointingAction::Delete(epoch) => {
|
||||
self.model
|
||||
.delete(epoch)
|
||||
.expect("Can delete model checkpoint.");
|
||||
self.optim
|
||||
.delete(epoch)
|
||||
.expect("Can delete optimizer checkpoint.");
|
||||
self.lr_scheduler
|
||||
.delete(epoch)
|
||||
.expect("Can delete learning rate scheduler checkpoint.");
|
||||
}
|
||||
CheckpointingAction::Save => {
|
||||
self.model
|
||||
.save(epoch, learner.model.clone().into_record())
|
||||
.expect("Can save model checkpoint.");
|
||||
self.optim
|
||||
.save(epoch, learner.optim.to_record())
|
||||
.expect("Can save optimizer checkpoint.");
|
||||
self.lr_scheduler
|
||||
.save(epoch, learner.lr_scheduler.to_record())
|
||||
.expect("Can save learning rate scheduler checkpoint.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a training checkpoint.
|
||||
pub fn load_checkpoint(
|
||||
&self,
|
||||
mut learner: Learner<LC>,
|
||||
device: &Device<LC::Backend>,
|
||||
epoch: usize,
|
||||
) -> Learner<LC> {
|
||||
let record = self
|
||||
.model
|
||||
.restore(epoch, device)
|
||||
.expect("Can load model checkpoint.");
|
||||
learner.load_model(record);
|
||||
|
||||
let record = self
|
||||
.optim
|
||||
.restore(epoch, device)
|
||||
.expect("Can load optimizer checkpoint.");
|
||||
learner.load_optim(record);
|
||||
|
||||
let record = self
|
||||
.lr_scheduler
|
||||
.restore(epoch, device)
|
||||
.expect("Can load learning rate scheduler checkpoint.");
|
||||
learner.load_scheduler(record);
|
||||
|
||||
learner
|
||||
}
|
||||
}
|
||||
|
||||
/// Cloneable reference to an early stopping strategy
|
||||
pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
/// A handle that allows aborting the training/evaluation process early.
|
||||
pub struct Interrupter {
|
||||
state: Arc<AtomicBool>,
|
||||
message: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
impl Interrupter {
|
||||
/// Create a new instance.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Notify the learner that it should stop.
|
||||
/// # Arguments
|
||||
/// * `reason` - A string describing the reason the training was stopped.
|
||||
pub fn stop(&self, reason: Option<&str>) {
|
||||
self.state.store(true, Ordering::Relaxed);
|
||||
reason.inspect(|r| {
|
||||
let mut message = self.message.lock().unwrap();
|
||||
*message = Some(String::from(*r));
|
||||
});
|
||||
}
|
||||
|
||||
/// Reset the interrupter.
|
||||
pub fn reset(&self) {
|
||||
self.state.store(false, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// True if .stop() has been called.
|
||||
pub fn should_stop(&self) -> bool {
|
||||
self.state.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get the message associated with the interrupt.
|
||||
pub fn get_message(&self) -> Option<String> {
|
||||
let message = self.message.lock().unwrap();
|
||||
message.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
use crate::metric::{
|
||||
AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,
|
||||
PerplexityInput, TopKAccuracyInput, processor::ItemLazy,
|
||||
};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{Int, Tensor, Transaction};
|
||||
use burn_ndarray::NdArray;
|
||||
|
||||
/// Simple classification output adapted for multiple metrics.
|
||||
///
|
||||
/// Supported metrics:
|
||||
/// - Accuracy
|
||||
/// - AUROC
|
||||
/// - TopKAccuracy
|
||||
/// - Perplexity
|
||||
/// - Precision (via ConfusionStatsInput)
|
||||
/// - Recall (via ConfusionStatsInput)
|
||||
/// - FBetaScore (via ConfusionStatsInput)
|
||||
/// - Loss.
|
||||
#[derive(new)]
|
||||
pub struct ClassificationOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// The class logits or probabilities. Shape: \[batch_size, num_classes\].
|
||||
pub output: Tensor<B, 2>,
|
||||
|
||||
/// The ground truth class index for each sample. Shape: \[batch_size\].
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ItemLazy for ClassificationOutput<B> {
|
||||
type ItemSync = ClassificationOutput<NdArray>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
let [output, loss, targets] = Transaction::default()
|
||||
.register(self.output)
|
||||
.register(self.loss)
|
||||
.register(self.targets)
|
||||
.execute()
|
||||
.try_into()
|
||||
.expect("Correct amount of tensor data");
|
||||
|
||||
let device = &Default::default();
|
||||
|
||||
ClassificationOutput {
|
||||
output: Tensor::from_data(output, device),
|
||||
loss: Tensor::from_data(loss, device),
|
||||
targets: Tensor::from_data(targets, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> AccuracyInput<B> {
|
||||
AccuracyInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<AurocInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> AurocInput<B> {
|
||||
AurocInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> TopKAccuracyInput<B> {
|
||||
TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> PerplexityInput<B> {
|
||||
PerplexityInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> ConfusionStatsInput<B> {
|
||||
let [_, num_classes] = self.output.dims();
|
||||
if num_classes > 1 {
|
||||
ConfusionStatsInput::new(
|
||||
self.output.clone(),
|
||||
self.targets.clone().one_hot(num_classes).bool(),
|
||||
)
|
||||
} else {
|
||||
ConfusionStatsInput::new(
|
||||
self.output.clone(),
|
||||
self.targets.clone().unsqueeze_dim(1).bool(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-label classification output adapted for multiple metrics.
|
||||
///
|
||||
/// Supported metrics:
|
||||
/// - HammingScore
|
||||
/// - Precision (via ConfusionStatsInput)
|
||||
/// - Recall (via ConfusionStatsInput)
|
||||
/// - FBetaScore (via ConfusionStatsInput)
|
||||
/// - Loss
|
||||
#[derive(new)]
|
||||
pub struct MultiLabelClassificationOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// The label logits or probabilities. Shape: \[batch_size, num_classes\].
|
||||
pub output: Tensor<B, 2>,
|
||||
|
||||
/// The ground truth labels. Shape: \[batch_size, num_classes\].
|
||||
pub targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
|
||||
type ItemSync = MultiLabelClassificationOutput<NdArray>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
let [output, loss, targets] = Transaction::default()
|
||||
.register(self.output)
|
||||
.register(self.loss)
|
||||
.register(self.targets)
|
||||
.execute()
|
||||
.try_into()
|
||||
.expect("Correct amount of tensor data");
|
||||
|
||||
let device = &Default::default();
|
||||
|
||||
MultiLabelClassificationOutput {
|
||||
output: Tensor::from_data(output, device),
|
||||
loss: Tensor::from_data(loss, device),
|
||||
targets: Tensor::from_data(targets, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
|
||||
fn adapt(&self) -> HammingScoreInput<B> {
|
||||
HammingScoreInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
|
||||
fn adapt(&self) -> ConfusionStatsInput<B> {
|
||||
ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
use crate::metric::{
|
||||
Metric, MetricName,
|
||||
store::{Aggregate, Direction, EventStoreClient, Split},
|
||||
};
|
||||
|
||||
/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
|
||||
#[derive(Clone)]
|
||||
pub enum StoppingCondition {
|
||||
/// When no improvement has happened since the given number of epochs.
|
||||
NoImprovementSince {
|
||||
/// The number of epochs allowed to worsen before it gets better.
|
||||
n_epochs: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// A strategy that checks if the training should be stopped.
|
||||
pub trait EarlyStoppingStrategy: Send {
|
||||
/// Update its current state and returns if the training should be stopped.
|
||||
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
|
||||
}
|
||||
|
||||
/// A helper trait to provide type-erased cloning.
|
||||
pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
|
||||
/// Clone into a boxed trait object.
|
||||
fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
|
||||
}
|
||||
|
||||
/// Blanket-implement `CloneEarlyStoppingStrategy` for any `T` that
|
||||
/// already implements your strategy + `Clone` + `Send` + `'static`.
|
||||
impl<T> CloneEarlyStoppingStrategy for T
|
||||
where
|
||||
T: EarlyStoppingStrategy + Clone + Send + 'static,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Now you can `impl Clone` for the boxed trait object.
|
||||
impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
|
||||
fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
|
||||
/// during training or validation.
|
||||
#[derive(Clone)]
|
||||
pub struct MetricEarlyStoppingStrategy {
|
||||
condition: StoppingCondition,
|
||||
metric_name: MetricName,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
best_epoch: usize,
|
||||
best_value: f64,
|
||||
warmup_epochs: Option<usize>,
|
||||
}
|
||||
|
||||
impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
|
||||
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
|
||||
let current_value =
|
||||
match store.find_metric(&self.metric_name, epoch, self.aggregate, &self.split) {
|
||||
Some(value) => value,
|
||||
None => {
|
||||
log::warn!("Can't find metric for early stopping.");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
let is_best = match self.direction {
|
||||
Direction::Lowest => current_value < self.best_value,
|
||||
Direction::Highest => current_value > self.best_value,
|
||||
};
|
||||
|
||||
if is_best {
|
||||
log::info!(
|
||||
"New best epoch found {} {}: {}",
|
||||
epoch,
|
||||
self.metric_name,
|
||||
current_value
|
||||
);
|
||||
self.best_value = current_value;
|
||||
self.best_epoch = epoch;
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(warmup_epochs) = self.warmup_epochs
|
||||
&& epoch <= warmup_epochs
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
match self.condition {
|
||||
StoppingCondition::NoImprovementSince { n_epochs } => {
|
||||
let should_stop = epoch - self.best_epoch >= n_epochs;
|
||||
|
||||
if should_stop {
|
||||
log::info!(
|
||||
"Stopping training loop, no improvement since epoch {}, {}: {}, current \
|
||||
epoch {}, {}: {}",
|
||||
self.best_epoch,
|
||||
self.metric_name,
|
||||
self.best_value,
|
||||
epoch,
|
||||
self.metric_name,
|
||||
current_value
|
||||
);
|
||||
}
|
||||
|
||||
should_stop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricEarlyStoppingStrategy {
|
||||
/// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
|
||||
/// during training or validation.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The metric should be registered for early stopping to work, otherwise no data is collected.
|
||||
pub fn new<Me: Metric>(
|
||||
metric: &Me,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
condition: StoppingCondition,
|
||||
) -> Self {
|
||||
let init_value = match direction {
|
||||
Direction::Lowest => f64::MAX,
|
||||
Direction::Highest => f64::MIN,
|
||||
};
|
||||
|
||||
Self {
|
||||
metric_name: metric.name(),
|
||||
condition,
|
||||
aggregate,
|
||||
direction,
|
||||
split,
|
||||
best_epoch: 1,
|
||||
best_value: init_value,
|
||||
warmup_epochs: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the warmup period.
|
||||
///
|
||||
/// Early stopping will not trigger during the warmup epochs.
|
||||
pub fn warmup_epochs(&self) -> Option<usize> {
|
||||
self.warmup_epochs
|
||||
}
|
||||
|
||||
/// Set the warmup epochs.
|
||||
///
|
||||
/// Early stopping will not trigger during the warmup epochs.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `warmup`: the number of warmup epochs, or None.
|
||||
pub fn with_warmup_epochs(self, warmup: Option<usize>) -> Self {
|
||||
Self {
|
||||
warmup_epochs: warmup,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
EventProcessorTraining, TestBackend,
|
||||
logger::InMemoryMetricLogger,
|
||||
metric::{
|
||||
LossMetric,
|
||||
processor::{
|
||||
MetricsTraining, MinimalEventProcessor,
|
||||
test_utils::{end_epoch, process_train},
|
||||
},
|
||||
store::LogEventStore,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn never_early_stop_while_it_is_improving() {
|
||||
test_early_stopping(
|
||||
None,
|
||||
1,
|
||||
&[
|
||||
(&[0.5, 0.3], false, "Should not stop first epoch"),
|
||||
(&[0.4, 0.3], false, "Should not stop when improving"),
|
||||
(&[0.3, 0.3], false, "Should not stop when improving"),
|
||||
(&[0.2, 0.3], false, "Should not stop when improving"),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stop_when_no_improvement_since_two_epochs() {
|
||||
test_early_stopping(
|
||||
None,
|
||||
2,
|
||||
&[
|
||||
(&[1.0, 0.5], false, "Should not stop first epoch"),
|
||||
(&[0.5, 0.3], false, "Should not stop when improving"),
|
||||
(
|
||||
&[1.0, 3.0],
|
||||
false,
|
||||
"Should not stop first time it gets worse",
|
||||
),
|
||||
(
|
||||
&[1.0, 2.0],
|
||||
true,
|
||||
"Should stop since two following epochs didn't improve",
|
||||
),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stopping_with_warmup() {
|
||||
test_early_stopping(
|
||||
Some(3),
|
||||
2,
|
||||
&[
|
||||
(&[1.0, 0.5], false, "Should not stop during warmup"),
|
||||
(&[1.0, 0.5], false, "Should not stop during warmup"),
|
||||
(&[1.0, 0.5], false, "Should not stop during warmup"),
|
||||
(
|
||||
&[1.0, 0.5],
|
||||
true,
|
||||
"Should stop when not improving after warmup",
|
||||
),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stop_when_stays_equal() {
|
||||
test_early_stopping(
|
||||
None,
|
||||
2,
|
||||
&[
|
||||
(&[0.5, 0.3], false, "Should not stop first epoch"),
|
||||
(
|
||||
&[0.5, 0.3],
|
||||
false,
|
||||
"Should not stop first time it stars the same",
|
||||
),
|
||||
(
|
||||
&[0.5, 0.3],
|
||||
true,
|
||||
"Should stop since two following epochs didn't improve",
|
||||
),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
fn test_early_stopping(warmup: Option<usize>, n_epochs: usize, data: &[(&[f64], bool, &str)]) {
|
||||
let loss = LossMetric::<TestBackend>::new();
|
||||
let mut early_stopping = MetricEarlyStoppingStrategy::new(
|
||||
&loss,
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Train,
|
||||
StoppingCondition::NoImprovementSince { n_epochs },
|
||||
)
|
||||
.with_warmup_epochs(warmup);
|
||||
let mut store = LogEventStore::default();
|
||||
let mut metrics = MetricsTraining::<f64, f64>::default();
|
||||
|
||||
store.register_logger(InMemoryMetricLogger::default());
|
||||
metrics.register_train_metric_numeric(loss);
|
||||
|
||||
let store = Arc::new(EventStoreClient::new(store));
|
||||
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
|
||||
|
||||
let mut epoch = 1;
|
||||
processor.process_train(crate::LearnerEvent::Start);
|
||||
for (points, should_start, comment) in data {
|
||||
for point in points.iter() {
|
||||
process_train(&mut processor, *point, epoch);
|
||||
}
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
assert_eq!(
|
||||
*should_start,
|
||||
early_stopping.should_stop(epoch, &store),
|
||||
"{comment}"
|
||||
);
|
||||
epoch += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
#[cfg(feature = "rl")]
|
||||
mod rl;
|
||||
#[cfg(feature = "rl")]
|
||||
pub use rl::*;
|
||||
|
||||
mod application_logger;
|
||||
mod base;
|
||||
mod classification;
|
||||
mod early_stopping;
|
||||
mod regression;
|
||||
mod sequence;
|
||||
mod summary;
|
||||
mod supervised;
|
||||
mod train_val;
|
||||
|
||||
pub use application_logger::*;
|
||||
pub use base::*;
|
||||
pub use classification::*;
|
||||
pub use early_stopping::*;
|
||||
pub use regression::*;
|
||||
pub use sequence::*;
|
||||
pub use summary::*;
|
||||
pub use supervised::*;
|
||||
pub use train_val::*;
|
||||
@@ -0,0 +1,46 @@
|
||||
use crate::metric::processor::ItemLazy;
|
||||
use crate::metric::{Adaptor, LossInput};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{Tensor, Transaction};
|
||||
use burn_ndarray::NdArray;
|
||||
|
||||
/// Regression output adapted for the loss metric.
|
||||
#[derive(new)]
|
||||
pub struct RegressionOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// The predicted values. Shape: \[batch_size, num_targets\].
|
||||
pub output: Tensor<B, 2>,
|
||||
|
||||
/// The ground truth values. Shape: \[batch_size, num_targets\].
|
||||
pub targets: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ItemLazy for RegressionOutput<B> {
|
||||
type ItemSync = RegressionOutput<NdArray>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
let [output, loss, targets] = Transaction::default()
|
||||
.register(self.output)
|
||||
.register(self.loss)
|
||||
.register(self.targets)
|
||||
.execute()
|
||||
.try_into()
|
||||
.expect("Correct amount of tensor data");
|
||||
|
||||
let device = &Default::default();
|
||||
|
||||
RegressionOutput {
|
||||
output: Tensor::from_data(output, device),
|
||||
loss: Tensor::from_data(loss, device),
|
||||
targets: Tensor::from_data(targets, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
use burn_core::tensor::Device;
|
||||
use burn_rl::{Policy, PolicyLearner, PolicyState};
|
||||
|
||||
use crate::RLAgentRecord;
|
||||
use crate::{
|
||||
RLComponentsTypes, RLPolicyRecord,
|
||||
checkpoint::Checkpointer,
|
||||
checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},
|
||||
metric::store::EventStoreClient,
|
||||
};
|
||||
|
||||
#[derive(new)]
|
||||
/// Used to create, delete, or load checkpoints of the training process.
|
||||
pub struct RLCheckpointer<RLC: RLComponentsTypes> {
|
||||
policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
|
||||
learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
|
||||
strategy: Box<dyn CheckpointingStrategy>,
|
||||
}
|
||||
|
||||
impl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {
|
||||
/// Create checkpoint for the training process.
|
||||
pub fn checkpoint(
|
||||
&mut self,
|
||||
policy: &RLC::PolicyState,
|
||||
learning_agent: &RLC::LearningAgent,
|
||||
epoch: usize,
|
||||
store: &EventStoreClient,
|
||||
) {
|
||||
let actions = self.strategy.checkpointing(epoch, store);
|
||||
|
||||
for action in actions {
|
||||
match action {
|
||||
CheckpointingAction::Delete(epoch) => {
|
||||
self.policy
|
||||
.delete(epoch)
|
||||
.expect("Can delete policy checkpoint.");
|
||||
self.learning_agent
|
||||
.delete(epoch)
|
||||
.expect("Can delete learning agent checkpoint.")
|
||||
}
|
||||
CheckpointingAction::Save => {
|
||||
self.policy
|
||||
.save(epoch, policy.clone().into_record())
|
||||
.expect("Can save policy checkpoint.");
|
||||
self.learning_agent
|
||||
.save(epoch, learning_agent.record())
|
||||
.expect("Can save learning agent checkpoint.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a training checkpoint.
|
||||
pub fn load_checkpoint(
|
||||
&self,
|
||||
learning_agent: RLC::LearningAgent,
|
||||
device: &Device<RLC::Backend>,
|
||||
epoch: usize,
|
||||
) -> RLC::LearningAgent {
|
||||
let record = self
|
||||
.policy
|
||||
.restore(epoch, device)
|
||||
.expect("Can load model checkpoint.");
|
||||
let policy = learning_agent.policy().load_record(record);
|
||||
|
||||
let record = self
|
||||
.learning_agent
|
||||
.restore(epoch, device)
|
||||
.expect("Can load learning agent checkpoint.");
|
||||
let mut learning_agent = learning_agent.load_record(record);
|
||||
learning_agent.update_policy(policy);
|
||||
|
||||
learning_agent
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};
|
||||
|
||||
use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};
|
||||
|
||||
/// All components used by the reinforcement learning paradigm, grouped in one trait.
|
||||
pub trait RLComponentsTypes {
|
||||
/// The backend used for training.
|
||||
type Backend: AutodiffBackend;
|
||||
/// The learning environment.
|
||||
type Env: Environment<State = Self::State, Action = Self::Action> + 'static;
|
||||
/// Specifies how to initialize the environment.
|
||||
type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;
|
||||
/// The type of the environment state.
|
||||
type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;
|
||||
/// The type of the environment action.
|
||||
type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>
|
||||
+ Into<<Self::Policy as Policy<Self::Backend>>::Action>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static;
|
||||
|
||||
/// The policy used to take actions in the environment.
|
||||
type Policy: Policy<
|
||||
Self::Backend,
|
||||
Observation = Self::PolicyObs,
|
||||
ActionDistribution = Self::PolicyAD,
|
||||
Action = Self::PolicyAction,
|
||||
ActionContext = Self::ActionContext,
|
||||
PolicyState = Self::PolicyState,
|
||||
> + Send
|
||||
+ 'static;
|
||||
/// The policy's observation type.
|
||||
type PolicyObs: Clone + Send + Batchable + 'static;
|
||||
/// The policy's action distribution type.
|
||||
type PolicyAD: Clone + Send + Batchable;
|
||||
/// The policy's action type.
|
||||
type PolicyAction: Clone + Send + Batchable;
|
||||
/// Additional data as context for an agent's action.
|
||||
type ActionContext: ItemLazy + Clone + Send + 'static;
|
||||
/// The state of the parameterized policy.
|
||||
type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;
|
||||
|
||||
/// The learning agent.
|
||||
type LearningAgent: PolicyLearner<
|
||||
Self::Backend,
|
||||
TrainContext = Self::TrainingOutput,
|
||||
InnerPolicy = Self::Policy,
|
||||
> + Send
|
||||
+ 'static;
|
||||
/// The output data of a training step.
|
||||
type TrainingOutput: ItemLazy + Clone + Send;
|
||||
}
|
||||
|
||||
/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait.
|
||||
pub struct RLComponentsMarker<B, E, EI, A> {
|
||||
_backend: PhantomData<B>,
|
||||
_env: PhantomData<E>,
|
||||
_env_init: PhantomData<EI>,
|
||||
_agent: PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
E: Environment + 'static,
|
||||
EI: EnvironmentInit<E> + Send + 'static,
|
||||
A: PolicyLearner<B> + Send + 'static,
|
||||
A::TrainContext: ItemLazy + Clone + Send,
|
||||
A::InnerPolicy: Policy<B> + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
|
||||
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
|
||||
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
|
||||
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Into<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
type Backend = B;
|
||||
type Env = E;
|
||||
type EnvInit = EI;
|
||||
type LearningAgent = A;
|
||||
type Policy = A::InnerPolicy;
|
||||
type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;
|
||||
type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;
|
||||
type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;
|
||||
type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;
|
||||
type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;
|
||||
type TrainingOutput = A::TrainContext;
|
||||
type State = E::State;
|
||||
type Action = E::Action;
|
||||
}
|
||||
|
||||
pub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
|
||||
<RLC as RLComponentsTypes>::Backend,
|
||||
>>::InnerPolicy;
|
||||
/// The event processor type for reinforcement learning.
|
||||
pub type RLEventProcessorType<RLC> = AsyncProcessorTraining<
|
||||
RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,
|
||||
AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,
|
||||
>;
|
||||
/// The record of the policy.
|
||||
pub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<
|
||||
<RLC as RLComponentsTypes>::Backend,
|
||||
>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;
|
||||
/// The record of the learning agent.
|
||||
pub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
|
||||
<RLC as RLComponentsTypes>::Backend,
|
||||
>>::Record;
|
||||
@@ -0,0 +1,703 @@
|
||||
use rand::prelude::SliceRandom;
|
||||
use std::{
|
||||
sync::mpsc::{Receiver, Sender},
|
||||
thread::spawn,
|
||||
};
|
||||
|
||||
use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};
|
||||
use burn_rl::EnvironmentInit;
|
||||
use burn_rl::Policy;
|
||||
use burn_rl::Transition;
|
||||
use burn_rl::{AsyncPolicy, Environment};
|
||||
|
||||
use crate::{
|
||||
AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
|
||||
Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,
|
||||
RlPolicy, TimeStep, Trajectory,
|
||||
};
|
||||
|
||||
enum RequestMessage {
|
||||
Step(),
|
||||
Episode(),
|
||||
}
|
||||
|
||||
/// Configuration for an async agent/environment loop.
|
||||
pub struct AsyncAgentEnvLoopConfig {
|
||||
/// If the loop is used for evaluation (as opposed to training).
|
||||
pub eval: bool,
|
||||
/// If the agent should take action deterministically.
|
||||
pub deterministic: bool,
|
||||
/// An arbitrary ID for the loop.
|
||||
pub id: usize,
|
||||
}
|
||||
|
||||
/// An asynchronous agent/environement interface.
|
||||
pub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {
|
||||
eval: bool,
|
||||
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
|
||||
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
|
||||
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
|
||||
request_sender: Sender<RequestMessage>,
|
||||
}
|
||||
|
||||
impl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {
|
||||
/// Create a new asynchronous runner.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `env_init` - A function returning an environment instance.
|
||||
/// * `agent` - An [AsyncPolicy](AsyncPolicy) taking actions in the loop.
|
||||
/// * `config` - An [AsyncAgentEnvLoopConfig](AsyncAgentEnvLoopConfig).
|
||||
/// * `transition_sender` - Optional sender for transitions if you want to drive the requests from outside of the loop instance.
|
||||
/// * `trajectory_sender` - Optional sender for trajectories if you want to drive the requests from outside of the loop instance.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An async Agent/Environement loop.
|
||||
pub fn new(
|
||||
env_init: RLC::EnvInit,
|
||||
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
|
||||
config: AsyncAgentEnvLoopConfig,
|
||||
transition_device: &Device<BT>,
|
||||
transition_sender: Option<Sender<RLTimeStep<BT, RLC>>>,
|
||||
trajectory_sender: Option<Sender<RLTrajectory<BT, RLC>>>,
|
||||
) -> Self {
|
||||
let (loop_transition_sender, transition_receiver) = std::sync::mpsc::channel();
|
||||
let (loop_trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
|
||||
let (request_sender, request_receiver) = std::sync::mpsc::channel();
|
||||
let loop_transition_sender = transition_sender.unwrap_or(loop_transition_sender);
|
||||
let loop_trajectory_sender = trajectory_sender.unwrap_or(loop_trajectory_sender);
|
||||
|
||||
let device = transition_device.clone();
|
||||
let mut loop_agent = agent.clone();
|
||||
let eval = config.eval;
|
||||
|
||||
let mut current_steps = vec![];
|
||||
let mut current_reward = 0.0;
|
||||
let mut step_num = 0;
|
||||
spawn(move || {
|
||||
let mut env = env_init.init();
|
||||
env.reset();
|
||||
|
||||
let mut request_episode = false;
|
||||
loop {
|
||||
let state = env.state();
|
||||
let (action, context) =
|
||||
loop_agent.action(state.clone().into(), config.deterministic);
|
||||
|
||||
let env_action = RLC::Action::from(action);
|
||||
let step_result = env.step(env_action.clone());
|
||||
|
||||
current_reward += step_result.reward;
|
||||
step_num += 1;
|
||||
|
||||
let transition = Transition::new(
|
||||
state.clone(),
|
||||
step_result.next_state,
|
||||
env_action,
|
||||
Tensor::from_data([step_result.reward], &device),
|
||||
Tensor::from_data(
|
||||
[(step_result.done || step_result.truncated) as i32 as f64],
|
||||
&device,
|
||||
),
|
||||
);
|
||||
|
||||
if !request_episode {
|
||||
loop_agent.decrement_agents(1);
|
||||
let request = match request_receiver.recv() {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
log::error!("Error in env runner : {}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
loop_agent.increment_agents(1);
|
||||
|
||||
match request {
|
||||
RequestMessage::Step() => (),
|
||||
RequestMessage::Episode() => request_episode = true,
|
||||
}
|
||||
}
|
||||
|
||||
let time_step = TimeStep {
|
||||
env_id: config.id,
|
||||
transition,
|
||||
done: step_result.done,
|
||||
ep_len: step_num,
|
||||
cum_reward: current_reward,
|
||||
action_context: context[0].clone(),
|
||||
};
|
||||
current_steps.push(time_step.clone());
|
||||
|
||||
if !request_episode && let Err(err) = loop_transition_sender.send(time_step) {
|
||||
log::error!("Error in env runner : {}", err);
|
||||
break;
|
||||
}
|
||||
|
||||
if step_result.done || step_result.truncated {
|
||||
if request_episode {
|
||||
request_episode = false;
|
||||
loop_trajectory_sender
|
||||
.send(Trajectory {
|
||||
timesteps: current_steps.clone(),
|
||||
})
|
||||
.expect("Can send trajectory to main thread.");
|
||||
}
|
||||
current_steps.clear();
|
||||
|
||||
env.reset();
|
||||
current_reward = 0.;
|
||||
step_num = 0;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
eval,
|
||||
agent,
|
||||
transition_receiver,
|
||||
trajectory_receiver,
|
||||
request_sender,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>
|
||||
where
|
||||
BT: Backend,
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
for _ in 0..num_steps {
|
||||
self.request_sender
|
||||
.send(RequestMessage::Step())
|
||||
.expect("Can request transitions.");
|
||||
let transition = self
|
||||
.transition_receiver
|
||||
.recv()
|
||||
.expect("Can receive transitions.");
|
||||
items.push(transition.clone());
|
||||
|
||||
if !self.eval {
|
||||
progress.items_processed += 1;
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
transition.action_context,
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
|
||||
if transition.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: transition.ep_len,
|
||||
cum_reward: transition.cum_reward,
|
||||
},
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
_progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
self.agent.increment_agents(1);
|
||||
for episode_num in 0..num_episodes {
|
||||
self.request_sender
|
||||
.send(RequestMessage::Episode())
|
||||
.expect("Can request episodes.");
|
||||
let trajectory = self
|
||||
.trajectory_receiver
|
||||
.recv()
|
||||
.expect("Main thread can receive trajectory.");
|
||||
|
||||
for (i, step) in trajectory.timesteps.iter().enumerate() {
|
||||
// TODO : clean this.
|
||||
if self.eval {
|
||||
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
|
||||
EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items.push(trajectory);
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.agent.decrement_agents(1);
|
||||
items
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: RLC::PolicyState) {
|
||||
self.agent.update(update);
|
||||
}
|
||||
|
||||
fn policy(&self) -> RLC::PolicyState {
|
||||
self.agent.state()
|
||||
}
|
||||
}
|
||||
|
||||
/// An asynchronous runner for multiple agent/environement interfaces.
|
||||
pub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
|
||||
num_envs: usize,
|
||||
eval: bool,
|
||||
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
|
||||
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
|
||||
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
|
||||
request_senders: Vec<Sender<RequestMessage>>,
|
||||
}
|
||||
|
||||
impl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {
|
||||
/// Create a new asynchronous runner for multiple agent/environement interfaces.
|
||||
pub fn new(
|
||||
num_envs: usize,
|
||||
env_init: RLC::EnvInit,
|
||||
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
|
||||
eval: bool,
|
||||
deterministic: bool,
|
||||
device: &Device<BT>,
|
||||
) -> Self {
|
||||
let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
|
||||
let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
|
||||
let mut request_senders = vec![];
|
||||
|
||||
// Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps.
|
||||
agent.increment_agents(num_envs);
|
||||
|
||||
for i in 0..num_envs {
|
||||
let config = AsyncAgentEnvLoopConfig {
|
||||
eval,
|
||||
deterministic,
|
||||
id: i,
|
||||
};
|
||||
let runner = AgentEnvAsyncLoop::<BT, RLC>::new(
|
||||
env_init.clone(),
|
||||
agent.clone(),
|
||||
config,
|
||||
&device.clone(),
|
||||
Some(transition_sender.clone()),
|
||||
Some(trajectory_sender.clone()),
|
||||
);
|
||||
request_senders.push(runner.request_sender.clone());
|
||||
}
|
||||
|
||||
// Double batching : The environments are always one step ahead.
|
||||
request_senders.iter().for_each(|s| {
|
||||
s.send(RequestMessage::Step())
|
||||
.expect("Main thread can send step requests.")
|
||||
});
|
||||
|
||||
Self {
|
||||
num_envs,
|
||||
eval,
|
||||
agent: agent.clone(),
|
||||
transition_receiver,
|
||||
trajectory_receiver,
|
||||
request_senders,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>
|
||||
where
|
||||
BT: Backend,
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
for _ in 0..num_steps {
|
||||
let transition = self
|
||||
.transition_receiver
|
||||
.recv()
|
||||
.expect("Can receive transitions.");
|
||||
items.push(transition.clone());
|
||||
|
||||
self.request_senders[transition.env_id]
|
||||
.send(RequestMessage::Step())
|
||||
.expect("Main thread can request steps.");
|
||||
|
||||
if !self.eval {
|
||||
progress.items_processed += 1;
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
transition.action_context,
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
|
||||
if transition.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: transition.ep_len,
|
||||
cum_reward: transition.cum_reward,
|
||||
},
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: RLC::PolicyState) {
|
||||
self.agent.update(update);
|
||||
}
|
||||
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
_progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>> {
|
||||
// Send `num_episodes` initial requests.
|
||||
let mut idx = vec![];
|
||||
if num_episodes < self.num_envs {
|
||||
let mut rng = rand::rng();
|
||||
let mut vec: Vec<usize> = (0..self.num_envs).collect();
|
||||
vec.shuffle(&mut rng);
|
||||
idx = vec.into_iter().take(num_episodes).collect();
|
||||
} else {
|
||||
idx = (0..self.num_envs).collect();
|
||||
}
|
||||
let num_requests = self.num_envs.min(num_episodes);
|
||||
idx.into_iter().for_each(|i| {
|
||||
self.request_senders[i]
|
||||
.send(RequestMessage::Episode())
|
||||
.expect("Main thread can request steps.");
|
||||
});
|
||||
|
||||
let mut items = vec![];
|
||||
for episode_num in 0..num_episodes {
|
||||
let trajectory = self
|
||||
.trajectory_receiver
|
||||
.recv()
|
||||
.expect("Can receive trajectory.");
|
||||
items.push(trajectory.clone());
|
||||
if items.len() + num_requests <= num_episodes {
|
||||
self.request_senders[trajectory.timesteps[0].env_id]
|
||||
.send(RequestMessage::Episode())
|
||||
.expect("Main thread can request steps.");
|
||||
}
|
||||
for (i, step) in trajectory.timesteps.iter().enumerate() {
|
||||
if self.eval {
|
||||
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
|
||||
EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
fn policy(&self) -> RLC::PolicyState {
|
||||
self.agent.state()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
mod tests {
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_rl::AsyncPolicy;
|
||||
|
||||
use crate::learner::rl::env_runner::async_runner::AsyncAgentEnvLoopConfig;
|
||||
use crate::learner::rl::env_runner::base::AgentEnvLoop;
|
||||
use crate::learner::tests::{MockPolicyState, MockProcessor};
|
||||
use crate::{
|
||||
AgentEnvAsyncLoop, TestBackend,
|
||||
learner::tests::{MockEnvInit, MockPolicy, MockRLComponents},
|
||||
};
|
||||
use crate::{AsyncProcessorTraining, Interrupter, MultiAgentEnvLoop};
|
||||
|
||||
fn setup_async_loop(
|
||||
state: usize,
|
||||
eval: bool,
|
||||
deterministic: bool,
|
||||
) -> AgentEnvAsyncLoop<TestBackend, MockRLComponents> {
|
||||
let env_init = MockEnvInit;
|
||||
let agent = MockPolicy(state);
|
||||
let config = AsyncAgentEnvLoopConfig {
|
||||
eval,
|
||||
deterministic,
|
||||
id: 0,
|
||||
};
|
||||
AgentEnvAsyncLoop::<TestBackend, MockRLComponents>::new(
|
||||
env_init,
|
||||
AsyncPolicy::new(1, agent),
|
||||
config,
|
||||
&Default::default(),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
fn setup_multi_loop(
|
||||
num_envs: usize,
|
||||
autobatch_size: usize,
|
||||
state: usize,
|
||||
eval: bool,
|
||||
deterministic: bool,
|
||||
) -> MultiAgentEnvLoop<TestBackend, MockRLComponents> {
|
||||
let env_init = MockEnvInit;
|
||||
let agent = MockPolicy(state);
|
||||
MultiAgentEnvLoop::<TestBackend, MockRLComponents>::new(
|
||||
num_envs,
|
||||
env_init,
|
||||
AsyncPolicy::new(autobatch_size, agent),
|
||||
eval,
|
||||
deterministic,
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_async_loop() {
|
||||
let runner = setup_async_loop(1000, false, false);
|
||||
let policy_state = runner.policy();
|
||||
assert_eq!(policy_state.0, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_policy_async_loop() {
|
||||
let mut runner = setup_async_loop(0, false, false);
|
||||
|
||||
runner.update_policy(MockPolicyState(1));
|
||||
assert_eq!(runner.policy().0, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_steps_returns_requested_number_async_loop() {
|
||||
let mut runner = setup_async_loop(0, false, false);
|
||||
let mut processor = AsyncProcessorTraining::new(MockProcessor);
|
||||
let interrupter = Interrupter::new();
|
||||
let mut progress = Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
};
|
||||
|
||||
let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(steps.len(), 1);
|
||||
let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(steps.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_episodes_returns_requested_number_async_loop() {
|
||||
let mut runner = setup_async_loop(0, false, false);
|
||||
let mut processor = AsyncProcessorTraining::new(MockProcessor);
|
||||
let interrupter = Interrupter::new();
|
||||
let mut progress = Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
};
|
||||
|
||||
let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(trajectories.len(), 1);
|
||||
assert_ne!(trajectories[0].timesteps.len(), 0);
|
||||
let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(trajectories.len(), 8);
|
||||
for i in 0..8 {
|
||||
assert_ne!(trajectories[i].timesteps.len(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_multi_loop() {
|
||||
let runner = setup_multi_loop(4, 4, 1000, false, false);
|
||||
let policy_state = runner.policy();
|
||||
assert_eq!(policy_state.0, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_policy_multi_loop() {
|
||||
let mut runner = setup_multi_loop(4, 4, 0, false, false);
|
||||
|
||||
runner.update_policy(MockPolicyState(1));
|
||||
assert_eq!(runner.policy().0, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_steps_returns_requested_number_multi_loop() {
|
||||
fn run_test(num_envs: usize, autobatch_size: usize) {
|
||||
let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
|
||||
let mut processor = AsyncProcessorTraining::new(MockProcessor);
|
||||
let interrupter = Interrupter::new();
|
||||
let mut progress = Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
};
|
||||
|
||||
// Kickstart tests by running some steps to make sure it's not a double batching edge case success.
|
||||
let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(steps.len(), 8);
|
||||
|
||||
for i in 0..16 {
|
||||
let steps = runner.run_steps(i, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(steps.len(), i);
|
||||
}
|
||||
}
|
||||
|
||||
// num_envs == autobatch_size
|
||||
run_test(1, 1);
|
||||
run_test(4, 4);
|
||||
// num_envs < autobatch_size
|
||||
run_test(1, 2);
|
||||
run_test(1, 3);
|
||||
run_test(2, 3);
|
||||
run_test(2, 4);
|
||||
run_test(5, 19);
|
||||
// num_envs > autobatch_size
|
||||
run_test(2, 1);
|
||||
run_test(8, 1);
|
||||
run_test(3, 2);
|
||||
run_test(8, 2);
|
||||
run_test(8, 3);
|
||||
run_test(8, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_episodes_returns_requested_number_multi_loop() {
|
||||
fn run_test(num_envs: usize, autobatch_size: usize) {
|
||||
let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
|
||||
let mut processor = AsyncProcessorTraining::new(MockProcessor);
|
||||
let interrupter = Interrupter::new();
|
||||
let mut progress = Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
};
|
||||
|
||||
// Kickstart tests by running some episodes to make sure it's not a double batching edge case success.
|
||||
let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(trajectories.len(), 8);
|
||||
for j in 0..8 {
|
||||
assert_ne!(trajectories[j].timesteps.len(), 0);
|
||||
}
|
||||
|
||||
for i in 0..16 {
|
||||
let trajectories =
|
||||
runner.run_episodes(i, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(trajectories.len(), i);
|
||||
for j in 0..i {
|
||||
assert_ne!(trajectories[j].timesteps.len(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// num_envs == autobatch_size
|
||||
run_test(1, 1);
|
||||
run_test(4, 4);
|
||||
// num_envs < autobatch_size
|
||||
run_test(1, 2);
|
||||
run_test(1, 3);
|
||||
run_test(2, 3);
|
||||
run_test(2, 4);
|
||||
run_test(5, 19);
|
||||
// num_envs > autobatch_size
|
||||
run_test(2, 1);
|
||||
run_test(8, 1);
|
||||
run_test(3, 2);
|
||||
run_test(8, 2);
|
||||
run_test(8, 3);
|
||||
run_test(8, 7);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,343 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::{Tensor, prelude::Backend};
|
||||
use burn_rl::Policy;
|
||||
use burn_rl::Transition;
|
||||
use burn_rl::{Environment, EnvironmentInit};
|
||||
|
||||
use crate::RLEvent;
|
||||
use crate::{
|
||||
AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
|
||||
RLEventProcessorType,
|
||||
};
|
||||
use crate::{Interrupter, RLComponentsTypes};
|
||||
|
||||
/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep).
|
||||
#[derive(Clone, new)]
|
||||
pub struct Trajectory<B: Backend, S, A, C> {
|
||||
/// A list of ordered [TimeStep](TimeStep)s.
|
||||
pub timesteps: Vec<TimeStep<B, S, A, C>>,
|
||||
}
|
||||
|
||||
/// A timestep debscribing an iteration of the state/decision process.
|
||||
#[derive(Clone)]
|
||||
pub struct TimeStep<B: Backend, S, A, C> {
|
||||
/// The environment id.
|
||||
pub env_id: usize,
|
||||
/// The [burn_rl::Transition](burn_rl::Transition).
|
||||
pub transition: Transition<B, S, A>,
|
||||
/// True if the environment reaches a terminal state.
|
||||
pub done: bool,
|
||||
/// The running length of the current episode.
|
||||
pub ep_len: usize,
|
||||
/// The running cumulative reward.
|
||||
pub cum_reward: f64,
|
||||
/// The action's context for this timestep.
|
||||
pub action_context: C,
|
||||
}
|
||||
|
||||
pub(crate) type RLTimeStep<B, RLC> = TimeStep<
|
||||
B,
|
||||
<RLC as RLComponentsTypes>::State,
|
||||
<RLC as RLComponentsTypes>::Action,
|
||||
<RLC as RLComponentsTypes>::ActionContext,
|
||||
>;
|
||||
|
||||
pub(crate) type RLTrajectory<B, RLC> = Trajectory<
|
||||
B,
|
||||
<RLC as RLComponentsTypes>::State,
|
||||
<RLC as RLComponentsTypes>::Action,
|
||||
<RLC as RLComponentsTypes>::ActionContext,
|
||||
>;
|
||||
|
||||
/// Trait for a structure that implements an agent/environement interface.
|
||||
pub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
|
||||
/// Run a certain number of timesteps.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num_steps` - The number of time_steps to run.
|
||||
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
|
||||
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
|
||||
/// * `num_steps` - The number of time_steps to run.
|
||||
/// * `progress` - A mutable reference to the learning progress.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A list of ordered timesteps.
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>>;
|
||||
/// Run a certain number of episodes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num_episodes` - The number of episodes to run.
|
||||
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
|
||||
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
|
||||
/// * `progress` - A mutable reference to the learning progress.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A list of ordered timesteps.
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>>;
|
||||
/// Update the runner's agent.
|
||||
fn update_policy(&mut self, update: RLC::PolicyState);
|
||||
/// Get the state of the runner's agent.
|
||||
fn policy(&self) -> RLC::PolicyState;
|
||||
}
|
||||
|
||||
/// A simple, synchronized agent/environement interface.
|
||||
pub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {
|
||||
env: RLC::Env,
|
||||
eval: bool,
|
||||
agent: RLC::Policy,
|
||||
deterministic: bool,
|
||||
current_reward: f64,
|
||||
run_num: usize,
|
||||
step_num: usize,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {
|
||||
/// Create a new base runner.
|
||||
pub fn new(
|
||||
env_init: RLC::EnvInit,
|
||||
agent: RLC::Policy,
|
||||
eval: bool,
|
||||
deterministic: bool,
|
||||
) -> Self {
|
||||
let mut env = env_init.init();
|
||||
env.reset();
|
||||
|
||||
Self {
|
||||
env,
|
||||
eval,
|
||||
agent: agent.clone(),
|
||||
deterministic,
|
||||
current_reward: 0.0,
|
||||
run_num: 0,
|
||||
step_num: 0,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>
|
||||
where
|
||||
BT: Backend,
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
let device = Default::default();
|
||||
for _ in 0..num_steps {
|
||||
let state = self.env.state();
|
||||
let (action, context) = self.agent.action(state.clone().into(), self.deterministic);
|
||||
|
||||
let step_result = self.env.step(RLC::Action::from(action.clone()));
|
||||
|
||||
self.current_reward += step_result.reward;
|
||||
self.step_num += 1;
|
||||
|
||||
let transition = Transition::new(
|
||||
state.clone(),
|
||||
step_result.next_state,
|
||||
RLC::Action::from(action),
|
||||
Tensor::from_data([step_result.reward], &device),
|
||||
Tensor::from_data(
|
||||
[(step_result.done || step_result.truncated) as i32 as f64],
|
||||
&device,
|
||||
),
|
||||
);
|
||||
items.push(TimeStep {
|
||||
env_id: 0,
|
||||
transition,
|
||||
done: step_result.done,
|
||||
ep_len: self.step_num,
|
||||
cum_reward: self.current_reward,
|
||||
action_context: context[0].clone(),
|
||||
});
|
||||
|
||||
if !self.eval {
|
||||
progress.items_processed += 1;
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
context[0].clone(),
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step_result.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: self.step_num,
|
||||
cum_reward: self.current_reward,
|
||||
},
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
|
||||
if step_result.done || step_result.truncated {
|
||||
self.env.reset();
|
||||
self.current_reward = 0.;
|
||||
self.step_num = 0;
|
||||
self.run_num += 1;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: RLC::PolicyState) {
|
||||
self.agent.update(update);
|
||||
}
|
||||
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>> {
|
||||
self.env.reset();
|
||||
|
||||
let mut items = vec![];
|
||||
for ep in 0..num_episodes {
|
||||
let mut steps = vec![];
|
||||
loop {
|
||||
let step = self.run_steps(1, processor, interrupter, progress)[0].clone();
|
||||
steps.push(step.clone());
|
||||
|
||||
if self.eval {
|
||||
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(steps.len() + 1, steps.len() + 1),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
|
||||
EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(ep + 1, num_episodes),
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() || step.done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items.push(Trajectory::new(steps));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn policy(&self) -> RLC::PolicyState {
|
||||
self.agent.state()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
mod tests {
|
||||
use crate::{AsyncProcessorTraining, TestBackend};
|
||||
|
||||
use crate::learner::tests::{
|
||||
MockEnvInit, MockPolicy, MockPolicyState, MockProcessor, MockRLComponents,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn setup(
|
||||
state: usize,
|
||||
eval: bool,
|
||||
deterministic: bool,
|
||||
) -> AgentEnvBaseLoop<TestBackend, MockRLComponents> {
|
||||
let env_init = MockEnvInit;
|
||||
let agent = MockPolicy(state);
|
||||
AgentEnvBaseLoop::<TestBackend, MockRLComponents>::new(env_init, agent, eval, deterministic)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_returns_agent_state() {
|
||||
let runner = setup(1000, false, false);
|
||||
let policy_state = runner.policy();
|
||||
assert_eq!(policy_state.0, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_policy() {
|
||||
let mut runner = setup(0, false, false);
|
||||
|
||||
runner.update_policy(MockPolicyState(1));
|
||||
assert_eq!(runner.policy().0, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_steps_returns_requested_number() {
|
||||
let mut runner = setup(0, false, false);
|
||||
let mut processor = AsyncProcessorTraining::new(MockProcessor);
|
||||
let interrupter = Interrupter::new();
|
||||
let mut progress = Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
};
|
||||
|
||||
let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(steps.len(), 1);
|
||||
let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(steps.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_episodes_returns_requested_number() {
|
||||
let mut runner = setup(0, false, false);
|
||||
let mut processor = AsyncProcessorTraining::new(MockProcessor);
|
||||
let interrupter = Interrupter::new();
|
||||
let mut progress = Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
};
|
||||
|
||||
let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(trajectories.len(), 1);
|
||||
assert_ne!(trajectories[0].timesteps.len(), 0);
|
||||
let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
|
||||
assert_eq!(trajectories.len(), 8);
|
||||
for i in 0..8 {
|
||||
assert_ne!(trajectories[i].timesteps.len(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
mod async_runner;
|
||||
mod base;
|
||||
|
||||
pub use async_runner::*;
|
||||
pub use base::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyState};
|
||||
|
||||
use crate::tests::TestAutodiffBackend;
|
||||
use crate::{
|
||||
AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLComponentsTypes, RLEvent,
|
||||
};
|
||||
use burn_rl::{LearnerTransitionBatch, PolicyLearner, RLTrainOutput, StepResult};
|
||||
|
||||
/// Mock policy for testing
|
||||
///
|
||||
/// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)
|
||||
/// containing a list of 0s of the same length as the observation.
|
||||
///
|
||||
/// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockPolicyAction](MockPolicyAction) with a list of actions of the same length as the observation.
|
||||
/// The actions are all 1 if the call is requested as deterministic, or else 0.
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockPolicy(pub usize);
|
||||
|
||||
impl Policy<TestAutodiffBackend> for MockPolicy {
|
||||
type Observation = MockObservation;
|
||||
type ActionDistribution = MockActionDistribution;
|
||||
type Action = MockPolicyAction;
|
||||
type ActionContext = MockActionContext;
|
||||
type PolicyState = MockPolicyState;
|
||||
|
||||
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {
|
||||
let mut dists = vec![];
|
||||
for _ in obs.0 {
|
||||
dists.push(MockActionDistribution(vec![0.]));
|
||||
}
|
||||
MockActionDistribution::batch(dists)
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
obs: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>) {
|
||||
let mut actions = vec![];
|
||||
let mut contexts = vec![];
|
||||
|
||||
for _ in obs.0 {
|
||||
if deterministic {
|
||||
actions.push(MockPolicyAction(vec![1]));
|
||||
} else {
|
||||
actions.push(MockPolicyAction(vec![0]));
|
||||
}
|
||||
contexts.push(MockActionContext);
|
||||
}
|
||||
|
||||
(MockPolicyAction::batch(actions), contexts)
|
||||
}
|
||||
|
||||
fn update(&mut self, update: Self::PolicyState) {
|
||||
self.0 = update.0;
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::PolicyState {
|
||||
MockPolicyState(self.0)
|
||||
}
|
||||
|
||||
fn load_record(
|
||||
self,
|
||||
_record: <Self::PolicyState as PolicyState<TestAutodiffBackend>>::Record,
|
||||
) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockObservation(pub Vec<f32>);
|
||||
|
||||
/// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockPolicyAction(pub Vec<i32>);
|
||||
|
||||
/// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockActionDistribution(Vec<f32>);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockActionContext;
|
||||
|
||||
/// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockPolicyState(pub usize);
|
||||
|
||||
impl PolicyState<TestAutodiffBackend> for MockPolicyState {
|
||||
type Record = ();
|
||||
|
||||
fn into_record(self) -> Self::Record {}
|
||||
|
||||
fn load_record(&self, _record: Self::Record) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Batchable for MockObservation {
|
||||
fn batch(items: Vec<Self>) -> Self {
|
||||
MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
vec![MockObservation(self.0)]
|
||||
}
|
||||
}
|
||||
|
||||
impl Batchable for MockPolicyAction {
|
||||
fn batch(items: Vec<Self>) -> Self {
|
||||
MockPolicyAction(items.iter().flat_map(|m| m.0.clone()).collect())
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
let mut actions = vec![];
|
||||
for a in self.0 {
|
||||
actions.push(MockPolicyAction(vec![a]));
|
||||
}
|
||||
actions
|
||||
}
|
||||
}
|
||||
|
||||
impl Batchable for MockActionDistribution {
|
||||
fn batch(items: Vec<Self>) -> Self {
|
||||
MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
let mut dists = vec![];
|
||||
for _ in self.0 {
|
||||
dists.push(MockActionDistribution(vec![0.]));
|
||||
}
|
||||
dists
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock environment for testing
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockEnv {
|
||||
counter: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct MockState;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct MockAction(pub i32);
|
||||
|
||||
impl From<MockState> for MockObservation {
|
||||
fn from(_value: MockState) -> Self {
|
||||
MockObservation(vec![0.])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MockPolicyAction> for MockAction {
|
||||
fn from(value: MockPolicyAction) -> Self {
|
||||
MockAction(value.0[0])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MockAction> for MockPolicyAction {
|
||||
fn from(value: MockAction) -> Self {
|
||||
MockPolicyAction(vec![value.0])
|
||||
}
|
||||
}
|
||||
|
||||
impl ItemLazy for MockActionContext {
|
||||
type ItemSync = MockActionContext;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl MockEnv {
|
||||
fn new() -> Self {
|
||||
Self { counter: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl Environment for MockEnv {
|
||||
type State = MockState;
|
||||
type Action = MockAction;
|
||||
const MAX_STEPS: usize = 5;
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.counter = 0;
|
||||
}
|
||||
|
||||
fn step(&mut self, _action: Self::Action) -> StepResult<Self::State> {
|
||||
self.counter += 1;
|
||||
let done = self.counter >= Self::MAX_STEPS;
|
||||
|
||||
burn_rl::StepResult {
|
||||
next_state: MockState,
|
||||
reward: 1.0,
|
||||
done,
|
||||
truncated: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::State {
|
||||
MockState
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock environment init for testing
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockEnvInit;
|
||||
|
||||
impl EnvironmentInit<MockEnv> for MockEnvInit {
|
||||
fn init(&self) -> MockEnv {
|
||||
MockEnv::new()
|
||||
}
|
||||
}
|
||||
|
||||
// Mock RLComponentsTypes for testing
|
||||
pub(crate) struct MockRLComponents;
|
||||
|
||||
impl RLComponentsTypes for MockRLComponents {
|
||||
type Backend = TestAutodiffBackend;
|
||||
type Env = MockEnv;
|
||||
type EnvInit = MockEnvInit;
|
||||
type State = MockState;
|
||||
type Action = MockAction;
|
||||
type Policy = MockPolicy;
|
||||
type PolicyObs = MockObservation;
|
||||
type PolicyAD = MockActionDistribution;
|
||||
type PolicyAction = MockPolicyAction;
|
||||
type ActionContext = MockActionContext;
|
||||
type PolicyState = MockPolicyState;
|
||||
type LearningAgent = MockLearningAgent;
|
||||
type TrainingOutput = ();
|
||||
}
|
||||
|
||||
// Mock learning agent for testing
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockLearningAgent;
|
||||
|
||||
impl PolicyLearner<TestAutodiffBackend> for MockLearningAgent {
|
||||
type InnerPolicy = MockPolicy;
|
||||
type TrainContext = ();
|
||||
type Record = ();
|
||||
|
||||
fn train(
|
||||
&mut self,
|
||||
_input: LearnerTransitionBatch<TestAutodiffBackend, Self::InnerPolicy>,
|
||||
) -> RLTrainOutput<
|
||||
Self::TrainContext,
|
||||
<Self::InnerPolicy as Policy<TestAutodiffBackend>>::PolicyState,
|
||||
> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn policy(&self) -> Self::InnerPolicy {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, _update: Self::InnerPolicy) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn record(&self) -> Self::Record {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
// Mock event processor for testing
|
||||
pub(crate) struct MockProcessor;
|
||||
|
||||
impl
|
||||
EventProcessorTraining<
|
||||
RLEvent<(), MockActionContext>,
|
||||
AgentEvaluationEvent<MockActionContext>,
|
||||
> for MockProcessor
|
||||
{
|
||||
fn process_train(&mut self, _event: RLEvent<(), MockActionContext>) {
|
||||
// Mock process train
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, _event: AgentEvaluationEvent<MockActionContext>) {
|
||||
// Mock process valid
|
||||
}
|
||||
|
||||
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
mod checkpointer;
|
||||
mod components;
|
||||
mod env_runner;
|
||||
mod off_policy;
|
||||
mod output;
|
||||
mod paradigm;
|
||||
mod strategy;
|
||||
|
||||
pub use checkpointer::*;
|
||||
pub use components::*;
|
||||
pub use env_runner::*;
|
||||
pub use off_policy::*;
|
||||
pub use output::*;
|
||||
pub use paradigm::*;
|
||||
pub use strategy::*;
|
||||
@@ -0,0 +1,189 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem,
|
||||
EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent,
|
||||
RLEventProcessorType, RLStrategy,
|
||||
};
|
||||
use burn_core::{self as burn};
|
||||
use burn_core::{config::Config, data::dataloader::Progress};
|
||||
use burn_ndarray::NdArray;
|
||||
use burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer};
|
||||
|
||||
/// Parameters of an on policy training with multi environments and double-batching.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct OffPolicyConfig {
|
||||
/// The number of environments to run simultaneously for experience collection.
|
||||
#[config(default = 1)]
|
||||
pub num_envs: usize,
|
||||
/// Number of environment state to accumulate before running one step of inference with the policy.
|
||||
/// Must be equal or less than the number of simultaneous environments.
|
||||
#[config(default = 1)]
|
||||
pub autobatch_size: usize,
|
||||
/// Max number of transitions stored in the replay buffer.
|
||||
#[config(default = 1024)]
|
||||
pub replay_buffer_size: usize,
|
||||
/// The number of steps to collect between each step of training.
|
||||
#[config(default = 1)]
|
||||
pub train_interval: usize,
|
||||
/// Number of optimization steps done each `train_interval`.
|
||||
#[config(default = 1)]
|
||||
pub train_steps: usize,
|
||||
/// The number of steps to collect between each evaluation.
|
||||
#[config(default = 10_000)]
|
||||
pub eval_interval: usize,
|
||||
/// The number of episodes to run for each evaluation.
|
||||
#[config(default = 1)]
|
||||
pub eval_episodes: usize,
|
||||
/// The number of transition to train on.
|
||||
#[config(default = 32)]
|
||||
pub train_batch_size: usize,
|
||||
/// Number of steps to collect before starting to train.
|
||||
#[config(default = 0)]
|
||||
pub warmup_steps: usize,
|
||||
}
|
||||
|
||||
/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.
|
||||
pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
|
||||
config: OffPolicyConfig,
|
||||
_components: PhantomData<RLC>,
|
||||
}
|
||||
impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
|
||||
/// Create a new off-policy base strategy.
|
||||
pub fn new(config: OffPolicyConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
_components: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
|
||||
where
|
||||
RLC: RLComponentsTypes,
|
||||
RLC::PolicyObs: SliceAccess<RLC::Backend>,
|
||||
RLC::PolicyAction: SliceAccess<RLC::Backend>,
|
||||
{
|
||||
fn train_loop(
|
||||
&self,
|
||||
training_components: RLComponents<RLC>,
|
||||
learner_agent: &mut RLC::LearningAgent,
|
||||
starting_epoch: usize,
|
||||
env_init: RLC::EnvInit,
|
||||
) -> (RLC::Policy, RLEventProcessorType<RLC>) {
|
||||
let mut event_processor = training_components.event_processor;
|
||||
let mut checkpointer = training_components.checkpointer;
|
||||
let num_steps_total = training_components.num_steps;
|
||||
|
||||
let mut env_runner = MultiAgentEnvLoop::<NdArray, RLC>::new(
|
||||
self.config.num_envs,
|
||||
env_init.clone(),
|
||||
AsyncPolicy::new(
|
||||
self.config.num_envs.min(self.config.autobatch_size),
|
||||
learner_agent.policy(),
|
||||
),
|
||||
false,
|
||||
false,
|
||||
&Default::default(),
|
||||
);
|
||||
let runner_config = AsyncAgentEnvLoopConfig {
|
||||
eval: true,
|
||||
deterministic: true,
|
||||
id: 0,
|
||||
};
|
||||
let mut env_runner_valid = AgentEnvAsyncLoop::<NdArray, RLC>::new(
|
||||
env_init,
|
||||
AsyncPolicy::new(1, learner_agent.policy()),
|
||||
runner_config,
|
||||
&Default::default(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let device: <RLC::Backend as burn_core::prelude::Backend>::Device = Default::default();
|
||||
let mut transition_buffer = TransitionBuffer::<
|
||||
RLC::Backend,
|
||||
RLC::PolicyObs,
|
||||
RLC::PolicyAction,
|
||||
>::new(self.config.replay_buffer_size, &device);
|
||||
|
||||
let mut valid_next = self.config.eval_interval + starting_epoch - 1;
|
||||
let mut progress = Progress {
|
||||
items_processed: starting_epoch,
|
||||
items_total: num_steps_total,
|
||||
};
|
||||
|
||||
let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
|
||||
None;
|
||||
while progress.items_processed < num_steps_total {
|
||||
if training_components.interrupter.should_stop() {
|
||||
let reason = training_components
|
||||
.interrupter
|
||||
.get_message()
|
||||
.unwrap_or(String::from("Reason unknown"));
|
||||
log::info!("Training interrupted: {reason}");
|
||||
break;
|
||||
}
|
||||
|
||||
let previous_steps = progress.items_processed;
|
||||
let items = env_runner.run_steps(
|
||||
self.config.train_interval,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
&mut progress,
|
||||
);
|
||||
|
||||
for item in &items {
|
||||
let t = &item.transition;
|
||||
let state: RLC::PolicyObs = t.state.clone().into();
|
||||
let next_state: RLC::PolicyObs = t.next_state.clone().into();
|
||||
let action: RLC::PolicyAction = t.action.clone().into();
|
||||
let reward = t.reward.to_data().to_vec::<f32>().unwrap()[0];
|
||||
let done = t.done.to_data().to_vec::<f32>().unwrap()[0] > 0.5;
|
||||
transition_buffer.push(state, next_state, action, reward, done);
|
||||
}
|
||||
|
||||
if transition_buffer.len() >= self.config.train_batch_size
|
||||
&& progress.items_processed >= self.config.warmup_steps
|
||||
{
|
||||
if let Some(ref u) = intermediary_update {
|
||||
env_runner.update_policy(u.clone());
|
||||
}
|
||||
for _ in 0..self.config.train_steps {
|
||||
let batch = transition_buffer.sample(self.config.train_batch_size);
|
||||
let train_item = learner_agent.train(batch);
|
||||
intermediary_update = Some(train_item.policy);
|
||||
|
||||
event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
|
||||
train_item.item,
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if valid_next > previous_steps && valid_next <= progress.items_processed {
|
||||
env_runner_valid.update_policy(learner_agent.policy().state());
|
||||
env_runner_valid.run_episodes(
|
||||
self.config.eval_episodes,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
&mut progress,
|
||||
);
|
||||
|
||||
if let Some(checkpointer) = &mut checkpointer {
|
||||
checkpointer.checkpoint(
|
||||
&env_runner.policy(),
|
||||
learner_agent,
|
||||
valid_next,
|
||||
&training_components.event_store,
|
||||
);
|
||||
}
|
||||
|
||||
valid_next += self.config.eval_interval;
|
||||
}
|
||||
}
|
||||
|
||||
(learner_agent.policy(), event_processor)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
ItemLazy,
|
||||
metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput},
|
||||
};
|
||||
|
||||
/// Summary of an episode.
|
||||
pub struct EpisodeSummary {
|
||||
/// The total length of the episode.
|
||||
pub episode_length: usize,
|
||||
/// The final cumulative reward.
|
||||
pub cum_reward: f64,
|
||||
}
|
||||
|
||||
impl ItemLazy for EpisodeSummary {
|
||||
type ItemSync = EpisodeSummary;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<EpisodeLengthInput> for EpisodeSummary {
|
||||
fn adapt(&self) -> EpisodeLengthInput {
|
||||
EpisodeLengthInput::new(self.episode_length as f64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<CumulativeRewardInput> for EpisodeSummary {
|
||||
fn adapt(&self) -> CumulativeRewardInput {
|
||||
CumulativeRewardInput::new(self.cum_reward)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,525 @@
|
||||
use crate::checkpoint::{
|
||||
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
|
||||
KeepLastNCheckpoints, MetricCheckpointingStrategy,
|
||||
};
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||
use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};
|
||||
use crate::renderer::{MetricsRenderer, default_renderer};
|
||||
use crate::{
|
||||
ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,
|
||||
LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,
|
||||
RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,
|
||||
RLPolicyRecord, RLStrategy,
|
||||
};
|
||||
use crate::{EpisodeSummary, RLStrategies};
|
||||
use burn_core::record::FileRecorder;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, SliceAccess};
|
||||
use std::collections::BTreeSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Structure to configure and launch reinforcement learning trainings.
|
||||
pub struct RLTraining<RLC: RLComponentsTypes> {
|
||||
// Not that complex. Extracting into yet another type would only make it more confusing.
|
||||
#[allow(clippy::type_complexity)]
|
||||
checkpointers: Option<(
|
||||
AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
|
||||
AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
|
||||
)>,
|
||||
num_steps: usize,
|
||||
checkpoint: Option<usize>,
|
||||
directory: PathBuf,
|
||||
grad_accumulation: Option<usize>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,
|
||||
event_store: LogEventStore,
|
||||
interrupter: Interrupter,
|
||||
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||
learning_strategy: RLStrategies<RLC>,
|
||||
// Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
|
||||
summary_metrics: BTreeSet<String>,
|
||||
summary: bool,
|
||||
env_initializer: RLC::EnvInit,
|
||||
}
|
||||
|
||||
impl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
E: Environment + 'static,
|
||||
EI: EnvironmentInit<E> + Send + 'static,
|
||||
A: PolicyLearner<B> + Send + 'static,
|
||||
A::TrainContext: ItemLazy + Clone + Send,
|
||||
A::InnerPolicy: Policy<B> + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
|
||||
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
|
||||
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
|
||||
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Into<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
/// Creates a new runner for reinforcement learning.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
/// * `env_init` - Specifies how to initialize the environment.
|
||||
pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {
|
||||
let directory = directory.as_ref().to_path_buf();
|
||||
let experiment_log_file = directory.join("experiment.log");
|
||||
Self {
|
||||
num_steps: 1,
|
||||
checkpoint: None,
|
||||
checkpointers: None,
|
||||
directory,
|
||||
grad_accumulation: None,
|
||||
metrics: RLMetrics::default(),
|
||||
event_store: LogEventStore::default(),
|
||||
renderer: None,
|
||||
interrupter: Interrupter::new(),
|
||||
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
|
||||
experiment_log_file,
|
||||
))),
|
||||
checkpointer_strategy: Box::new(
|
||||
ComposedCheckpointingStrategy::builder()
|
||||
.add(KeepLastNCheckpoints::new(2))
|
||||
.add(MetricCheckpointingStrategy::new(
|
||||
&EpisodeLengthMetric::new(), // default to evaluations' cumulative reward.
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Valid,
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),
|
||||
summary_metrics: BTreeSet::new(),
|
||||
summary: false,
|
||||
env_initializer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {
|
||||
/// Replace the default learning strategy (Off Policy learning) with the provided one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `training_strategy` - The training strategy.
|
||||
pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {
|
||||
self.learning_strategy = learning_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default metric loggers with the provided ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `logger` - The training logger.
|
||||
pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
|
||||
where
|
||||
ML: MetricLogger + 'static,
|
||||
{
|
||||
self.event_store.register_logger(logger);
|
||||
self
|
||||
}
|
||||
|
||||
/// Update the checkpointing_strategy.
|
||||
pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
|
||||
mut self,
|
||||
strategy: CS,
|
||||
) -> Self {
|
||||
self.checkpointer_strategy = Box::new(strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default CLI renderer with a custom one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The custom renderer.
|
||||
pub fn renderer<MR>(mut self, renderer: MR) -> Self
|
||||
where
|
||||
MR: MetricsRenderer + 'static,
|
||||
{
|
||||
self.renderer = Some(Box::new(renderer));
|
||||
self
|
||||
}
|
||||
|
||||
/// Register numerical metrics for a training step of the agent.
|
||||
pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register textual metrics for a training step of the agent.
|
||||
pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register numerical metrics for each action of the agent.
|
||||
pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register textual metrics for each action of the agent.
|
||||
pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register numerical metrics for a completed episode.
|
||||
pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register textual metrics for a completed episode.
|
||||
pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register a textual metric for a training step.
|
||||
pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_text_metric_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step.
|
||||
pub fn metric_train<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_metric_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a textual metric for each action taken by the agent.
|
||||
pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_text_metric_agent(metric.clone());
|
||||
self.metrics.register_text_metric_agent_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent.
|
||||
pub fn metric_agent<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_agent_metric(metric.clone());
|
||||
self.metrics.register_agent_metric_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a textual metric for a completed episode.
|
||||
pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
self.metrics.register_text_metric_episode(metric.clone());
|
||||
self.metrics.register_text_metric_episode_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode.
|
||||
pub fn metric_episode<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_episode_metric(metric.clone());
|
||||
self.metrics.register_episode_metric_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// The number of environment steps to train for.
|
||||
pub fn num_steps(mut self, num_steps: usize) -> Self {
|
||||
self.num_steps = num_steps;
|
||||
self
|
||||
}
|
||||
|
||||
/// The step from which the training must resume.
|
||||
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
|
||||
self.checkpoint = Some(checkpoint);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provides a handle that can be used to interrupt training.
|
||||
pub fn interrupter(&self) -> Interrupter {
|
||||
self.interrupter.clone()
|
||||
}
|
||||
|
||||
/// Override the handle for stopping training with an externally provided handle
|
||||
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
|
||||
self.interrupter = interrupter;
|
||||
self
|
||||
}
|
||||
|
||||
/// By default, Rust logs are captured and written into
|
||||
/// `experiment.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
pub fn with_application_logger(
|
||||
mut self,
|
||||
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
) -> Self {
|
||||
self.tracing_logger = logger;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a checkpointer that will save the environment runner's [policy](Policy)
|
||||
/// and the [PolicyLearner](PolicyLearner) state to different files.
|
||||
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
|
||||
where
|
||||
FR: FileRecorder<RLC::Backend> + 'static,
|
||||
FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,
|
||||
{
|
||||
let checkpoint_dir = self.directory.join("checkpoint");
|
||||
let checkpointer_policy =
|
||||
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy");
|
||||
let checkpointer_learning =
|
||||
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent");
|
||||
|
||||
self.checkpointers = Some((
|
||||
AsyncCheckpointer::new(checkpointer_policy),
|
||||
AsyncCheckpointer::new(checkpointer_learning),
|
||||
));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable the training summary report.
|
||||
///
|
||||
/// The summary will be displayed after `.launch()`, when the renderer is dropped.
|
||||
pub fn summary(mut self) -> Self {
|
||||
self.summary = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment.
|
||||
pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy>
|
||||
where
|
||||
RLC::PolicyObs: SliceAccess<RLC::Backend>,
|
||||
RLC::PolicyAction: SliceAccess<RLC::Backend>,
|
||||
{
|
||||
if self.tracing_logger.is_some()
|
||||
&& let Err(e) = self.tracing_logger.as_ref().unwrap().install()
|
||||
{
|
||||
log::warn!("Failed to install the experiment logger: {e}");
|
||||
}
|
||||
let renderer = self
|
||||
.renderer
|
||||
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
|
||||
|
||||
if !self.event_store.has_loggers() {
|
||||
self.event_store
|
||||
.register_logger(FileMetricLogger::new(self.directory.clone()));
|
||||
}
|
||||
|
||||
let event_store = Arc::new(EventStoreClient::new(self.event_store));
|
||||
let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(
|
||||
self.metrics,
|
||||
renderer,
|
||||
event_store.clone(),
|
||||
));
|
||||
|
||||
let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {
|
||||
RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)
|
||||
});
|
||||
|
||||
let summary = if self.summary {
|
||||
Some(LearnerSummaryConfig {
|
||||
directory: self.directory,
|
||||
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let components = RLComponents::<RLC> {
|
||||
checkpoint: self.checkpoint,
|
||||
checkpointer,
|
||||
interrupter: self.interrupter,
|
||||
event_processor,
|
||||
event_store,
|
||||
num_steps: self.num_steps,
|
||||
grad_accumulation: self.grad_accumulation,
|
||||
summary,
|
||||
};
|
||||
|
||||
match self.learning_strategy {
|
||||
RLStrategies::OffPolicyStrategy(config) => {
|
||||
let strategy = OffPolicyStrategy::new(config);
|
||||
strategy.train(learner_agent, components, self.env_initializer)
|
||||
}
|
||||
RLStrategies::Custom(strategy) => {
|
||||
strategy.train(learner_agent, components, self.env_initializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer).
|
||||
pub struct RLResult<P> {
|
||||
/// The learned policy.
|
||||
pub policy: P,
|
||||
/// The renderer that can be used for follow up training and evaluation.
|
||||
pub renderer: Box<dyn MetricsRenderer>,
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for train step metrics.
|
||||
pub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for train step text metrics.
|
||||
pub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for env step metrics.
|
||||
pub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for env step text metrics.
|
||||
pub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for episode metrics.
|
||||
pub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for episode text metrics.
|
||||
pub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
macro_rules! gen_tuple {
|
||||
($($M:ident),*) => {
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.text_metric_train($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_train($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.text_metric_agent($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_agent($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.text_metric_episode($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_episode($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
gen_tuple!(M1);
|
||||
gen_tuple!(M1, M2);
|
||||
gen_tuple!(M1, M2, M3);
|
||||
gen_tuple!(M1, M2, M3, M4);
|
||||
gen_tuple!(M1, M2, M3, M4, M5);
|
||||
gen_tuple!(M1, M2, M3, M4, M5, M6);
|
||||
@@ -0,0 +1,99 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
|
||||
RLEventProcessorType, RLResult,
|
||||
metric::{processor::EventProcessorTraining, store::EventStoreClient},
|
||||
};
|
||||
|
||||
/// Struct to minimise parameters passed to [RLStrategy::train].
|
||||
pub struct RLComponents<RLC: RLComponentsTypes> {
|
||||
/// The total number of environment steps.
|
||||
pub num_steps: usize,
|
||||
/// The step number from which to continue the training.
|
||||
pub checkpoint: Option<usize>,
|
||||
/// A checkpointer used to load and save learning checkpoints.
|
||||
pub checkpointer: Option<RLCheckpointer<RLC>>,
|
||||
/// Enables gradients accumulation.
|
||||
pub grad_accumulation: Option<usize>,
|
||||
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
|
||||
pub interrupter: Interrupter,
|
||||
/// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation.
|
||||
pub event_processor: RLEventProcessorType<RLC>,
|
||||
/// A reference to an [EventStoreClient](EventStoreClient).
|
||||
pub event_store: Arc<EventStoreClient>,
|
||||
/// Config for creating a summary of the learning
|
||||
pub summary: Option<LearnerSummaryConfig>,
|
||||
}
|
||||
|
||||
/// The strategy for reinforcement learning.
|
||||
#[derive(Clone)]
|
||||
pub enum RLStrategies<RLC: RLComponentsTypes> {
|
||||
/// Training on one device
|
||||
OffPolicyStrategy(OffPolicyConfig),
|
||||
/// Training using a custom learning strategy
|
||||
Custom(CustomRLStrategy<RLC>),
|
||||
}
|
||||
|
||||
/// A reference to an implementation of [RLStrategy].
|
||||
pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
|
||||
|
||||
/// Provides the `fit` function for any learning strategy
|
||||
pub trait RLStrategy<RLC: RLComponentsTypes> {
|
||||
/// Train the learner agent with this strategy.
|
||||
fn train(
|
||||
&self,
|
||||
mut learner_agent: RLC::LearningAgent,
|
||||
mut training_components: RLComponents<RLC>,
|
||||
env_init: RLC::EnvInit,
|
||||
) -> RLResult<RLC::Policy> {
|
||||
let starting_epoch = match training_components.checkpoint {
|
||||
Some(checkpoint) => {
|
||||
if let Some(checkpointer) = &mut training_components.checkpointer {
|
||||
learner_agent = checkpointer.load_checkpoint(
|
||||
learner_agent,
|
||||
&Default::default(),
|
||||
checkpoint,
|
||||
);
|
||||
}
|
||||
checkpoint + 1
|
||||
}
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let summary_config = training_components.summary.clone();
|
||||
|
||||
// Event processor start training
|
||||
training_components
|
||||
.event_processor
|
||||
.process_train(RLEvent::Start);
|
||||
|
||||
// Training loop
|
||||
let (policy, mut event_processor) = self.train_loop(
|
||||
training_components,
|
||||
&mut learner_agent,
|
||||
starting_epoch,
|
||||
env_init,
|
||||
);
|
||||
|
||||
let summary = summary_config.and_then(|summary| summary.init().ok());
|
||||
|
||||
// Signal training end. For the TUI renderer, this handles the exit & return to main screen.
|
||||
// TODO: summary makes sense for RL?
|
||||
event_processor.process_train(RLEvent::End(summary));
|
||||
|
||||
// let model = model.valid();
|
||||
let renderer = event_processor.renderer();
|
||||
|
||||
RLResult { policy, renderer }
|
||||
}
|
||||
|
||||
/// Training loop for this strategy
|
||||
fn train_loop(
|
||||
&self,
|
||||
training_components: RLComponents<RLC>,
|
||||
learner_agent: &mut RLC::LearningAgent,
|
||||
starting_epoch: usize,
|
||||
env_init: RLC::EnvInit,
|
||||
) -> (RLC::Policy, RLEventProcessorType<RLC>);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
use crate::metric::{AccuracyInput, PerplexityInput, TopKAccuracyInput};
|
||||
use crate::metric::{Adaptor, CerInput, LossInput, WerInput, processor::ItemLazy};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{Int, Tensor, Transaction};
|
||||
use burn_ndarray::NdArray;
|
||||
|
||||
/// Sequence prediction output adapted for multiple metrics.
|
||||
///
|
||||
/// Supported metrics:
|
||||
/// - Accuracy
|
||||
/// - TopKAccuracy
|
||||
/// - Perplexity
|
||||
/// - Loss
|
||||
/// - CER
|
||||
/// - WER
|
||||
#[derive(new)]
|
||||
pub struct SequenceOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// Raw logits. Shape: `[batch_size, seq_len, vocab_size]`
|
||||
pub logits: Tensor<B, 3>,
|
||||
|
||||
/// Optional predicted token indices. Shape: `[batch_size, seq_length]`.
|
||||
/// If not provided, predictions default to argmax of `logits` along the last dimension.
|
||||
pub predictions: Option<Tensor<B, 2, Int>>,
|
||||
|
||||
/// The target token indices. Shape: `[batch_size, seq_length]`
|
||||
pub targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SequenceOutput<B> {
|
||||
fn predicted_tokens(&self) -> Tensor<B, 2, Int> {
|
||||
match &self.predictions {
|
||||
Some(preds) => preds.clone(),
|
||||
None => self.logits.clone().argmax(2).squeeze_dim::<2>(2),
|
||||
}
|
||||
}
|
||||
|
||||
fn flat_logits(&self) -> Tensor<B, 2> {
|
||||
let [batch_size, seq_len, vocab_size] = self.logits.dims();
|
||||
self.logits
|
||||
.clone()
|
||||
.reshape([batch_size * seq_len, vocab_size])
|
||||
}
|
||||
|
||||
fn flat_targets(&self) -> Tensor<B, 1, Int> {
|
||||
let [batch_size, seq_len] = self.targets.dims();
|
||||
self.targets.clone().reshape([batch_size * seq_len])
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ItemLazy for SequenceOutput<B> {
|
||||
type ItemSync = SequenceOutput<NdArray>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
let device = &Default::default();
|
||||
|
||||
match self.predictions {
|
||||
Some(preds) => {
|
||||
let [logits, loss, targets, predictions] = Transaction::default()
|
||||
.register(self.logits)
|
||||
.register(self.loss)
|
||||
.register(self.targets)
|
||||
.register(preds)
|
||||
.execute()
|
||||
.try_into()
|
||||
.expect("Correct amount of tensor data");
|
||||
|
||||
SequenceOutput {
|
||||
logits: Tensor::from_data(logits, device),
|
||||
loss: Tensor::from_data(loss, device),
|
||||
targets: Tensor::from_data(targets, device),
|
||||
predictions: Some(Tensor::from_data(predictions, device)),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let [logits, loss, targets] = Transaction::default()
|
||||
.register(self.logits)
|
||||
.register(self.loss)
|
||||
.register(self.targets)
|
||||
.execute()
|
||||
.try_into()
|
||||
.expect("Correct amount of tensor data");
|
||||
|
||||
SequenceOutput {
|
||||
logits: Tensor::from_data(logits, device),
|
||||
loss: Tensor::from_data(loss, device),
|
||||
targets: Tensor::from_data(targets, device),
|
||||
predictions: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for SequenceOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<CerInput<B>> for SequenceOutput<B> {
|
||||
fn adapt(&self) -> CerInput<B> {
|
||||
CerInput::new(self.predicted_tokens(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<WerInput<B>> for SequenceOutput<B> {
|
||||
fn adapt(&self) -> WerInput<B> {
|
||||
WerInput::new(self.predicted_tokens(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<AccuracyInput<B>> for SequenceOutput<B> {
|
||||
fn adapt(&self) -> AccuracyInput<B> {
|
||||
AccuracyInput::new(self.flat_logits(), self.flat_targets())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for SequenceOutput<B> {
|
||||
fn adapt(&self) -> TopKAccuracyInput<B> {
|
||||
TopKAccuracyInput::new(self.flat_logits(), self.flat_targets())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<PerplexityInput<B>> for SequenceOutput<B> {
|
||||
fn adapt(&self) -> PerplexityInput<B> {
|
||||
PerplexityInput::new(self.flat_logits(), self.flat_targets())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,475 @@
|
||||
use core::cmp::Ordering;
|
||||
use std::{
|
||||
collections::{HashMap, hash_map::Entry},
|
||||
fmt::Display,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
logger::FileMetricLogger,
|
||||
metric::store::{Aggregate, EventStore, LogEventStore, Split},
|
||||
};
|
||||
|
||||
/// Contains the metric value at a given time.
|
||||
#[derive(Debug)]
|
||||
pub struct MetricEntry {
|
||||
/// The step at which the metric was recorded (i.e., epoch).
|
||||
pub step: usize,
|
||||
/// The metric value.
|
||||
pub value: f64,
|
||||
}
|
||||
|
||||
/// Contains the summary of recorded values for a given metric.
|
||||
#[derive(Debug)]
|
||||
pub struct MetricSummary {
|
||||
/// The metric name.
|
||||
pub name: String,
|
||||
/// The metric entries.
|
||||
pub entries: Vec<MetricEntry>,
|
||||
}
|
||||
|
||||
impl MetricSummary {
|
||||
fn collect<E: EventStore>(
|
||||
event_store: &mut E,
|
||||
metric: &str,
|
||||
split: &Split,
|
||||
num_epochs: usize,
|
||||
) -> Option<Self> {
|
||||
let entries = (1..=num_epochs)
|
||||
.filter_map(|epoch| {
|
||||
event_store
|
||||
.find_metric(metric, epoch, Aggregate::Mean, split)
|
||||
.map(|value| MetricEntry { step: epoch, value })
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if entries.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Self {
|
||||
name: metric.to_string(),
|
||||
entries,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains the summary of recorded metrics for the training and validation steps.
|
||||
pub struct SummaryMetrics {
|
||||
/// Training metrics summary.
|
||||
pub train: Vec<MetricSummary>,
|
||||
/// Validation metrics summary.
|
||||
pub valid: Vec<MetricSummary>,
|
||||
/// Test metrics summary per test split tag.
|
||||
///
|
||||
/// Each key corresponds to a `Split::Test(Some(tag))`.
|
||||
/// The empty string represents `Split::Test(None)`.
|
||||
pub test: HashMap<String, Vec<MetricSummary>>,
|
||||
}
|
||||
|
||||
/// Detailed training summary.
|
||||
pub struct LearnerSummary {
|
||||
/// The number of epochs completed.
|
||||
pub epochs: usize,
|
||||
/// The summary of recorded metrics during training.
|
||||
pub metrics: SummaryMetrics,
|
||||
/// The model name (only recorded within the learner).
|
||||
pub(crate) model: Option<String>,
|
||||
}
|
||||
|
||||
impl LearnerSummary {
|
||||
/// Creates a new learner summary for the specified metrics.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
|
||||
/// * `metrics` - The list of metrics to collect for the summary.
|
||||
pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
|
||||
let directory = directory.as_ref();
|
||||
if !directory.exists() {
|
||||
return Err(format!(
|
||||
"Artifact directory does not exist at: {}",
|
||||
directory.display()
|
||||
));
|
||||
}
|
||||
|
||||
let mut event_store = LogEventStore::default();
|
||||
let train_split = Split::Train;
|
||||
let valid_split = Split::Valid;
|
||||
|
||||
let logger = FileMetricLogger::new(directory);
|
||||
let test_split_root = logger.split_dir(&Split::Test(None));
|
||||
if !logger.split_exists(&train_split)
|
||||
&& !logger.split_exists(&valid_split)
|
||||
&& test_split_root.is_none()
|
||||
{
|
||||
return Err(format!(
|
||||
"No training, validation or test artifacts found at: {}",
|
||||
directory.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Number of recorded epochs
|
||||
let epochs = logger.epochs();
|
||||
|
||||
event_store.register_logger(logger);
|
||||
|
||||
let train_summary = metrics
|
||||
.iter()
|
||||
.filter_map(|metric| {
|
||||
MetricSummary::collect(&mut event_store, metric.as_ref(), &train_split, epochs)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let valid_summary = metrics
|
||||
.iter()
|
||||
.filter_map(|metric| {
|
||||
MetricSummary::collect(&mut event_store, metric.as_ref(), &valid_split, epochs)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let test_summary = match test_split_root {
|
||||
Some(root) => collect_test_split_metrics(root, metrics, &mut event_store, epochs),
|
||||
None => Default::default(),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
epochs,
|
||||
metrics: SummaryMetrics {
|
||||
train: train_summary,
|
||||
valid: valid_summary,
|
||||
test: test_summary,
|
||||
},
|
||||
model: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn with_model(mut self, name: String) -> Self {
|
||||
self.model = Some(name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Merges another summary into this one, combining all metric entries.
|
||||
pub(crate) fn merge(mut self, other: LearnerSummary) -> Self {
|
||||
fn merge_metrics(
|
||||
base: Vec<MetricSummary>,
|
||||
incoming: Vec<MetricSummary>,
|
||||
) -> Vec<MetricSummary> {
|
||||
let mut map: HashMap<String, MetricSummary> =
|
||||
base.into_iter().map(|m| (m.name.clone(), m)).collect();
|
||||
|
||||
for metric in incoming {
|
||||
match map.entry(metric.name.clone()) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().entries.extend(metric.entries);
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(metric);
|
||||
}
|
||||
}
|
||||
}
|
||||
map.into_values().collect()
|
||||
}
|
||||
|
||||
self.metrics.train = merge_metrics(self.metrics.train, other.metrics.train);
|
||||
self.metrics.valid = merge_metrics(self.metrics.valid, other.metrics.valid);
|
||||
|
||||
for (tag, metrics) in other.metrics.test {
|
||||
match self.metrics.test.entry(tag) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let current = std::mem::take(entry.get_mut());
|
||||
let merged = merge_metrics(current, metrics);
|
||||
*entry.get_mut() = merged;
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(metrics);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if self.model != other.model {
|
||||
self.model = None;
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_test_split_metrics<P: AsRef<Path>, S: AsRef<str>>(
|
||||
root: P,
|
||||
metrics: &[S],
|
||||
event_store: &mut LogEventStore,
|
||||
epochs: usize,
|
||||
) -> HashMap<String, Vec<MetricSummary>> {
|
||||
// Collect immediate child directories
|
||||
let dirs = match std::fs::read_dir(root) {
|
||||
Ok(entries) => entries
|
||||
.filter_map(|entry| {
|
||||
let entry = entry.ok()?;
|
||||
let file_type = entry.file_type().ok()?;
|
||||
if file_type.is_dir() {
|
||||
Some(entry.file_name().to_string_lossy().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let mut map = HashMap::new();
|
||||
|
||||
if dirs.is_empty() {
|
||||
return map;
|
||||
}
|
||||
|
||||
// Detect if all directories are epoch directories
|
||||
let all_epochs = dirs.iter().all(FileMetricLogger::is_epoch_dir);
|
||||
|
||||
if all_epochs {
|
||||
// Single untagged test split
|
||||
let split = Split::Test(None);
|
||||
|
||||
let summaries = metrics
|
||||
.iter()
|
||||
.filter_map(|metric| {
|
||||
MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Untagged marked with empty string
|
||||
map.insert("".to_string(), summaries);
|
||||
} else {
|
||||
// Tagged splits
|
||||
for tag in dirs {
|
||||
let split = Split::Test(Some(tag.clone().into()));
|
||||
|
||||
let summaries = metrics
|
||||
.iter()
|
||||
.filter_map(|metric| {
|
||||
MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
map.insert(tag, summaries);
|
||||
}
|
||||
}
|
||||
|
||||
map
|
||||
}
|
||||
|
||||
impl Display for LearnerSummary {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Compute the max length for each column
|
||||
let mut max_split_len = 5; // "Train"
|
||||
let mut max_metric_len = "Metric".len();
|
||||
for metric in self.metrics.train.iter() {
|
||||
max_metric_len = max_metric_len.max(metric.name.len());
|
||||
}
|
||||
for metric in self.metrics.valid.iter() {
|
||||
max_metric_len = max_metric_len.max(metric.name.len());
|
||||
}
|
||||
for (tag, metrics) in self.metrics.test.iter() {
|
||||
let split_name = if tag.is_empty() {
|
||||
"Test".to_string()
|
||||
} else {
|
||||
format!("Test ({tag})")
|
||||
};
|
||||
|
||||
max_split_len = max_split_len.max(split_name.len());
|
||||
|
||||
for metric in metrics {
|
||||
max_metric_len = max_metric_len.max(metric.name.len());
|
||||
}
|
||||
}
|
||||
|
||||
// Summary header
|
||||
writeln!(
|
||||
f,
|
||||
"{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
|
||||
"",
|
||||
"",
|
||||
width_symbol = 24,
|
||||
)?;
|
||||
|
||||
if let Some(model) = &self.model {
|
||||
writeln!(f, "Model:\n{model}")?;
|
||||
}
|
||||
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
|
||||
|
||||
// Metrics table header
|
||||
writeln!(
|
||||
f,
|
||||
"| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
|
||||
"Split",
|
||||
"Metric",
|
||||
"",
|
||||
"",
|
||||
width_split = max_split_len,
|
||||
width_metric = max_metric_len,
|
||||
)?;
|
||||
|
||||
// Table entries
|
||||
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
|
||||
match (a.is_nan(), b.is_nan()) {
|
||||
(true, true) => Ordering::Equal,
|
||||
(true, false) => Ordering::Greater,
|
||||
(false, true) => Ordering::Less,
|
||||
_ => a.partial_cmp(b).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_val(val: f64) -> String {
|
||||
if val < 1e-2 {
|
||||
// Use scientific notation for small values which would otherwise be truncated
|
||||
format!("{val:<9.3e}")
|
||||
} else {
|
||||
format!("{val:<9.3}")
|
||||
}
|
||||
}
|
||||
|
||||
let mut write_metrics_summary =
|
||||
|metrics: &[MetricSummary], split: String| -> std::fmt::Result {
|
||||
for metric in metrics.iter() {
|
||||
if metric.entries.is_empty() {
|
||||
continue; // skip metrics with no recorded values
|
||||
}
|
||||
|
||||
// Compute the min & max for each metric
|
||||
let metric_min = metric
|
||||
.entries
|
||||
.iter()
|
||||
.min_by(|a, b| cmp_f64(&a.value, &b.value))
|
||||
.unwrap();
|
||||
let metric_max = metric
|
||||
.entries
|
||||
.iter()
|
||||
.max_by(|a, b| cmp_f64(&a.value, &b.value))
|
||||
.unwrap();
|
||||
|
||||
writeln!(
|
||||
f,
|
||||
"| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|",
|
||||
split,
|
||||
metric.name,
|
||||
fmt_val(metric_min.value),
|
||||
metric_min.step,
|
||||
fmt_val(metric_max.value),
|
||||
metric_max.step,
|
||||
width_split = max_split_len,
|
||||
width_metric = max_metric_len,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
};
|
||||
|
||||
write_metrics_summary(&self.metrics.train, format!("{:?}", Split::Train))?;
|
||||
write_metrics_summary(&self.metrics.valid, format!("{:?}", Split::Valid))?;
|
||||
|
||||
for (tag, metrics) in &self.metrics.test {
|
||||
let split_name = if tag.is_empty() {
|
||||
"Test".to_string()
|
||||
} else {
|
||||
format!("Test ({tag})")
|
||||
};
|
||||
|
||||
write_metrics_summary(metrics, split_name)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: rename to `ExperimentSummary`? Used in learner + evaluator.
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Learning summary config.
|
||||
pub struct LearnerSummaryConfig {
|
||||
pub(crate) directory: PathBuf,
|
||||
pub(crate) metrics: Vec<String>,
|
||||
}
|
||||
|
||||
impl LearnerSummaryConfig {
|
||||
/// Create the learning summary.
|
||||
pub fn init(&self) -> Result<LearnerSummary, String> {
|
||||
LearnerSummary::new(&self.directory, &self.metrics[..])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Summary artifacts should exist"]
|
||||
fn test_artifact_dir_should_exist() {
|
||||
let dir = "/tmp/learner-summary-not-found";
|
||||
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Summary artifacts should exist"]
|
||||
fn test_train_valid_artifacts_should_exist() {
|
||||
let dir = "/tmp/test-learner-summary-empty";
|
||||
std::fs::create_dir_all(dir).ok();
|
||||
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_should_be_empty() {
|
||||
let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
|
||||
std::fs::create_dir_all(dir).unwrap();
|
||||
std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
|
||||
std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
|
||||
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
|
||||
.expect("Summary artifacts should exist");
|
||||
|
||||
assert_eq!(summary.epochs, 1);
|
||||
|
||||
assert_eq!(summary.metrics.train.len(), 0);
|
||||
assert_eq!(summary.metrics.valid.len(), 0);
|
||||
|
||||
std::fs::remove_dir_all(dir).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_should_be_collected() {
|
||||
let dir = Path::new("/tmp/test-learner-summary");
|
||||
let train_dir = dir.join("train/epoch-1");
|
||||
let valid_dir = dir.join("valid/epoch-1");
|
||||
std::fs::create_dir_all(dir).unwrap();
|
||||
std::fs::create_dir_all(&train_dir).unwrap();
|
||||
std::fs::create_dir_all(&valid_dir).unwrap();
|
||||
|
||||
std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
|
||||
std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
|
||||
|
||||
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
|
||||
.expect("Summary artifacts should exist");
|
||||
|
||||
assert_eq!(summary.epochs, 1);
|
||||
|
||||
// Only Loss metric
|
||||
assert_eq!(summary.metrics.train.len(), 1);
|
||||
assert_eq!(summary.metrics.valid.len(), 1);
|
||||
|
||||
// Aggregated train metric entries for 1 epoch
|
||||
let train_metric = &summary.metrics.train[0];
|
||||
assert_eq!(train_metric.name, "Loss");
|
||||
assert_eq!(train_metric.entries.len(), 1);
|
||||
let entry = &train_metric.entries[0];
|
||||
assert_eq!(entry.step, 1); // epoch = 1
|
||||
assert_eq!(entry.value, 1.5); // (1 + 2) / 2
|
||||
|
||||
// Aggregated valid metric entries for 1 epoch
|
||||
let valid_metric = &summary.metrics.valid[0];
|
||||
assert_eq!(valid_metric.name, "Loss");
|
||||
assert_eq!(valid_metric.entries.len(), 1);
|
||||
let entry = &valid_metric.entries[0];
|
||||
assert_eq!(entry.step, 1); // epoch = 1
|
||||
assert_eq!(entry.value, 1.0);
|
||||
|
||||
std::fs::remove_dir_all(dir).unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod paradigm;
|
||||
mod step;
|
||||
mod strategies;
|
||||
|
||||
pub use paradigm::*;
|
||||
pub use step::*;
|
||||
pub use strategies::*;
|
||||
@@ -0,0 +1,488 @@
|
||||
use crate::checkpoint::{
|
||||
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
|
||||
KeepLastNCheckpoints, MetricCheckpointingStrategy,
|
||||
};
|
||||
use crate::components::{InferenceModelOutput, TrainingModelOutput};
|
||||
use crate::learner::EarlyStoppingStrategy;
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::processor::{
|
||||
AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
|
||||
};
|
||||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||
use crate::metric::{Adaptor, LossMetric, Metric, Numeric};
|
||||
use crate::multi::MultiDeviceLearningStrategy;
|
||||
use crate::renderer::{MetricsRenderer, default_renderer};
|
||||
use crate::single::SingleDeviceTrainingStrategy;
|
||||
use crate::{
|
||||
ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
|
||||
InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent,
|
||||
LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig,
|
||||
LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult,
|
||||
TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy,
|
||||
};
|
||||
use crate::{Learner, SupervisedLearningStrategy};
|
||||
use burn_core::data::dataloader::DataLoader;
|
||||
use burn_core::module::{AutodiffModule, Module};
|
||||
use burn_core::record::FileRecorder;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_optim::Optimizer;
|
||||
use burn_optim::lr_scheduler::LrScheduler;
|
||||
use std::collections::BTreeSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A reference to the training split [DataLoader](DataLoader).
|
||||
pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;
|
||||
/// A reference to the validation split [DataLoader](DataLoader).
|
||||
pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
|
||||
/// The event processor type for supervised learning.
|
||||
pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
|
||||
LearnerEvent<TrainingModelOutput<LC>>,
|
||||
LearnerEvent<InferenceModelOutput<LC>>,
|
||||
>;
|
||||
|
||||
/// Structure to configure and launch supervised learning trainings.
|
||||
pub struct SupervisedTraining<LC>
|
||||
where
|
||||
LC: LearningComponentsTypes,
|
||||
{
|
||||
// Not that complex. Extracting into another type would only make it more confusing.
|
||||
#[allow(clippy::type_complexity)]
|
||||
checkpointers: Option<(
|
||||
AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,
|
||||
AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,
|
||||
AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,
|
||||
)>,
|
||||
num_epochs: usize,
|
||||
checkpoint: Option<usize>,
|
||||
directory: PathBuf,
|
||||
grad_accumulation: Option<usize>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
|
||||
event_store: LogEventStore,
|
||||
interrupter: Interrupter,
|
||||
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||
early_stopping: Option<EarlyStoppingStrategyRef>,
|
||||
training_strategy: Option<TrainingStrategy<LC>>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: ValidLoader<LC>,
|
||||
// Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
|
||||
summary_metrics: BTreeSet<String>,
|
||||
summary: bool,
|
||||
}
|
||||
|
||||
impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
LR: LrScheduler + 'static,
|
||||
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
|
||||
M::InnerModule: InferenceStep,
|
||||
O: Optimizer<M, B> + 'static,
|
||||
{
|
||||
/// Creates a new runner for a supervised training.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
/// * `dataloader_train` - The dataloader for the training split.
|
||||
/// * `dataloader_valid` - The dataloader for the validation split.
|
||||
pub fn new(
|
||||
directory: impl AsRef<Path>,
|
||||
dataloader_train: Arc<dyn DataLoader<B, M::Input>>,
|
||||
dataloader_valid: Arc<
|
||||
dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,
|
||||
>,
|
||||
) -> Self {
|
||||
let directory = directory.as_ref().to_path_buf();
|
||||
let experiment_log_file = directory.join("experiment.log");
|
||||
Self {
|
||||
num_epochs: 1,
|
||||
checkpoint: None,
|
||||
checkpointers: None,
|
||||
directory,
|
||||
grad_accumulation: None,
|
||||
metrics: MetricsTraining::default(),
|
||||
event_store: LogEventStore::default(),
|
||||
renderer: None,
|
||||
interrupter: Interrupter::new(),
|
||||
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
|
||||
experiment_log_file,
|
||||
))),
|
||||
checkpointer_strategy: Box::new(
|
||||
ComposedCheckpointingStrategy::builder()
|
||||
.add(KeepLastNCheckpoints::new(2))
|
||||
.add(MetricCheckpointingStrategy::new(
|
||||
&LossMetric::<B>::new(), // default to valid loss
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Valid,
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
early_stopping: None,
|
||||
training_strategy: None,
|
||||
summary_metrics: BTreeSet::new(),
|
||||
summary: false,
|
||||
dataloader_train,
|
||||
dataloader_valid,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
|
||||
/// Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `training_strategy` - The training strategy.
|
||||
pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {
|
||||
self.training_strategy = Some(training_strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default metric loggers with the provided ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `logger` - The training logger.
|
||||
pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
|
||||
where
|
||||
ML: MetricLogger + 'static,
|
||||
{
|
||||
self.event_store.register_logger(logger);
|
||||
self
|
||||
}
|
||||
|
||||
/// Update the checkpointing_strategy.
|
||||
pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
|
||||
mut self,
|
||||
strategy: CS,
|
||||
) -> Self {
|
||||
self.checkpointer_strategy = Box::new(strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default CLI renderer with a custom one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The custom renderer.
|
||||
pub fn renderer<MR>(mut self, renderer: MR) -> Self
|
||||
where
|
||||
MR: MetricsRenderer + 'static,
|
||||
{
|
||||
self.renderer = Some(Box::new(renderer));
|
||||
self
|
||||
}
|
||||
|
||||
/// Register all metrics as numeric for the training and validation set.
|
||||
pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register all metrics as text for the training and validation set.
|
||||
pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_train_metric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a validation metric.
|
||||
pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_valid_metric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable gradients accumulation.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// When you enable gradients accumulation, the gradients object used by the optimizer will be
|
||||
/// the sum of all gradients generated by each backward pass. It might be a good idea to
|
||||
/// reduce the learning to compensate.
|
||||
///
|
||||
/// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
|
||||
/// amount.
|
||||
pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
|
||||
self.grad_accumulation = Some(accumulation);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
|
||||
pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_train_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
|
||||
pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_valid_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// The number of epochs the training should last.
|
||||
pub fn num_epochs(mut self, num_epochs: usize) -> Self {
|
||||
self.num_epochs = num_epochs;
|
||||
self
|
||||
}
|
||||
|
||||
/// The epoch from which the training must resume.
|
||||
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
|
||||
self.checkpoint = Some(checkpoint);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provides a handle that can be used to interrupt training.
|
||||
pub fn interrupter(&self) -> Interrupter {
|
||||
self.interrupter.clone()
|
||||
}
|
||||
|
||||
/// Override the handle for stopping training with an externally provided handle
|
||||
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
|
||||
self.interrupter = interrupter;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
|
||||
/// conditions are meet.
|
||||
pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
|
||||
where
|
||||
Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
|
||||
{
|
||||
self.early_stopping = Some(Box::new(strategy));
|
||||
self
|
||||
}
|
||||
|
||||
/// By default, Rust logs are captured and written into
|
||||
/// `experiment.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
pub fn with_application_logger(
|
||||
mut self,
|
||||
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
) -> Self {
|
||||
self.tracing_logger = logger;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a checkpointer that will save the [optimizer](Optimizer), the
|
||||
/// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
|
||||
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
|
||||
where
|
||||
FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,
|
||||
FR: FileRecorder<
|
||||
<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,
|
||||
> + 'static,
|
||||
{
|
||||
let checkpoint_dir = self.directory.join("checkpoint");
|
||||
let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
|
||||
let checkpointer_optimizer =
|
||||
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
|
||||
let checkpointer_scheduler: FileCheckpointer<FR> =
|
||||
FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
|
||||
|
||||
self.checkpointers = Some((
|
||||
AsyncCheckpointer::new(checkpointer_model),
|
||||
AsyncCheckpointer::new(checkpointer_optimizer),
|
||||
AsyncCheckpointer::new(checkpointer_scheduler),
|
||||
));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable the training summary report.
|
||||
///
|
||||
/// The summary will be displayed after `.fit()`, when the renderer is dropped.
|
||||
pub fn summary(mut self) -> Self {
|
||||
self.summary = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes + Send + 'static> SupervisedTraining<LC> {
|
||||
/// Launch this training with the given [Learner](Learner).
|
||||
pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {
|
||||
if self.tracing_logger.is_some()
|
||||
&& let Err(e) = self.tracing_logger.as_ref().unwrap().install()
|
||||
{
|
||||
log::warn!("Failed to install the experiment logger: {e}");
|
||||
}
|
||||
let renderer = self
|
||||
.renderer
|
||||
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
|
||||
|
||||
if !self.event_store.has_loggers() {
|
||||
self.event_store
|
||||
.register_logger(FileMetricLogger::new(self.directory.clone()));
|
||||
}
|
||||
|
||||
let event_store = Arc::new(EventStoreClient::new(self.event_store));
|
||||
let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
|
||||
self.metrics,
|
||||
renderer,
|
||||
event_store.clone(),
|
||||
));
|
||||
|
||||
let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
|
||||
LearningCheckpointer::new(
|
||||
model.with_interrupter(self.interrupter.clone()),
|
||||
optim.with_interrupter(self.interrupter.clone()),
|
||||
scheduler.with_interrupter(self.interrupter.clone()),
|
||||
self.checkpointer_strategy,
|
||||
)
|
||||
});
|
||||
|
||||
let summary = if self.summary {
|
||||
Some(LearnerSummaryConfig {
|
||||
directory: self.directory,
|
||||
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let components = TrainingComponents {
|
||||
checkpoint: self.checkpoint,
|
||||
checkpointer,
|
||||
interrupter: self.interrupter,
|
||||
early_stopping: self.early_stopping,
|
||||
event_processor,
|
||||
event_store,
|
||||
num_epochs: self.num_epochs,
|
||||
grad_accumulation: self.grad_accumulation,
|
||||
summary,
|
||||
};
|
||||
|
||||
// Default to single device based on model
|
||||
let training_strategy = self
|
||||
.training_strategy
|
||||
.unwrap_or(TrainingStrategy::SingleDevice(
|
||||
learner.model.devices()[0].clone(),
|
||||
));
|
||||
|
||||
match training_strategy {
|
||||
TrainingStrategy::SingleDevice(device) => {
|
||||
let single_device: SingleDeviceTrainingStrategy<LC> =
|
||||
SingleDeviceTrainingStrategy::new(device);
|
||||
single_device.train(
|
||||
learner,
|
||||
self.dataloader_train,
|
||||
self.dataloader_valid,
|
||||
components,
|
||||
)
|
||||
}
|
||||
TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(
|
||||
learner,
|
||||
self.dataloader_train,
|
||||
self.dataloader_valid,
|
||||
components,
|
||||
),
|
||||
TrainingStrategy::MultiDevice(devices, multi_device_optim) => {
|
||||
let strategy: Box<dyn SupervisedLearningStrategy<LC>> = match devices.len() == 1 {
|
||||
true => Box::new(SingleDeviceTrainingStrategy::new(devices[0].clone())),
|
||||
false => Box::new(MultiDeviceLearningStrategy::new(
|
||||
devices,
|
||||
multi_device_optim,
|
||||
)),
|
||||
};
|
||||
strategy.train(
|
||||
learner,
|
||||
self.dataloader_train,
|
||||
self.dataloader_valid,
|
||||
components,
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "ddp")]
|
||||
TrainingStrategy::DistributedDataParallel { devices, config } => {
|
||||
use crate::ddp::DdpTrainingStrategy;
|
||||
|
||||
let ddp = DdpTrainingStrategy::new(devices.clone(), config.clone());
|
||||
ddp.train(
|
||||
learner,
|
||||
self.dataloader_train,
|
||||
self.dataloader_valid,
|
||||
components,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics.
|
||||
pub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics.
|
||||
pub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
|
||||
}
|
||||
|
||||
macro_rules! gen_tuple {
|
||||
($($M:ident),*) => {
|
||||
impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)
|
||||
where
|
||||
$(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: SupervisedTraining<LC>,
|
||||
) -> SupervisedTraining<LC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_train($M.clone());)*
|
||||
$(let builder = builder.metric_valid($M);)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)
|
||||
where
|
||||
$(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: SupervisedTraining<LC>,
|
||||
) -> SupervisedTraining<LC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_train_numeric($M.clone());)*
|
||||
$(let builder = builder.metric_valid_numeric($M);)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
gen_tuple!(M1);
|
||||
gen_tuple!(M1, M2);
|
||||
gen_tuple!(M1, M2, M3);
|
||||
gen_tuple!(M1, M2, M3, M4);
|
||||
gen_tuple!(M1, M2, M3, M4, M5);
|
||||
gen_tuple!(M1, M2, M3, M4, M5, M6);
|
||||
@@ -0,0 +1,2 @@
|
||||
/// The trainer module.
|
||||
pub mod train;
|
||||
@@ -0,0 +1,151 @@
|
||||
use crate::{LearningComponentsTypes, TrainingModel};
|
||||
use crate::{TrainOutput, TrainStep, TrainingBackend, TrainingModelInput, TrainingModelOutput};
|
||||
use burn_core::data::dataloader::DataLoaderIterator;
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::module::Module;
|
||||
use burn_core::prelude::DeviceOps;
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::DeviceId;
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
use std::thread::spawn;
|
||||
|
||||
/// Multi devices train step.
|
||||
pub struct MultiDevicesTrainStep<LC: LearningComponentsTypes> {
|
||||
workers: Vec<Worker<LC>>,
|
||||
receiver: Receiver<MultiTrainOutput<TrainingModelOutput<LC>>>,
|
||||
}
|
||||
|
||||
struct Message<M, TI> {
|
||||
item: TI,
|
||||
model: M,
|
||||
}
|
||||
|
||||
struct Worker<LC: LearningComponentsTypes> {
|
||||
// Not that complex. Extracting into another type would only make it more confusing.
|
||||
#[allow(clippy::type_complexity)]
|
||||
sender_input: Sender<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
|
||||
device: Device<TrainingBackend<LC>>,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> Worker<LC> {
|
||||
fn register(&self, item: TrainingModelInput<LC>, model: &TrainingModel<LC>) {
|
||||
let message = Message {
|
||||
item,
|
||||
model: model.clone(),
|
||||
};
|
||||
self.sender_input.send(message).unwrap();
|
||||
}
|
||||
|
||||
// Not that complex. Extracting into another type would only make it more confusing.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn start(
|
||||
&self,
|
||||
sender_output: Sender<MultiTrainOutput<TrainingModelOutput<LC>>>,
|
||||
receiver_input: Receiver<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
|
||||
) {
|
||||
let device = self.device.clone();
|
||||
|
||||
spawn(move || {
|
||||
loop {
|
||||
match receiver_input.recv() {
|
||||
Ok(item) => {
|
||||
let model = item.model.fork(&device);
|
||||
let output = model.step(item.item);
|
||||
let item = MultiTrainOutput {
|
||||
output,
|
||||
device: device.to_id(),
|
||||
};
|
||||
|
||||
sender_output.send(item).unwrap();
|
||||
}
|
||||
Err(_err) => {
|
||||
log::info!("Closing thread on device {device:?}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiple output items.
|
||||
pub struct MultiTrainOutput<TO> {
|
||||
/// The training output.
|
||||
pub output: TrainOutput<TO>,
|
||||
/// The device on which the computing happened.
|
||||
pub device: DeviceId,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> MultiDevicesTrainStep<LC> {
|
||||
/// Create a new multi devices train step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `devices` - Devices.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// MultiDevicesTrainStep instance.
|
||||
pub fn new(devices: &[Device<TrainingBackend<LC>>]) -> Self {
|
||||
let (sender_output, receiver_output) = std::sync::mpsc::channel();
|
||||
let workers = devices
|
||||
.iter()
|
||||
.map(|device| {
|
||||
let (sender_input, receiver_input) = std::sync::mpsc::channel();
|
||||
let worker = Worker {
|
||||
sender_input,
|
||||
device: device.clone(),
|
||||
};
|
||||
|
||||
worker.start(sender_output.clone(), receiver_input);
|
||||
worker
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
workers,
|
||||
receiver: receiver_output,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect outputs from workers for one step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - Model.
|
||||
/// * `dataloaders` - The data loader for each worker.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Outputs.
|
||||
pub fn step<'a>(
|
||||
&self,
|
||||
dataloaders: &mut [Box<dyn DataLoaderIterator<TrainingModelInput<LC>> + 'a>],
|
||||
model: &TrainingModel<LC>,
|
||||
) -> (Vec<MultiTrainOutput<TrainingModelOutput<LC>>>, Progress) {
|
||||
let mut num_send = 0;
|
||||
|
||||
let mut items_total = 0;
|
||||
let mut items_processed = 0;
|
||||
|
||||
for (i, worker) in self.workers.iter().enumerate() {
|
||||
let dataloader = &mut dataloaders[i];
|
||||
if let Some(item) = dataloader.next() {
|
||||
worker.register(item, model);
|
||||
num_send += 1;
|
||||
let progress = dataloader.progress();
|
||||
items_total += progress.items_total;
|
||||
items_processed += progress.items_processed;
|
||||
}
|
||||
}
|
||||
|
||||
let mut outputs = Vec::with_capacity(num_send);
|
||||
|
||||
for _ in 0..num_send {
|
||||
let output = self.receiver.recv().unwrap();
|
||||
outputs.push(output);
|
||||
}
|
||||
|
||||
(outputs, Progress::new(items_processed, items_total))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "ddp")]
|
||||
use burn_collective::CollectiveConfig;
|
||||
use burn_core::{module::AutodiffModule, prelude::Backend};
|
||||
|
||||
use crate::{
|
||||
EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
|
||||
LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
|
||||
TrainingModel, ValidLoader,
|
||||
components::LearningComponentsTypes,
|
||||
metric::{
|
||||
processor::{EventProcessorTraining, LearnerEvent},
|
||||
store::EventStoreClient,
|
||||
},
|
||||
};
|
||||
|
||||
type LearnerDevice<LC> = <<LC as LearningComponentsTypes>::Backend as Backend>::Device;
|
||||
|
||||
/// A reference to an implementation of SupervisedLearningStrategy.
|
||||
pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
/// Determine how the optimization is performed when training with multiple devices.
|
||||
pub enum MultiDeviceOptim {
|
||||
/// The optimization is done on an elected device.
|
||||
OptimMainDevice,
|
||||
/// The optimization is sharded across all devices.
|
||||
OptimSharded,
|
||||
}
|
||||
|
||||
/// How should the learner run the learning for the model
|
||||
#[derive(Clone)]
|
||||
pub enum TrainingStrategy<LC: LearningComponentsTypes> {
|
||||
/// Training on one device
|
||||
SingleDevice(LearnerDevice<LC>),
|
||||
/// Performs data-parallel distributed training where the optimization is
|
||||
/// done on an elected master device.
|
||||
MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
|
||||
/// Training using a custom learning strategy
|
||||
Custom(CustomLearningStrategy<LC>),
|
||||
/// Training with input distributed across devices, each device has its own copy of the model.
|
||||
/// Collective ops are used to sync the gradients after each pass.
|
||||
#[cfg(feature = "ddp")]
|
||||
DistributedDataParallel {
|
||||
/// Devices on this node for the DDP
|
||||
devices: Vec<LearnerDevice<LC>>,
|
||||
|
||||
/// The configuration for collective operations
|
||||
/// num_devices is ignored
|
||||
config: CollectiveConfig,
|
||||
},
|
||||
}
|
||||
|
||||
/// Constructor for a distributed data parallel (DDP) learning strategy
|
||||
#[cfg(feature = "ddp")]
|
||||
pub fn ddp<LC: LearningComponentsTypes>(
|
||||
devices: Vec<LearnerDevice<LC>>,
|
||||
config: CollectiveConfig,
|
||||
) -> TrainingStrategy<LC> {
|
||||
TrainingStrategy::DistributedDataParallel { devices, config }
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
|
||||
fn default() -> Self {
|
||||
Self::SingleDevice(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct to minimise parameters passed to [SupervisedLearningStrategy::train].
|
||||
/// These components are used during training.
|
||||
pub struct TrainingComponents<LC: LearningComponentsTypes> {
|
||||
/// The total number of epochs
|
||||
pub num_epochs: usize,
|
||||
/// The epoch number from which to continue the training.
|
||||
pub checkpoint: Option<usize>,
|
||||
/// A checkpointer used to load and save learner checkpoints.
|
||||
pub checkpointer: Option<LearningCheckpointer<LC>>,
|
||||
/// Enables gradients accumulation.
|
||||
pub grad_accumulation: Option<usize>,
|
||||
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
|
||||
pub interrupter: Interrupter,
|
||||
/// Cloneable reference to an early stopping strategy.
|
||||
pub early_stopping: Option<EarlyStoppingStrategyRef>,
|
||||
/// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and validation.
|
||||
pub event_processor: SupervisedTrainingEventProcessor<LC>,
|
||||
/// A reference to an [EventStoreClient](EventStoreClient).
|
||||
pub event_store: Arc<EventStoreClient>,
|
||||
/// Config for creating a summary of the learning
|
||||
pub summary: Option<LearnerSummaryConfig>,
|
||||
}
|
||||
|
||||
/// Provides the `fit` function for any learning strategy
|
||||
pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
|
||||
/// Train the learner's model with this strategy.
|
||||
fn train(
|
||||
&self,
|
||||
mut learner: Learner<LC>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: ValidLoader<LC>,
|
||||
mut training_components: TrainingComponents<LC>,
|
||||
) -> LearningResult<InferenceModel<LC>> {
|
||||
let starting_epoch = match training_components.checkpoint {
|
||||
Some(checkpoint) => {
|
||||
if let Some(checkpointer) = &mut training_components.checkpointer {
|
||||
learner =
|
||||
checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
|
||||
}
|
||||
checkpoint + 1
|
||||
}
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let summary_config = training_components.summary.clone();
|
||||
|
||||
// Event processor start training
|
||||
training_components
|
||||
.event_processor
|
||||
.process_train(LearnerEvent::Start);
|
||||
// Training loop
|
||||
let (model, mut event_processor) = self.fit(
|
||||
training_components,
|
||||
learner,
|
||||
dataloader_train,
|
||||
dataloader_valid,
|
||||
starting_epoch,
|
||||
);
|
||||
|
||||
let summary = summary_config.and_then(|summary| {
|
||||
summary
|
||||
.init()
|
||||
.map(|summary| summary.with_model(model.to_string()))
|
||||
.ok()
|
||||
});
|
||||
|
||||
// Signal training end. For the TUI renderer, this handles the exit & return to main screen.
|
||||
event_processor.process_train(LearnerEvent::End(summary));
|
||||
|
||||
let model = model.valid();
|
||||
let renderer = event_processor.renderer();
|
||||
|
||||
LearningResult::<InferenceModel<LC>> { model, renderer }
|
||||
}
|
||||
|
||||
/// Training loop for this strategy
|
||||
fn fit(
|
||||
&self,
|
||||
training_components: TrainingComponents<LC>,
|
||||
learner: Learner<LC>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: ValidLoader<LC>,
|
||||
starting_epoch: usize,
|
||||
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
## DDP
|
||||
Distributed Data Parallel
|
||||
|
||||
The DDP is a learning strategy that trains a replica of the model on each device.
|
||||
|
||||
The DDP launches threads for each local device. Each thread on each node will run the model.
|
||||
After the forward and backward passes, the gradients are synced between all peers on all nodes
|
||||
with an `all-reduce` operation.
|
||||
|
||||
While the DDP launches threads for each local device, it is the user's responsibility to launch the
|
||||
DDP on each node, and assure the collective configuration matches.
|
||||
|
||||
## Main device vs secondary devices
|
||||
|
||||
The main device is responsible for validation, as well as event processing, which is used in the UI.
|
||||
|
||||
The first device is chosen as the main device.
|
||||
@@ -0,0 +1,234 @@
|
||||
use burn_collective::{PeerId, ReduceOperation};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_optim::GradientsAccumulator;
|
||||
use burn_optim::GradientsParams;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::mpsc::{Receiver, SyncSender};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::SupervisedTrainingEventProcessor;
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
|
||||
use crate::{
|
||||
InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader,
|
||||
};
|
||||
|
||||
/// A validation epoch.
|
||||
#[derive(new)]
|
||||
pub struct DdpValidEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: ValidLoader<LC>,
|
||||
}
|
||||
|
||||
/// A training epoch.
|
||||
#[derive(new)]
|
||||
pub struct DdpTrainEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: TrainLoader<LC>,
|
||||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {
|
||||
/// Runs the validation epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to validate.
|
||||
/// * `processor` - The event processor to use.
|
||||
pub fn run(
|
||||
&self,
|
||||
model: &<LC as LearningComponentsTypes>::TrainingModel,
|
||||
global_progress: &Progress,
|
||||
processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing validation step for epoch {}", epoch);
|
||||
let model = model.valid();
|
||||
|
||||
let mut iterator = self.dataloader.iter();
|
||||
let mut iteration = 0;
|
||||
|
||||
while let Some(item) = iterator.next() {
|
||||
let progress = iterator.progress();
|
||||
iteration += 1;
|
||||
|
||||
let item = model.step(item);
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
None,
|
||||
);
|
||||
|
||||
processor.process_valid(LearnerEvent::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
processor.process_valid(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {
|
||||
/// Runs the training epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `scheduler` - The learning rate scheduler to use.
|
||||
/// * `processor` - The event processor to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The trained model and the optimizer.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
global_progress: &Progress,
|
||||
processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
|
||||
interrupter: &Interrupter,
|
||||
peer_id: PeerId,
|
||||
peer_count: usize,
|
||||
is_main: bool,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing training step for epoch {}", epoch,);
|
||||
|
||||
let mut iterator = self.dataloader.iter();
|
||||
let mut iteration = 0;
|
||||
let mut accumulator = GradientsAccumulator::new();
|
||||
let mut accumulation_current = 0;
|
||||
|
||||
let grads_syncer = GradsSyncer::<
|
||||
TrainingBackend<LC>,
|
||||
<LC as LearningComponentsTypes>::TrainingModel,
|
||||
>::new(false, peer_id);
|
||||
|
||||
while let Some(item) = iterator.next() {
|
||||
for _ in 0..peer_count {
|
||||
iteration += 1;
|
||||
learner.lr_step();
|
||||
}
|
||||
log::info!("Iteration {iteration}");
|
||||
|
||||
let mut progress = iterator.progress();
|
||||
progress.items_processed *= peer_count;
|
||||
progress.items_total *= peer_count;
|
||||
|
||||
let item = learner.train_step(item);
|
||||
|
||||
match self.grad_accumulation {
|
||||
Some(accumulation) => {
|
||||
accumulator.accumulate(&learner.model(), item.grads);
|
||||
accumulation_current += 1;
|
||||
|
||||
if accumulation <= accumulation_current {
|
||||
let grads = accumulator.grads();
|
||||
|
||||
// With double buffering, these are the previous iteration's gradients
|
||||
let grads = grads_syncer.sync(grads);
|
||||
if let Some(grads) = grads {
|
||||
learner.optimizer_step(grads);
|
||||
}
|
||||
|
||||
accumulation_current = 0;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// With double buffering, these are the previous iteration's gradients
|
||||
let grads = grads_syncer.sync(item.grads);
|
||||
|
||||
if let Some(grads) = grads {
|
||||
learner.optimizer_step(grads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let item = TrainingItem::new(
|
||||
item.item,
|
||||
progress,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
{
|
||||
let mut processor = processor.lock().unwrap();
|
||||
processor.process_train(LearnerEvent::ProcessedItem(item));
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if is_main {
|
||||
let mut processor = processor.lock().unwrap();
|
||||
processor.process_train(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Worker that is responsible for syncing gradients for the DDP worker. With double buffering,
|
||||
/// this allows for more optimization.
|
||||
struct GradsSyncer<B: AutodiffBackend, M: AutodiffModule<B> + 'static> {
|
||||
msg_send: SyncSender<GradientsParams>,
|
||||
// Optional because with double buffering, the first iteration yields no gradients.
|
||||
result_recv: Receiver<Option<GradientsParams>>,
|
||||
|
||||
_p: PhantomData<(B, M)>,
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend, M: AutodiffModule<B> + 'static> GradsSyncer<B, M> {
|
||||
fn new(double_buffering: bool, peer_id: PeerId) -> Self {
|
||||
let (msg_send, msg_recv) = std::sync::mpsc::sync_channel::<GradientsParams>(1);
|
||||
let (result_send, result_recv) =
|
||||
std::sync::mpsc::sync_channel::<Option<GradientsParams>>(1);
|
||||
std::thread::spawn(move || {
|
||||
Self::run_worker(double_buffering, peer_id, result_send, msg_recv)
|
||||
});
|
||||
Self {
|
||||
msg_send,
|
||||
result_recv,
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&self, grads: GradientsParams) -> Option<GradientsParams> {
|
||||
self.msg_send.send(grads).unwrap();
|
||||
self.result_recv.recv().unwrap()
|
||||
}
|
||||
|
||||
fn run_worker(
|
||||
double_buffering: bool,
|
||||
peer_id: PeerId,
|
||||
send: SyncSender<Option<GradientsParams>>,
|
||||
recv: Receiver<GradientsParams>,
|
||||
) {
|
||||
let mut grads_buffer = None;
|
||||
|
||||
while let Ok(new_grads) = recv.recv() {
|
||||
// Sync grads with collective
|
||||
let new_grads = new_grads
|
||||
.all_reduce::<B::InnerBackend>(peer_id, ReduceOperation::Mean)
|
||||
.expect("DDP worker could not sync gradients!");
|
||||
|
||||
if double_buffering {
|
||||
let old_grads = grads_buffer.take();
|
||||
grads_buffer = Some(new_grads);
|
||||
|
||||
send.send(old_grads).unwrap();
|
||||
} else {
|
||||
send.send(Some(new_grads)).unwrap();
|
||||
}
|
||||
}
|
||||
// GradsSyncer dropped, channel closed, this thread can end
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod epoch;
|
||||
mod strategy;
|
||||
mod worker;
|
||||
|
||||
pub use strategy::*;
|
||||
@@ -0,0 +1,140 @@
|
||||
use core::panic;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use burn_collective::CollectiveConfig;
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::DeviceOps;
|
||||
|
||||
use crate::ddp::worker::DdpWorker;
|
||||
use crate::metric::store::EventStoreClient;
|
||||
use crate::{
|
||||
EarlyStoppingStrategyRef, Interrupter, Learner, LearningComponentsTypes,
|
||||
SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend,
|
||||
TrainingComponents, TrainingModel, ValidLoader,
|
||||
};
|
||||
use burn_core::data::dataloader::split::split_dataloader;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct WorkerComponents {
|
||||
/// The total number of epochs
|
||||
pub num_epochs: usize,
|
||||
/// Enables gradients accumulation.
|
||||
pub grad_accumulation: Option<usize>,
|
||||
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
|
||||
pub interrupter: Interrupter,
|
||||
/// Cloneable reference to an early stopping strategy.
|
||||
pub early_stopping: Option<EarlyStoppingStrategyRef>,
|
||||
/// A reference to an [EventStoreClient](EventStoreClient).
|
||||
pub event_store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
pub struct DdpTrainingStrategy<LC: LearningComponentsTypes> {
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
config: CollectiveConfig,
|
||||
}
|
||||
impl<LC: LearningComponentsTypes> DdpTrainingStrategy<LC> {
|
||||
pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, config: CollectiveConfig) -> Self {
|
||||
let config = config.with_num_devices(devices.len());
|
||||
Self { devices, config }
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes + Send + 'static> SupervisedLearningStrategy<LC>
|
||||
for DdpTrainingStrategy<LC>
|
||||
{
|
||||
fn fit(
|
||||
&self,
|
||||
training_components: TrainingComponents<LC>,
|
||||
learner: Learner<LC>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: ValidLoader<LC>,
|
||||
starting_epoch: usize,
|
||||
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
|
||||
// The reference model is always on the first device provided.
|
||||
let main_device = self.devices.first().unwrap();
|
||||
// One worker per device, so we use a fixed device strategy
|
||||
// for each (worker) data loader. This matches the expected device on the worker, so we
|
||||
// don't have to move the data between devices.
|
||||
let mut dataloaders_train = split_dataloader(dataloader_train, &self.devices);
|
||||
let dataloader_valid = dataloader_valid.to_device(main_device.inner());
|
||||
|
||||
let main_device = self.devices[0].clone();
|
||||
let peer_count = self.devices.len();
|
||||
let event_processor = Arc::new(Mutex::new(training_components.event_processor));
|
||||
|
||||
let interrupter = training_components.interrupter;
|
||||
let worker_components = WorkerComponents {
|
||||
num_epochs: training_components.num_epochs,
|
||||
grad_accumulation: training_components.grad_accumulation,
|
||||
interrupter: interrupter.clone(),
|
||||
early_stopping: training_components.early_stopping,
|
||||
event_store: training_components.event_store,
|
||||
};
|
||||
|
||||
// Start worker for main device
|
||||
// First training dataloader corresponds to main device
|
||||
let main_handle = DdpWorker::<LC>::start(
|
||||
0.into(),
|
||||
main_device,
|
||||
learner.clone(),
|
||||
event_processor.clone(),
|
||||
worker_components.clone(),
|
||||
training_components.checkpointer,
|
||||
dataloaders_train.remove(0),
|
||||
Some(dataloader_valid),
|
||||
self.config.clone(),
|
||||
starting_epoch,
|
||||
peer_count,
|
||||
true,
|
||||
);
|
||||
|
||||
// Spawn other workers for the other devices, starting with peer id 1
|
||||
let mut peer_id = 1;
|
||||
let mut secondary_workers = vec![];
|
||||
for device in &self.devices[1..] {
|
||||
let handle = DdpWorker::<LC>::start(
|
||||
peer_id.into(),
|
||||
device.clone(),
|
||||
learner.clone(),
|
||||
event_processor.clone(),
|
||||
worker_components.clone(),
|
||||
None,
|
||||
dataloaders_train.remove(0),
|
||||
None,
|
||||
self.config.clone(),
|
||||
starting_epoch,
|
||||
peer_count,
|
||||
false,
|
||||
);
|
||||
|
||||
peer_id += 1;
|
||||
|
||||
secondary_workers.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all devices to finish
|
||||
for worker in secondary_workers {
|
||||
worker
|
||||
.join()
|
||||
.expect("Distributed data parallel worker failed");
|
||||
}
|
||||
// Main worker had the event processor
|
||||
let model = main_handle
|
||||
.join()
|
||||
.expect("Distributed data parallel main worker failed");
|
||||
|
||||
if interrupter.should_stop() {
|
||||
let reason = interrupter
|
||||
.get_message()
|
||||
.unwrap_or(String::from("Reason unknown"));
|
||||
log::info!("Training interrupted: {reason}");
|
||||
}
|
||||
let Ok(event_processor) = Arc::try_unwrap(event_processor) else {
|
||||
panic!("Event processor still held!");
|
||||
};
|
||||
let Ok(event_processor) = event_processor.into_inner() else {
|
||||
panic!("Event processor lock poisoned");
|
||||
};
|
||||
(model, event_processor)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch};
|
||||
use crate::ddp::strategy::WorkerComponents;
|
||||
use crate::single::TrainingLoop;
|
||||
use crate::{
|
||||
Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor,
|
||||
TrainLoader, TrainingBackend, ValidLoader,
|
||||
};
|
||||
use burn_collective::{self, CollectiveConfig, PeerId};
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread::JoinHandle;
|
||||
|
||||
/// A worker runs the model, syncing gradients using collective operations.
|
||||
/// Event processing and validation is optional too.
|
||||
pub(crate) struct DdpWorker<LC>
|
||||
where
|
||||
LC: LearningComponentsTypes + Send + 'static,
|
||||
{
|
||||
peer_id: PeerId,
|
||||
device: Device<TrainingBackend<LC>>,
|
||||
learner: Learner<LC>,
|
||||
event_processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
|
||||
components: WorkerComponents,
|
||||
checkpointer: Option<LearningCheckpointer<LC>>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: Option<ValidLoader<LC>>,
|
||||
collective_config: CollectiveConfig,
|
||||
starting_epoch: usize,
|
||||
peer_count: usize,
|
||||
is_main: bool,
|
||||
}
|
||||
|
||||
impl<LC> DdpWorker<LC>
|
||||
where
|
||||
LC: LearningComponentsTypes + Send + 'static,
|
||||
{
|
||||
/// Starts a worker that runs the model in a data distributed parallel
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn start(
|
||||
peer_id: PeerId,
|
||||
device: Device<TrainingBackend<LC>>,
|
||||
learner: Learner<LC>,
|
||||
event_processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
|
||||
components: WorkerComponents,
|
||||
checkpointer: Option<LearningCheckpointer<LC>>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: Option<ValidLoader<LC>>,
|
||||
collective_config: CollectiveConfig,
|
||||
starting_epoch: usize,
|
||||
peer_count: usize,
|
||||
is_main: bool,
|
||||
) -> JoinHandle<<LC as LearningComponentsTypes>::TrainingModel> {
|
||||
let worker = Self {
|
||||
peer_id,
|
||||
device,
|
||||
learner,
|
||||
event_processor,
|
||||
components,
|
||||
checkpointer,
|
||||
dataloader_train,
|
||||
dataloader_valid,
|
||||
collective_config,
|
||||
starting_epoch,
|
||||
peer_count,
|
||||
is_main,
|
||||
};
|
||||
|
||||
std::thread::spawn(|| worker.fit())
|
||||
}
|
||||
|
||||
/// Fits the model,
|
||||
pub fn fit(mut self) -> <LC as LearningComponentsTypes>::TrainingModel {
|
||||
burn_collective::register::<<TrainingBackend<LC> as AutodiffBackend>::InnerBackend>(
|
||||
self.peer_id,
|
||||
self.device.clone(),
|
||||
self.collective_config.clone(),
|
||||
)
|
||||
.expect("Couldn't register for collective operations!");
|
||||
|
||||
let num_epochs = self.components.num_epochs;
|
||||
let interrupter = self.components.interrupter;
|
||||
|
||||
// Changed the train epoch to keep the dataloaders
|
||||
let epoch_train = DdpTrainEpoch::<LC>::new(
|
||||
self.dataloader_train.clone(),
|
||||
self.components.grad_accumulation,
|
||||
);
|
||||
let epoch_valid = self
|
||||
.dataloader_valid
|
||||
.map(|dataloader| DdpValidEpoch::<LC>::new(dataloader));
|
||||
self.learner.fork(&self.device);
|
||||
|
||||
for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) {
|
||||
let epoch = training_progress.items_processed;
|
||||
|
||||
epoch_train.run(
|
||||
&mut self.learner,
|
||||
&training_progress,
|
||||
self.event_processor.clone(),
|
||||
&interrupter,
|
||||
self.peer_id,
|
||||
self.peer_count,
|
||||
self.is_main,
|
||||
);
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Validation
|
||||
if let Some(runner) = &epoch_valid {
|
||||
let mut event_processor = self.event_processor.lock().unwrap();
|
||||
runner.run(
|
||||
&self.learner.model(),
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&interrupter,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(checkpointer) = &mut self.checkpointer {
|
||||
checkpointer.checkpoint(&self.learner, epoch, &self.components.event_store);
|
||||
}
|
||||
|
||||
if let Some(early_stopping) = &mut self.components.early_stopping
|
||||
&& early_stopping.should_stop(epoch, &self.components.event_store)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
self.learner.model()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod base;
|
||||
|
||||
#[cfg(feature = "ddp")]
|
||||
pub(crate) mod ddp;
|
||||
pub(crate) mod multi;
|
||||
pub(crate) mod single;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,219 @@
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
|
||||
use crate::train::MultiDevicesTrainStep;
|
||||
use crate::{
|
||||
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor,
|
||||
TrainLoader, TrainingBackend,
|
||||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::prelude::DeviceOps;
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::DeviceId;
|
||||
use burn_optim::GradientsAccumulator;
|
||||
use burn_optim::MultiGradientsParams;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A training epoch.
|
||||
#[derive(new)]
|
||||
pub struct MultiDeviceTrainEpoch<LC: LearningComponentsTypes> {
|
||||
dataloaders: Vec<TrainLoader<LC>>,
|
||||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
|
||||
/// Runs the training epoch on multiple devices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `lr_scheduler` - The learning rate scheduler to use.
|
||||
/// * `processor` - The event processor to use.
|
||||
/// * `devices` - The devices to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The trained model and the optimizer.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
global_progress: &Progress,
|
||||
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
strategy: MultiDeviceOptim,
|
||||
) {
|
||||
match strategy {
|
||||
MultiDeviceOptim::OptimMainDevice => self.run_optim_main(
|
||||
learner,
|
||||
global_progress,
|
||||
event_processor,
|
||||
interrupter,
|
||||
devices,
|
||||
),
|
||||
MultiDeviceOptim::OptimSharded => self.run_optim_distr(
|
||||
learner,
|
||||
global_progress,
|
||||
event_processor,
|
||||
interrupter,
|
||||
devices,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_optim_main(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
global_progress: &Progress,
|
||||
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!(
|
||||
"Executing training step for epoch {} on devices {:?}",
|
||||
epoch,
|
||||
devices
|
||||
);
|
||||
|
||||
let mut iterators = self
|
||||
.dataloaders
|
||||
.iter()
|
||||
.map(|d| d.iter())
|
||||
.collect::<Vec<_>>();
|
||||
let mut iteration = 0;
|
||||
let mut accumulator = GradientsAccumulator::new();
|
||||
let mut accumulation_current = 0;
|
||||
|
||||
let accumulation = self.grad_accumulation.unwrap_or(1);
|
||||
let step = MultiDevicesTrainStep::<LC>::new(&devices);
|
||||
|
||||
// The main device is always the first in the list.
|
||||
let device_main = devices.first().expect("A minimum of one device.").clone();
|
||||
|
||||
loop {
|
||||
let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model());
|
||||
if items.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
learner.lr_step();
|
||||
|
||||
let mut progress_items = Vec::with_capacity(items.len());
|
||||
for item in items.into_iter() {
|
||||
let grads = item.output.grads.to_device(&device_main, &learner.model());
|
||||
accumulator.accumulate(&learner.model(), grads);
|
||||
progress_items.push(item.output.item);
|
||||
}
|
||||
|
||||
accumulation_current += 1;
|
||||
|
||||
if accumulation <= accumulation_current {
|
||||
let grads = accumulator.grads();
|
||||
learner.optimizer_step(grads);
|
||||
accumulation_current = 0;
|
||||
}
|
||||
|
||||
for item in progress_items {
|
||||
iteration += 1;
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress.clone(),
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
event_processor.process_train(LearnerEvent::ProcessedItem(item));
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
event_processor.process_train(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
|
||||
fn run_optim_distr(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
global_progress: &Progress,
|
||||
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!(
|
||||
"Executing training step for epoch {} on devices {:?}",
|
||||
epoch,
|
||||
devices
|
||||
);
|
||||
|
||||
let mut iterators = self
|
||||
.dataloaders
|
||||
.iter()
|
||||
.map(|d| d.iter())
|
||||
.collect::<Vec<_>>();
|
||||
let mut iteration = 0;
|
||||
let mut accumulators = HashMap::<
|
||||
DeviceId,
|
||||
GradientsAccumulator<<LC as LearningComponentsTypes>::TrainingModel>,
|
||||
>::new();
|
||||
for device in devices.iter() {
|
||||
accumulators.insert(device.to_id(), GradientsAccumulator::new());
|
||||
}
|
||||
let mut accumulation_current = 0;
|
||||
|
||||
let accumulation = self.grad_accumulation.unwrap_or(1);
|
||||
let step = MultiDevicesTrainStep::<LC>::new(&devices);
|
||||
|
||||
loop {
|
||||
let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model());
|
||||
if items.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
learner.lr_step();
|
||||
|
||||
let mut progress_items = Vec::with_capacity(items.len());
|
||||
for item in items.into_iter() {
|
||||
let accumulator = accumulators.get_mut(&item.device).unwrap();
|
||||
accumulator.accumulate(&learner.model(), item.output.grads);
|
||||
progress_items.push(item.output.item);
|
||||
}
|
||||
|
||||
accumulation_current += 1;
|
||||
|
||||
if accumulation <= accumulation_current {
|
||||
let mut grads = MultiGradientsParams::default();
|
||||
for (device_id, accumulator) in accumulators.iter_mut() {
|
||||
let grad = accumulator.grads();
|
||||
grads.grads.push((grad, *device_id));
|
||||
}
|
||||
learner.optimizer_step_multi(grads);
|
||||
accumulation_current = 0;
|
||||
}
|
||||
|
||||
for item in progress_items {
|
||||
iteration += 1;
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress.clone(),
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
event_processor.process_train(LearnerEvent::ProcessedItem(item));
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
event_processor.process_train(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub(crate) mod epoch;
|
||||
mod strategy;
|
||||
|
||||
pub use strategy::*;
|
||||
@@ -0,0 +1,100 @@
|
||||
use crate::{
|
||||
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy,
|
||||
SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents,
|
||||
TrainingModel, ValidLoader,
|
||||
multi::epoch::MultiDeviceTrainEpoch,
|
||||
single::{TrainingLoop, epoch::SingleDeviceValidEpoch},
|
||||
};
|
||||
use burn_core::{
|
||||
data::dataloader::split::split_dataloader,
|
||||
tensor::{Device, backend::DeviceOps},
|
||||
};
|
||||
|
||||
pub struct MultiDeviceLearningStrategy<LC: LearningComponentsTypes> {
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
optim: MultiDeviceOptim,
|
||||
}
|
||||
impl<LC: LearningComponentsTypes> MultiDeviceLearningStrategy<LC> {
|
||||
pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, optim: MultiDeviceOptim) -> Self {
|
||||
Self { devices, optim }
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
for MultiDeviceLearningStrategy<LC>
|
||||
{
|
||||
fn fit(
|
||||
&self,
|
||||
training_components: TrainingComponents<LC>,
|
||||
mut learner: Learner<LC>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: ValidLoader<LC>,
|
||||
starting_epoch: usize,
|
||||
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
|
||||
let main_device = self.devices.first().unwrap();
|
||||
|
||||
// `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy
|
||||
// for each (worker) data loader. This matches the expected device on the worker, so we
|
||||
// don't have to move the data between devices.
|
||||
let dataloader_train = split_dataloader(dataloader_train, &self.devices);
|
||||
let dataloader_valid = dataloader_valid.to_device(main_device.inner());
|
||||
|
||||
learner.fork(main_device);
|
||||
let mut event_processor = training_components.event_processor;
|
||||
let mut checkpointer = training_components.checkpointer;
|
||||
let mut early_stopping = training_components.early_stopping;
|
||||
|
||||
let epoch_train = MultiDeviceTrainEpoch::<LC>::new(
|
||||
dataloader_train.clone(),
|
||||
training_components.grad_accumulation,
|
||||
);
|
||||
let epoch_valid: SingleDeviceValidEpoch<LC> =
|
||||
SingleDeviceValidEpoch::new(dataloader_valid.clone());
|
||||
|
||||
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
|
||||
let epoch = training_progress.items_processed;
|
||||
epoch_train.run(
|
||||
&mut learner,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
self.devices.to_vec(),
|
||||
self.optim,
|
||||
);
|
||||
|
||||
if training_components.interrupter.should_stop() {
|
||||
let reason = training_components
|
||||
.interrupter
|
||||
.get_message()
|
||||
.unwrap_or(String::from("Reason unknown"));
|
||||
log::info!("Training interrupted: {reason}");
|
||||
break;
|
||||
}
|
||||
|
||||
// After OptimSharded training, model parameters are scattered across
|
||||
// devices. Fork back to main_device before single-device validation.
|
||||
if matches!(self.optim, MultiDeviceOptim::OptimSharded) {
|
||||
learner.fork(main_device);
|
||||
}
|
||||
|
||||
epoch_valid.run(
|
||||
&learner,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
);
|
||||
|
||||
if let Some(checkpointer) = &mut checkpointer {
|
||||
checkpointer.checkpoint(&learner, epoch, &training_components.event_store);
|
||||
}
|
||||
|
||||
if let Some(early_stopping) = &mut early_stopping
|
||||
&& early_stopping.should_stop(epoch, &training_components.event_store)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(learner.model(), event_processor)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
|
||||
use crate::{
|
||||
InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader,
|
||||
ValidLoader,
|
||||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_optim::GradientsAccumulator;
|
||||
|
||||
/// A validation epoch.
|
||||
#[derive(new)]
|
||||
pub struct SingleDeviceValidEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: ValidLoader<LC>,
|
||||
}
|
||||
|
||||
/// A training epoch.
|
||||
#[derive(new)]
|
||||
pub struct SingleDeviceTrainEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: TrainLoader<LC>,
|
||||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {
|
||||
/// Runs the validation epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to validate.
|
||||
/// * `processor` - The event processor to use.
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &Learner<LC>,
|
||||
global_progress: &Progress,
|
||||
processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing validation step for epoch {}", epoch);
|
||||
let model = learner.model().valid();
|
||||
|
||||
let mut iterator = self.dataloader.iter();
|
||||
let mut iteration = 0;
|
||||
|
||||
while let Some(item) = iterator.next() {
|
||||
let progress = iterator.progress();
|
||||
iteration += 1;
|
||||
|
||||
let item = model.step(item);
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
None,
|
||||
);
|
||||
|
||||
processor.process_valid(LearnerEvent::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
processor.process_valid(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {
|
||||
/// Runs the training epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `scheduler` - The learning rate scheduler to use.
|
||||
/// * `processor` - The event processor to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The trained model and the optimizer.
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
global_progress: &Progress,
|
||||
processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing training step for epoch {}", epoch,);
|
||||
|
||||
// Single device / dataloader
|
||||
let mut iterator = self.dataloader.iter();
|
||||
let mut iteration = 0;
|
||||
let mut accumulator = GradientsAccumulator::new();
|
||||
let mut accumulation_current = 0;
|
||||
|
||||
while let Some(item) = iterator.next() {
|
||||
iteration += 1;
|
||||
learner.lr_step();
|
||||
log::info!("Iteration {iteration}");
|
||||
|
||||
let progress = iterator.progress();
|
||||
let item = learner.train_step(item);
|
||||
|
||||
match self.grad_accumulation {
|
||||
Some(accumulation) => {
|
||||
accumulator.accumulate(&learner.model(), item.grads);
|
||||
accumulation_current += 1;
|
||||
|
||||
if accumulation <= accumulation_current {
|
||||
let grads = accumulator.grads();
|
||||
|
||||
learner.optimizer_step(grads);
|
||||
accumulation_current = 0;
|
||||
}
|
||||
}
|
||||
None => learner.optimizer_step(item.grads),
|
||||
}
|
||||
|
||||
let item = TrainingItem::new(
|
||||
item.item,
|
||||
progress,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
processor.process_train(LearnerEvent::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
processor.process_train(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub(crate) mod epoch;
|
||||
mod strategy;
|
||||
|
||||
pub use strategy::*;
|
||||
@@ -0,0 +1,108 @@
|
||||
use crate::{
|
||||
Learner, LearningComponentsTypes, SupervisedLearningStrategy, SupervisedTrainingEventProcessor,
|
||||
TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader,
|
||||
single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch},
|
||||
};
|
||||
use burn_core::{
|
||||
data::dataloader::Progress,
|
||||
tensor::{Device, backend::DeviceOps},
|
||||
};
|
||||
|
||||
/// Simplest learning strategy possible, with only a single devices doing both the training and
|
||||
/// validation.
|
||||
pub struct SingleDeviceTrainingStrategy<LC: LearningComponentsTypes> {
|
||||
device: Device<TrainingBackend<LC>>,
|
||||
}
|
||||
impl<LC: LearningComponentsTypes> SingleDeviceTrainingStrategy<LC> {
|
||||
pub fn new(device: Device<TrainingBackend<LC>>) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct TrainingLoop {
|
||||
next_iteration: usize,
|
||||
total_iteration: usize,
|
||||
}
|
||||
|
||||
/// An iterator that returns the progress of the training.
|
||||
impl Iterator for TrainingLoop {
|
||||
type Item = Progress;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.next_iteration > self.total_iteration {
|
||||
return None;
|
||||
}
|
||||
|
||||
let progress = Progress {
|
||||
items_processed: self.next_iteration,
|
||||
items_total: self.total_iteration,
|
||||
};
|
||||
|
||||
self.next_iteration += 1;
|
||||
Some(progress)
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
for SingleDeviceTrainingStrategy<LC>
|
||||
{
|
||||
fn fit(
|
||||
&self,
|
||||
training_components: TrainingComponents<LC>,
|
||||
mut learner: Learner<LC>,
|
||||
dataloader_train: TrainLoader<LC>,
|
||||
dataloader_valid: ValidLoader<LC>,
|
||||
starting_epoch: usize,
|
||||
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
|
||||
let dataloader_train = dataloader_train.to_device(&self.device);
|
||||
let dataloader_valid = dataloader_valid.to_device(self.device.inner());
|
||||
learner.fork(&self.device);
|
||||
let mut event_processor = training_components.event_processor;
|
||||
let mut checkpointer = training_components.checkpointer;
|
||||
let mut early_stopping = training_components.early_stopping;
|
||||
|
||||
let epoch_train: SingleDeviceTrainEpoch<LC> =
|
||||
SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation);
|
||||
let epoch_valid: SingleDeviceValidEpoch<LC> =
|
||||
SingleDeviceValidEpoch::new(dataloader_valid.clone());
|
||||
|
||||
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
|
||||
let epoch = training_progress.items_processed;
|
||||
epoch_train.run(
|
||||
&mut learner,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
);
|
||||
|
||||
if training_components.interrupter.should_stop() {
|
||||
let reason = training_components
|
||||
.interrupter
|
||||
.get_message()
|
||||
.unwrap_or(String::from("Reason unknown"));
|
||||
log::info!("Training interrupted: {reason}");
|
||||
break;
|
||||
}
|
||||
|
||||
epoch_valid.run(
|
||||
&learner,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
);
|
||||
|
||||
if let Some(checkpointer) = &mut checkpointer {
|
||||
checkpointer.checkpoint(&learner, epoch, &training_components.event_store);
|
||||
}
|
||||
|
||||
if let Some(early_stopping) = &mut early_stopping
|
||||
&& early_stopping.should_stop(epoch, &training_components.event_store)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(learner.model(), event_processor)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
use crate::{ItemLazy, renderer::MetricsRenderer};
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
|
||||
|
||||
/// A training output.
|
||||
pub struct TrainOutput<TO> {
|
||||
/// The gradients.
|
||||
pub grads: GradientsParams,
|
||||
|
||||
/// The item.
|
||||
pub item: TO,
|
||||
}
|
||||
|
||||
impl<TO> TrainOutput<TO> {
|
||||
/// Creates a new training output.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module` - The module.
|
||||
/// * `grads` - The gradients.
|
||||
/// * `item` - The item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new training output.
|
||||
pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
module: &M,
|
||||
grads: B::Gradients,
|
||||
item: TO,
|
||||
) -> Self {
|
||||
let grads = GradientsParams::from_grads(grads, module);
|
||||
Self { grads, item }
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to be implemented for models to be able to be trained.
|
||||
///
|
||||
/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
|
||||
///
|
||||
/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
|
||||
/// optimizer is used to update the model. This can be useful if you want to call custom mutable
|
||||
/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// To be used with the [Learner](crate::Learner) struct, the struct which implements this trait must
|
||||
/// also implement the [AutodiffModule] trait, which is done automatically with the
|
||||
/// [Module](burn_core::module::Module) derive.
|
||||
pub trait TrainStep {
|
||||
/// Type of input for a step of the training stage.
|
||||
type Input: Send + 'static;
|
||||
/// Type of output for a step of the training stage.
|
||||
type Output: ItemLazy + 'static;
|
||||
/// Runs a step for training, which executes the forward and backward passes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The input for the model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output containing the model output and the gradients.
|
||||
fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;
|
||||
/// Optimize the current module with the provided gradients and learning rate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `optim`: Optimizer used for learning.
|
||||
/// * `lr`: The learning rate used for this step.
|
||||
/// * `grads`: The gradients of each parameter in the current model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The updated model.
|
||||
fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
O: Optimizer<Self, B>,
|
||||
Self: AutodiffModule<B>,
|
||||
{
|
||||
optim.step(lr, self, grads)
|
||||
}
|
||||
/// Optimize the current module with the provided gradients and learning rate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `optim`: Optimizer used for learning.
|
||||
/// * `lr`: The learning rate used for this step.
|
||||
/// * `grads`: Multiple gradients associated to each parameter in the current model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The updated model.
|
||||
fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
O: Optimizer<Self, B>,
|
||||
Self: AutodiffModule<B>,
|
||||
{
|
||||
optim.step_multi(lr, self, grads)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to be implemented for validating models.
|
||||
pub trait InferenceStep {
|
||||
/// Type of input for an inference step.
|
||||
type Input: Send + 'static;
|
||||
/// Type of output for an inference step.
|
||||
type Output: ItemLazy + 'static;
|
||||
/// Runs a validation step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item to validate on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The validation output.
|
||||
fn step(&self, item: Self::Input) -> Self::Output;
|
||||
}
|
||||
|
||||
/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
|
||||
pub struct LearningResult<M> {
|
||||
/// The model with the learned weights.
|
||||
pub model: M,
|
||||
/// The renderer that can be used for follow up training and evaluation.
|
||||
pub renderer: Box<dyn MetricsRenderer>,
|
||||
}
|
||||
118
crates/stable-diffusion-burn/burn-crates/burn-train/src/lib.rs
Normal file
118
crates/stable-diffusion-burn/burn-crates/burn-train/src/lib.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! A library for training neural networks using the burn crate.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// The checkpoint module.
|
||||
pub mod checkpoint;
|
||||
|
||||
pub(crate) mod components;
|
||||
|
||||
/// Renderer modules to display metrics and training information.
|
||||
pub mod renderer;
|
||||
|
||||
/// The logger module.
|
||||
pub mod logger;
|
||||
|
||||
/// The metric module.
|
||||
pub mod metric;
|
||||
|
||||
pub use metric::processor::*;
|
||||
|
||||
mod learner;
|
||||
|
||||
pub use learner::*;
|
||||
|
||||
mod evaluator;
|
||||
|
||||
pub use evaluator::*;
|
||||
|
||||
pub use components::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::TestBackend;
|
||||
use burn_core::{prelude::Tensor, tensor::Bool};
|
||||
use std::default::Default;
|
||||
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
|
||||
/// Probability of tp before adding errors
|
||||
pub const THRESHOLD: f64 = 0.5;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub enum ClassificationType {
|
||||
#[default]
|
||||
Binary,
|
||||
Multiclass,
|
||||
Multilabel,
|
||||
}
|
||||
|
||||
/// Sample x Class shaped matrix for use in
|
||||
/// classification metrics testing
|
||||
pub fn dummy_classification_input(
|
||||
classification_type: &ClassificationType,
|
||||
) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
|
||||
match classification_type {
|
||||
ClassificationType::Binary => {
|
||||
(
|
||||
Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
|
||||
// targets
|
||||
Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
|
||||
// predictions @ threshold=0.5
|
||||
// [[0], [0], [1], [0], [1]]
|
||||
)
|
||||
}
|
||||
ClassificationType::Multiclass => {
|
||||
(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.2, 0.8, 0.0],
|
||||
[0.3, 0.6, 0.1],
|
||||
[0.7, 0.25, 0.05],
|
||||
[0.1, 0.15, 0.8],
|
||||
[0.9, 0.03, 0.07],
|
||||
],
|
||||
&Default::default(),
|
||||
),
|
||||
Tensor::from_data(
|
||||
// targets
|
||||
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
|
||||
// predictions @ top_k=1
|
||||
// [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
|
||||
// predictions @ top_k=2
|
||||
// [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
|
||||
&Default::default(),
|
||||
),
|
||||
)
|
||||
}
|
||||
ClassificationType::Multilabel => {
|
||||
(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.7, 0.6],
|
||||
[0.3, 0.9, 0.05],
|
||||
[0.8, 0.9, 0.4],
|
||||
[0.7, 0.5, 0.9],
|
||||
[1.0, 0.3, 0.2],
|
||||
],
|
||||
&Default::default(),
|
||||
),
|
||||
// targets
|
||||
Tensor::from_data(
|
||||
[[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
|
||||
// predictions @ threshold=0.5
|
||||
// [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
|
||||
&Default::default(),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
use super::Logger;
|
||||
use std::sync::mpsc;
|
||||
|
||||
enum Message<T> {
|
||||
Log(T),
|
||||
End,
|
||||
Sync(mpsc::Sender<()>),
|
||||
}
|
||||
/// Async logger.
|
||||
pub struct AsyncLogger<T> {
|
||||
sender: mpsc::Sender<Message<T>>,
|
||||
handler: Option<std::thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct LoggerThread<T, L: Logger<T>> {
|
||||
logger: L,
|
||||
receiver: mpsc::Receiver<Message<T>>,
|
||||
}
|
||||
|
||||
impl<T, L> LoggerThread<T, L>
|
||||
where
|
||||
L: Logger<T>,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::Log(item) => {
|
||||
self.logger.log(item);
|
||||
}
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
Message::Sync(callback) => {
|
||||
callback
|
||||
.send(())
|
||||
.expect("Can return result with the callback channel.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send + Sync + 'static> AsyncLogger<T> {
|
||||
/// Create a new async logger.
|
||||
pub fn new<L>(logger: L) -> Self
|
||||
where
|
||||
L: Logger<T> + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = LoggerThread::new(logger, receiver);
|
||||
|
||||
let handler = Some(std::thread::spawn(move || thread.run()));
|
||||
|
||||
Self { sender, handler }
|
||||
}
|
||||
|
||||
/// Sync the async logger.
|
||||
pub(crate) fn sync(&self) {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
|
||||
self.sender
|
||||
.send(Message::Sync(sender))
|
||||
.expect("Can send message to logger thread.");
|
||||
|
||||
receiver
|
||||
.recv()
|
||||
.expect("Should sync, otherwise the thread is dead.");
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Logger<T> for AsyncLogger<T> {
|
||||
fn log(&mut self, item: T) {
|
||||
self.sender
|
||||
.send(Message::Log(item))
|
||||
.expect("Can log using the logger thread.");
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for AsyncLogger<T> {
|
||||
fn drop(&mut self) {
|
||||
self.sender
|
||||
.send(Message::End)
|
||||
.expect("Can send the end message to the logger thread.");
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler.join().expect("The logger thread should stop.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
/// The logger trait.
|
||||
pub trait Logger<T>: Send {
|
||||
/// Logs an item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item.
|
||||
fn log(&mut self, item: T);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
use super::Logger;
|
||||
use std::{fs::File, io::Write, path::Path};
|
||||
|
||||
/// File logger.
|
||||
pub struct FileLogger {
|
||||
file: File,
|
||||
}
|
||||
|
||||
impl FileLogger {
|
||||
/// Create a new file logger.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - The path.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file logger.
|
||||
pub fn new(path: impl AsRef<Path>) -> Self {
|
||||
let path = path.as_ref();
|
||||
let mut options = std::fs::File::options();
|
||||
let file = options
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.create(true)
|
||||
.open(path)
|
||||
.unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Should be able to create the new file '{}': {}",
|
||||
path.display(),
|
||||
err
|
||||
)
|
||||
});
|
||||
|
||||
Self { file }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Logger<T> for FileLogger
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
fn log(&mut self, item: T) {
|
||||
writeln!(&mut self.file, "{item}").expect("Can log an item.");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use super::Logger;
|
||||
|
||||
/// In memory logger.
|
||||
#[derive(Default)]
|
||||
pub struct InMemoryLogger {
|
||||
pub(crate) values: Vec<String>,
|
||||
}
|
||||
|
||||
impl<T> Logger<T> for InMemoryLogger
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
fn log(&mut self, item: T) {
|
||||
self.values.push(item.to_string());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
|
||||
use crate::metric::{
|
||||
MetricDefinition, MetricEntry, MetricId, NumericEntry,
|
||||
store::{EpochSummary, MetricsUpdate, Split},
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
const EPOCH_PREFIX: &str = "epoch-";
|
||||
|
||||
/// Metric logger.
|
||||
pub trait MetricLogger: Send {
|
||||
/// Logs an item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `update` - Update information for all registered metrics.
|
||||
/// * `epoch` - Current epoch.
|
||||
/// * `split` - Current dataset split.
|
||||
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split);
|
||||
|
||||
/// Read the logs for an epoch.
|
||||
fn read_numeric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
split: &Split,
|
||||
) -> Result<Vec<NumericEntry>, String>;
|
||||
|
||||
/// Logs the metric definition information (name, description, unit, etc.)
|
||||
fn log_metric_definition(&mut self, definition: MetricDefinition);
|
||||
|
||||
/// Logs summary at the end of the epoch.
|
||||
fn log_epoch_summary(&mut self, summary: EpochSummary);
|
||||
}
|
||||
|
||||
/// The file metric logger.
|
||||
pub struct FileMetricLogger {
|
||||
loggers: HashMap<String, AsyncLogger<String>>,
|
||||
directory: PathBuf,
|
||||
metric_definitions: HashMap<MetricId, MetricDefinition>,
|
||||
is_eval: bool,
|
||||
last_epoch: Option<usize>,
|
||||
}
|
||||
|
||||
impl FileMetricLogger {
|
||||
/// Create a new file metric logger.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file metric logger.
|
||||
pub fn new(directory: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
loggers: HashMap::new(),
|
||||
directory: directory.as_ref().to_path_buf(),
|
||||
metric_definitions: HashMap::default(),
|
||||
is_eval: false,
|
||||
last_epoch: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new file metric logger.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file metric logger.
|
||||
pub fn new_eval(directory: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
loggers: HashMap::new(),
|
||||
directory: directory.as_ref().to_path_buf(),
|
||||
metric_definitions: HashMap::default(),
|
||||
is_eval: true,
|
||||
last_epoch: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn split_exists(&self, split: &Split) -> bool {
|
||||
self.split_dir(split).is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn split_dir(&self, split: &Split) -> Option<PathBuf> {
|
||||
let split_path = match split {
|
||||
Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()),
|
||||
other => self.directory.join(other.to_string()),
|
||||
};
|
||||
(split_path.exists() && split_path.is_dir()).then_some(split_path)
|
||||
}
|
||||
|
||||
pub(crate) fn is_epoch_dir<P: AsRef<str>>(dirname: P) -> bool {
|
||||
dirname.as_ref().starts_with(EPOCH_PREFIX)
|
||||
}
|
||||
|
||||
/// Number of epochs recorded.
|
||||
pub(crate) fn epochs(&self) -> usize {
|
||||
if self.is_eval {
|
||||
log::warn!("Number of epochs not available when testing.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut max_epoch = 0;
|
||||
|
||||
// with split
|
||||
for path in fs::read_dir(&self.directory).unwrap() {
|
||||
let path = path.unwrap();
|
||||
|
||||
if fs::metadata(path.path()).unwrap().is_dir() {
|
||||
for split_path in fs::read_dir(path.path()).unwrap() {
|
||||
let split_path = split_path.unwrap();
|
||||
|
||||
if fs::metadata(split_path.path()).unwrap().is_dir() {
|
||||
let dir_name = split_path.file_name().into_string().unwrap();
|
||||
|
||||
if !dir_name.starts_with(EPOCH_PREFIX) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
|
||||
|
||||
if let Some(epoch) = epoch
|
||||
&& epoch > max_epoch
|
||||
{
|
||||
max_epoch = epoch;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
max_epoch
|
||||
}
|
||||
|
||||
fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf {
|
||||
let name = format!("{EPOCH_PREFIX}{epoch}");
|
||||
|
||||
match split {
|
||||
Split::Train | Split::Valid | Split::Test(None) => {
|
||||
self.directory.join(split.to_string()).join(name)
|
||||
}
|
||||
Split::Test(Some(tag)) => {
|
||||
let tag = format_tag(tag);
|
||||
self.directory.join(split.to_string()).join(tag).join(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_directory(&self, split: &Split) -> PathBuf {
|
||||
match split {
|
||||
Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(),
|
||||
Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)),
|
||||
}
|
||||
}
|
||||
|
||||
fn file_path(&self, name: &str, epoch: Option<usize>, split: &Split) -> PathBuf {
|
||||
let directory = match epoch {
|
||||
Some(epoch) => self.train_directory(epoch, split),
|
||||
None => self.eval_directory(split),
|
||||
};
|
||||
let name = name.replace(' ', "_");
|
||||
let name = format!("{name}.log");
|
||||
directory.join(name)
|
||||
}
|
||||
|
||||
fn create_directory(&self, epoch: Option<usize>, split: &Split) {
|
||||
let directory = match epoch {
|
||||
Some(epoch) => self.train_directory(epoch, split),
|
||||
None => self.eval_directory(split),
|
||||
};
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
}
|
||||
}
|
||||
|
||||
impl FileMetricLogger {
|
||||
fn log_item(&mut self, item: &MetricEntry, epoch: Option<usize>, split: &Split) {
|
||||
let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
|
||||
let key = logger_key(name, split);
|
||||
let value = &item.serialized_entry.serialized;
|
||||
|
||||
let logger = match self.loggers.get_mut(&key) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
self.create_directory(epoch, split);
|
||||
|
||||
let file_path = self.file_path(name, epoch, split);
|
||||
let logger = FileLogger::new(file_path);
|
||||
let logger = AsyncLogger::new(logger);
|
||||
|
||||
self.loggers.insert(key.clone(), logger);
|
||||
self.loggers
|
||||
.get_mut(&key)
|
||||
.expect("Can get the previously saved logger.")
|
||||
}
|
||||
};
|
||||
|
||||
logger.log(value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tag(tag: &str) -> String {
|
||||
tag.trim().replace(' ', "-").to_lowercase()
|
||||
}
|
||||
|
||||
impl MetricLogger for FileMetricLogger {
|
||||
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
|
||||
if !self.is_eval && self.last_epoch != Some(epoch) {
|
||||
self.loggers.clear();
|
||||
self.last_epoch = Some(epoch);
|
||||
}
|
||||
|
||||
let entries: Vec<_> = update
|
||||
.entries
|
||||
.iter()
|
||||
.chain(
|
||||
update
|
||||
.entries_numeric
|
||||
.iter()
|
||||
.map(|numeric_update| &numeric_update.entry),
|
||||
)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
for item in entries.iter() {
|
||||
self.log_item(item, Some(epoch), split);
|
||||
}
|
||||
}
|
||||
|
||||
fn read_numeric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
split: &Split,
|
||||
) -> Result<Vec<NumericEntry>, String> {
|
||||
if let Some(value) = self.loggers.get(name) {
|
||||
value.sync()
|
||||
}
|
||||
|
||||
let file_path = self.file_path(name, Some(epoch), split);
|
||||
|
||||
let mut errors = false;
|
||||
|
||||
let data = std::fs::read_to_string(file_path)
|
||||
.unwrap_or_default()
|
||||
.split('\n')
|
||||
.filter_map(|value| {
|
||||
if value.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match NumericEntry::deserialize(value) {
|
||||
Ok(value) => Some(value),
|
||||
Err(err) => {
|
||||
log::error!("{err}");
|
||||
errors = true;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if errors {
|
||||
Err("Parsing numeric entry errors".to_string())
|
||||
} else {
|
||||
Ok(data)
|
||||
}
|
||||
}
|
||||
|
||||
fn log_metric_definition(&mut self, definition: MetricDefinition) {
|
||||
self.metric_definitions
|
||||
.insert(definition.metric_id.clone(), definition);
|
||||
}
|
||||
|
||||
fn log_epoch_summary(&mut self, _summary: EpochSummary) {
|
||||
if !self.is_eval {
|
||||
self.loggers.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn logger_key(name: &str, split: &Split) -> String {
|
||||
format!("{name}_{split}")
|
||||
}
|
||||
|
||||
/// In memory metric logger, useful when testing and debugging.
|
||||
#[derive(Default)]
|
||||
pub struct InMemoryMetricLogger {
|
||||
values: HashMap<String, Vec<InMemoryLogger>>,
|
||||
last_epoch: Option<usize>,
|
||||
metric_definitions: HashMap<MetricId, MetricDefinition>,
|
||||
}
|
||||
|
||||
impl InMemoryMetricLogger {
|
||||
/// Create a new in-memory metric logger.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricLogger for InMemoryMetricLogger {
|
||||
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
|
||||
if self.last_epoch != Some(epoch) {
|
||||
self.values
|
||||
.values_mut()
|
||||
.for_each(|loggers| loggers.push(InMemoryLogger::default()));
|
||||
self.last_epoch = Some(epoch);
|
||||
}
|
||||
|
||||
let entries: Vec<_> = update
|
||||
.entries
|
||||
.iter()
|
||||
.chain(
|
||||
update
|
||||
.entries_numeric
|
||||
.iter()
|
||||
.map(|numeric_update| &numeric_update.entry),
|
||||
)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
for item in entries.iter() {
|
||||
let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
|
||||
let key = logger_key(name, split);
|
||||
|
||||
if !self.values.contains_key(&key) {
|
||||
self.values
|
||||
.insert(key.to_string(), vec![InMemoryLogger::default()]);
|
||||
}
|
||||
|
||||
let values = self.values.get_mut(&key).unwrap();
|
||||
|
||||
values
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.log(item.serialized_entry.serialized.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn read_numeric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
split: &Split,
|
||||
) -> Result<Vec<NumericEntry>, String> {
|
||||
let key = logger_key(name, split);
|
||||
let values = match self.values.get(&key) {
|
||||
Some(values) => values,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
match values.get(epoch - 1) {
|
||||
Some(logger) => Ok(logger
|
||||
.values
|
||||
.iter()
|
||||
.filter_map(|value| NumericEntry::deserialize(value).ok())
|
||||
.collect()),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_metric_definition(&mut self, definition: MetricDefinition) {
|
||||
self.metric_definitions
|
||||
.insert(definition.metric_id.clone(), definition);
|
||||
}
|
||||
|
||||
fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
mod async_logger;
|
||||
mod base;
|
||||
mod file;
|
||||
mod in_memory;
|
||||
mod metric;
|
||||
|
||||
pub use async_logger::*;
|
||||
pub use base::*;
|
||||
pub use file::*;
|
||||
pub use in_memory::*;
|
||||
pub use metric::*;
|
||||
@@ -0,0 +1,164 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::MetricMetadata;
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{ElementConversion, Int, Tensor};
|
||||
|
||||
/// The accuracy metric.
|
||||
#[derive(Clone)]
|
||||
pub struct AccuracyMetric<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
pad_token: Option<usize>,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [accuracy metric](AccuracyMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct AccuracyInput<B: Backend> {
|
||||
outputs: Tensor<B, 2>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for AccuracyMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> AccuracyMetric<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: MetricName::new("Accuracy".to_string()),
|
||||
state: Default::default(),
|
||||
pad_token: Default::default(),
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the pad token.
|
||||
pub fn with_pad_token(mut self, index: usize) -> Self {
|
||||
self.pad_token = Some(index);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for AccuracyMetric<B> {
|
||||
type Input = AccuracyInput<B>;
|
||||
|
||||
fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let targets = input.targets.clone();
|
||||
let outputs = input.outputs.clone();
|
||||
|
||||
let [batch_size, _n_classes] = outputs.dims();
|
||||
|
||||
let outputs = outputs.argmax(1).reshape([batch_size]);
|
||||
|
||||
let accuracy = match self.pad_token {
|
||||
Some(pad_token) => {
|
||||
let mask = targets.clone().equal_elem(pad_token as i64);
|
||||
let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
|
||||
let num_pad = mask.float().sum();
|
||||
|
||||
let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
|
||||
|
||||
acc.into_scalar().elem::<f64>()
|
||||
}
|
||||
None => {
|
||||
outputs
|
||||
.equal(targets)
|
||||
.int()
|
||||
.sum()
|
||||
.into_scalar()
|
||||
.elem::<f64>()
|
||||
/ batch_size as f64
|
||||
}
|
||||
};
|
||||
|
||||
self.state.update(
|
||||
100.0 * accuracy,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
super::NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for AccuracyMetric<B> {
|
||||
fn value(&self) -> super::NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> super::NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_accuracy_without_padding() {
|
||||
let device = Default::default();
|
||||
let mut metric = AccuracyMetric::<TestBackend>::new();
|
||||
let input = AccuracyInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.2, 0.8], // 2
|
||||
[1.0, 2.0, 0.5], // 1
|
||||
[0.4, 0.1, 0.2], // 0
|
||||
[0.6, 0.7, 0.2], // 1
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2, 2, 1, 1], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(50.0, metric.value().current());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accuracy_with_padding() {
|
||||
let device = Default::default();
|
||||
let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
|
||||
let input = AccuracyInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.2, 0.8, 0.0], // 2
|
||||
[1.0, 2.0, 0.5, 0.0], // 1
|
||||
[0.4, 0.1, 0.2, 0.0], // 0
|
||||
[0.6, 0.7, 0.2, 0.0], // 1
|
||||
[0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
|
||||
[0.0, 0.1, 0.2, 0.0], // Error on padding should not count
|
||||
[0.6, 0.0, 0.2, 0.0], // Error on padding should not count
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(50.0, metric.value().current());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
use core::f64;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::MetricMetadata;
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{ElementConversion, Int, Tensor};
|
||||
|
||||
/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification.
|
||||
#[derive(Clone)]
|
||||
pub struct AurocMetric<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [AUROC metric](AurocMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct AurocInput<B: Backend> {
|
||||
outputs: Tensor<B, 2>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for AurocMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> AurocMetric<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: MetricName::new("AUROC".to_string()),
|
||||
state: Default::default(),
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn binary_auroc(&self, probabilities: &Tensor<B, 1>, targets: &Tensor<B, 1, Int>) -> f64 {
|
||||
let n = targets.dims()[0];
|
||||
|
||||
let n_pos = targets.clone().sum().into_scalar().elem::<u64>() as usize;
|
||||
|
||||
// Early return if we don't have both positive and negative samples
|
||||
if n_pos == 0 || n_pos == n {
|
||||
if n_pos == 0 {
|
||||
log::warn!("Metric cannot be computed because all target values are negative.")
|
||||
} else {
|
||||
log::warn!("Metric cannot be computed because all target values are positive.")
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
|
||||
let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
|
||||
|
||||
let valid_pairs = pos_mask * neg_mask;
|
||||
|
||||
let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
|
||||
let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
|
||||
|
||||
let correct_order = prob_i.clone().greater(prob_j.clone()).int();
|
||||
|
||||
let ties = prob_i.equal(prob_j).int();
|
||||
|
||||
// Calculate AUC components
|
||||
let num_pairs = valid_pairs.clone().sum().into_scalar().elem::<f64>();
|
||||
let correct_pairs = (correct_order * valid_pairs.clone())
|
||||
.sum()
|
||||
.into_scalar()
|
||||
.elem::<f64>();
|
||||
let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::<f64>();
|
||||
|
||||
(correct_pairs + 0.5 * tied_pairs) / num_pairs
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for AurocMetric<B> {
|
||||
type Input = AurocInput<B>;
|
||||
|
||||
fn update(&mut self, input: &AurocInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let [batch_size, num_classes] = input.outputs.dims();
|
||||
|
||||
assert_eq!(
|
||||
num_classes, 2,
|
||||
"Currently only binary classification is supported"
|
||||
);
|
||||
|
||||
let probabilities = {
|
||||
let exponents = input.outputs.clone().exp();
|
||||
let sum = exponents.clone().sum_dim(1);
|
||||
(exponents / sum)
|
||||
.select(1, Tensor::arange(1..2, &input.outputs.device()))
|
||||
.squeeze_dim(1)
|
||||
};
|
||||
|
||||
let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
|
||||
|
||||
self.state.update(
|
||||
100.0 * area_under_curve,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for AurocMetric<B> {
|
||||
fn value(&self) -> super::NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> super::NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_auroc() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::<TestBackend>::new();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.9], // High confidence positive
|
||||
[0.7, 0.3], // Low confidence negative
|
||||
[0.6, 0.4], // Low confidence negative
|
||||
[0.2, 0.8], // High confidence positive
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([1, 0, 0, 1], &device), // True labels
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_perfect_separation() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::<TestBackend>::new();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
|
||||
Tensor::from_data([1, 0, 0, 1], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 100.0); // Perfect AUC
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_random() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::<TestBackend>::new();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.5, 0.5], // Random predictions
|
||||
[0.5, 0.5],
|
||||
[0.5, 0.5],
|
||||
[0.5, 0.5],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([1, 0, 0, 1], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 50.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_all_one_class() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::<TestBackend>::new();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.9], // All positives predictions
|
||||
[0.2, 0.8],
|
||||
[0.3, 0.7],
|
||||
[0.4, 0.6],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([1, 1, 1, 1], &device), // All positive class
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Currently only binary classification is supported")]
|
||||
fn test_auroc_multiclass_error() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::<TestBackend>::new();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.2, 0.7], // More than 2 classes not supported
|
||||
[0.3, 0.5, 0.2],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2, 1], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_optim::LearningRate;
|
||||
|
||||
/// Metric metadata that can be used when computing metrics.
|
||||
pub struct MetricMetadata {
|
||||
/// The current progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The global progress of the training (e.g. epochs).
|
||||
pub global_progress: Progress,
|
||||
|
||||
/// The current iteration.
|
||||
pub iteration: Option<usize>,
|
||||
|
||||
/// The current learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
||||
|
||||
impl MetricMetadata {
|
||||
/// Fake metric metadata
|
||||
#[cfg(test)]
|
||||
pub fn fake() -> Self {
|
||||
Self {
|
||||
progress: Progress {
|
||||
items_processed: 1,
|
||||
items_total: 1,
|
||||
},
|
||||
global_progress: Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
},
|
||||
iteration: Some(0),
|
||||
lr: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
|
||||
/// For now we take the name as id to make sure that the same metric has the same id across different runs.
|
||||
#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
|
||||
pub struct MetricId {
|
||||
/// The metric id.
|
||||
id: Arc<String>,
|
||||
}
|
||||
|
||||
/// Metric attributes define the properties intrinsic to different types of metric.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum MetricAttributes {
|
||||
/// Numeric attributes.
|
||||
Numeric(NumericAttributes),
|
||||
/// No attributes.
|
||||
None,
|
||||
}
|
||||
|
||||
/// Definition of a metric.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MetricDefinition {
|
||||
/// The metric's id.
|
||||
pub metric_id: MetricId,
|
||||
/// The name of the metric.
|
||||
pub name: String,
|
||||
/// The description of the metric.
|
||||
pub description: Option<String>,
|
||||
/// The attributes of the metric.
|
||||
pub attributes: MetricAttributes,
|
||||
}
|
||||
|
||||
impl MetricDefinition {
|
||||
/// Create a new metric definition given the metric and a unique id.
|
||||
pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
|
||||
Self {
|
||||
metric_id,
|
||||
name: metric.name().to_string(),
|
||||
description: metric.description(),
|
||||
attributes: metric.attributes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metric trait.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Implementations should define their own input type only used by the metric.
|
||||
/// This is important since some conflict may happen when the model output is adapted for each
|
||||
/// metric's input type.
|
||||
pub trait Metric: Send + Sync + Clone {
|
||||
/// The input type of the metric.
|
||||
type Input;
|
||||
|
||||
/// The parameterized name of the metric.
|
||||
///
|
||||
/// This should be unique, so avoid using short generic names, prefer using the long name.
|
||||
///
|
||||
/// For a metric that can exist at different parameters (e.g., top-k accuracy for different
|
||||
/// values of k), the name should be unique for each instance.
|
||||
fn name(&self) -> MetricName;
|
||||
|
||||
/// A short description of the metric.
|
||||
fn description(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Attributes of the metric.
|
||||
///
|
||||
/// By default, metrics have no attributes.
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
MetricAttributes::None
|
||||
}
|
||||
|
||||
/// Update the metric state and returns the current metric entry.
|
||||
fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
|
||||
|
||||
/// Clear the metric state.
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
/// Type used to store metric names efficiently.
|
||||
pub type MetricName = Arc<String>;
|
||||
|
||||
/// Adaptor are used to transform types so that they can be used by metrics.
|
||||
///
|
||||
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
|
||||
/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).
|
||||
pub trait Adaptor<T> {
|
||||
/// Adapt the type to be passed to a [metric](Metric).
|
||||
fn adapt(&self) -> T;
|
||||
}
|
||||
|
||||
impl<T> Adaptor<()> for T {
|
||||
fn adapt(&self) {}
|
||||
}
|
||||
|
||||
/// Attributes that describe intrinsic properties of a numeric metric.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NumericAttributes {
|
||||
/// Optional unit (e.g. "%", "ms", "pixels")
|
||||
pub unit: Option<String>,
|
||||
/// Whether larger values are better (true) or smaller are better (false).
|
||||
pub higher_is_better: bool,
|
||||
}
|
||||
|
||||
impl From<NumericAttributes> for MetricAttributes {
|
||||
fn from(attr: NumericAttributes) -> Self {
|
||||
MetricAttributes::Numeric(attr)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NumericAttributes {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
unit: None,
|
||||
higher_is_better: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Declare a metric to be numeric.
|
||||
///
|
||||
/// This is useful to plot the values of a metric during training.
|
||||
pub trait Numeric {
|
||||
/// Returns the numeric value of the metric.
|
||||
fn value(&self) -> NumericEntry;
|
||||
/// Returns the current aggregated value of the metric over the global step (epoch).
|
||||
fn running_value(&self) -> NumericEntry;
|
||||
}
|
||||
|
||||
/// Serialized form of a metric entry.
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct SerializedEntry {
|
||||
/// The string to be displayed.
|
||||
pub formatted: String,
|
||||
/// The string to be saved.
|
||||
pub serialized: String,
|
||||
}
|
||||
|
||||
/// Data type that contains the current state of a metric at a given time.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetricEntry {
|
||||
/// Id of the entry's metric.
|
||||
pub metric_id: MetricId,
|
||||
/// The serialized form of the entry.
|
||||
pub serialized_entry: SerializedEntry,
|
||||
}
|
||||
|
||||
impl MetricEntry {
|
||||
/// Create a new metric.
|
||||
pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
|
||||
Self {
|
||||
metric_id,
|
||||
serialized_entry,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Numeric metric entry.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum NumericEntry {
|
||||
/// Single numeric value.
|
||||
Value(f64),
|
||||
/// Aggregated numeric (value, number of elements).
|
||||
Aggregated {
|
||||
/// The aggregated value of all entries.
|
||||
aggregated_value: f64,
|
||||
/// The number of entries present in the aggregated value.
|
||||
count: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl NumericEntry {
|
||||
/// Gets the current aggregated value of the metric.
|
||||
pub fn current(&self) -> f64 {
|
||||
match self {
|
||||
NumericEntry::Value(val) => *val,
|
||||
NumericEntry::Aggregated {
|
||||
aggregated_value, ..
|
||||
} => *aggregated_value,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a String representing the NumericEntry
|
||||
pub fn serialize(&self) -> String {
|
||||
match self {
|
||||
Self::Value(v) => v.to_string(),
|
||||
Self::Aggregated {
|
||||
aggregated_value,
|
||||
count,
|
||||
} => format!("{aggregated_value},{count}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
|
||||
pub fn deserialize(entry: &str) -> Result<Self, String> {
|
||||
// Check for comma separated values
|
||||
let values = entry.split(',').collect::<Vec<_>>();
|
||||
let num_values = values.len();
|
||||
|
||||
if num_values == 1 {
|
||||
// Numeric value
|
||||
match values[0].parse::<f64>() {
|
||||
Ok(value) => Ok(NumericEntry::Value(value)),
|
||||
Err(err) => Err(err.to_string()),
|
||||
}
|
||||
} else if num_values == 2 {
|
||||
// Aggregated numeric (value, number of elements)
|
||||
let (value, numel) = (values[0], values[1]);
|
||||
match value.parse::<f64>() {
|
||||
Ok(value) => match numel.parse::<usize>() {
|
||||
Ok(numel) => Ok(NumericEntry::Aggregated {
|
||||
aggregated_value: value,
|
||||
count: numel,
|
||||
}),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
Err(err) => Err(err.to_string()),
|
||||
}
|
||||
} else {
|
||||
Err("Invalid number of values for numeric entry".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare this numeric metric's value with another one using the specified direction.
|
||||
pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
|
||||
(self.current() > other.current()) == higher_is_better
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a float with the given precision. Will use scientific notation if necessary.
|
||||
pub fn format_float(float: f64, precision: usize) -> String {
|
||||
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
|
||||
|
||||
match scientific_notation_threshold >= float {
|
||||
true => format!("{float:.precision$e}"),
|
||||
false => format!("{float:.precision$}"),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use super::{MetricMetadata, SerializedEntry};
|
||||
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{Int, Tensor};
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Computes the edit distance (Levenshtein distance) between two sequences of integers.
|
||||
///
|
||||
/// The edit distance is defined as the minimum number of single-element edits (insertions,
|
||||
/// deletions, or substitutions) required to change one sequence into the other. This
|
||||
/// implementation is optimized for space, using only two rows of the dynamic programming table.
|
||||
///
|
||||
pub(crate) fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize {
|
||||
let mut prev = (0..=prediction.len()).collect::<Vec<_>>();
|
||||
let mut curr = vec![0; prediction.len() + 1];
|
||||
|
||||
for (i, &r) in reference.iter().enumerate() {
|
||||
curr[0] = i + 1;
|
||||
for (j, &p) in prediction.iter().enumerate() {
|
||||
curr[j + 1] = if r == p {
|
||||
prev[j] // no operation needed
|
||||
} else {
|
||||
1 + prev[j].min(prev[j + 1]).min(curr[j]) // substitution, insertion, deletion
|
||||
};
|
||||
}
|
||||
core::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
prev[prediction.len()]
|
||||
}
|
||||
|
||||
/// Character error rate (CER) is defined as the edit distance (e.g. Levenshtein distance) between the predicted
|
||||
/// and reference character sequences, divided by the total number of characters in the reference.
|
||||
/// This metric is commonly used in tasks such as speech recognition, OCR, or text generation
|
||||
/// to quantify how closely the predicted output matches the ground truth at a character level.
|
||||
///
|
||||
#[derive(Clone)]
|
||||
pub struct CharErrorRate<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
pad_token: Option<usize>,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [character error rate metric](CharErrorRate) input type.
|
||||
#[derive(new)]
|
||||
pub struct CerInput<B: Backend> {
|
||||
/// The predicted token sequences (as a 2-D tensor of token indices).
|
||||
pub outputs: Tensor<B, 2, Int>,
|
||||
/// The target token sequences (as a 2-D tensor of token indices).
|
||||
pub targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for CharErrorRate<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> CharErrorRate<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("CER".to_string()),
|
||||
state: NumericMetricState::default(),
|
||||
pad_token: None,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the pad token.
|
||||
pub fn with_pad_token(mut self, index: usize) -> Self {
|
||||
self.pad_token = Some(index);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// The [character error rate metric](CharErrorRate) implementation.
|
||||
impl<B: Backend> Metric for CharErrorRate<B> {
|
||||
type Input = CerInput<B>;
|
||||
|
||||
fn update(&mut self, input: &CerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let outputs = &input.outputs;
|
||||
let targets = &input.targets;
|
||||
let [batch_size, seq_len] = targets.dims();
|
||||
|
||||
let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token {
|
||||
// Create boolean masks for non-padding tokens.
|
||||
let output_mask = outputs.clone().not_equal_elem(pad as i64);
|
||||
let target_mask = targets.clone().not_equal_elem(pad as i64);
|
||||
|
||||
let output_lengths_tensor = output_mask.int().sum_dim(1);
|
||||
let target_lengths_tensor = target_mask.int().sum_dim(1);
|
||||
|
||||
(
|
||||
output_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
|
||||
target_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
|
||||
)
|
||||
} else {
|
||||
// If there's no padding, all sequences have the full length.
|
||||
(
|
||||
vec![seq_len as i64; batch_size],
|
||||
vec![seq_len as i64; batch_size],
|
||||
)
|
||||
};
|
||||
|
||||
let outputs_data = outputs.to_data().to_vec::<i64>().unwrap();
|
||||
let targets_data = targets.to_data().to_vec::<i64>().unwrap();
|
||||
|
||||
let total_edit_distance: usize = (0..batch_size)
|
||||
.map(|i| {
|
||||
let start = i * seq_len;
|
||||
|
||||
// Get pre-calculated lengths for the current sequence.
|
||||
let output_len = output_lengths[i] as usize;
|
||||
let target_len = target_lengths[i] as usize;
|
||||
|
||||
let output_seq_slice = &outputs_data[start..(start + output_len)];
|
||||
let target_seq_slice = &targets_data[start..(start + target_len)];
|
||||
let output_seq: Vec<i32> = output_seq_slice.iter().map(|&x| x as i32).collect();
|
||||
let target_seq: Vec<i32> = target_seq_slice.iter().map(|&x| x as i32).collect();
|
||||
|
||||
edit_distance(&target_seq, &output_seq)
|
||||
})
|
||||
.sum();
|
||||
|
||||
let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::<f64>();
|
||||
|
||||
let value = if total_target_length > 0.0 {
|
||||
100.0 * total_edit_distance as f64 / total_target_length
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
self.state.update(
|
||||
value,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset();
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
super::NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for CharErrorRate<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
/// Perfect match ⇒ CER = 0 %.
|
||||
#[test]
|
||||
fn test_cer_without_padding() {
|
||||
let device = Default::default();
|
||||
let mut metric = CharErrorRate::<TestBackend>::new();
|
||||
|
||||
// Batch size = 2, sequence length = 2
|
||||
let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
|
||||
let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
|
||||
|
||||
metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
|
||||
|
||||
assert_eq!(0.0, metric.value().current());
|
||||
}
|
||||
|
||||
/// Two edits in four target tokens ⇒ 50 %.
|
||||
#[test]
|
||||
fn test_cer_without_padding_two_errors() {
|
||||
let device = Default::default();
|
||||
let mut metric = CharErrorRate::<TestBackend>::new();
|
||||
|
||||
// One substitution in each sequence.
|
||||
let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
|
||||
let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
|
||||
|
||||
metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
|
||||
|
||||
// 2 edits / 4 tokens = 50 %
|
||||
assert_eq!(50.0, metric.value().current());
|
||||
}
|
||||
|
||||
/// Same scenario as above, but with right-padding (token 9) ignored.
|
||||
#[test]
|
||||
fn test_cer_with_padding() {
|
||||
let device = Default::default();
|
||||
let pad = 9_i64;
|
||||
let mut metric = CharErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
|
||||
|
||||
// Each row has three columns, last one is the pad token.
|
||||
let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
|
||||
let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
|
||||
|
||||
metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
|
||||
assert_eq!(50.0, metric.value().current());
|
||||
}
|
||||
|
||||
/// `clear()` must reset the running statistics to zero.
|
||||
#[test]
|
||||
fn test_clear_resets_state() {
|
||||
let device = Default::default();
|
||||
let mut metric = CharErrorRate::<TestBackend>::new();
|
||||
|
||||
let preds = Tensor::from_data([[1, 2]], &device);
|
||||
let tgts = Tensor::from_data([[1, 3]], &device); // one error
|
||||
|
||||
metric.update(
|
||||
&CerInput::new(preds.clone(), tgts.clone()),
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert!(metric.value().current() > 0.0);
|
||||
|
||||
metric.clear();
|
||||
assert!(metric.value().current().is_nan());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
/// Necessary data for classification metrics.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct ClassificationMetricConfig {
|
||||
pub decision_rule: DecisionRule,
|
||||
pub class_reduction: ClassReduction,
|
||||
}
|
||||
|
||||
/// The prediction decision rule for classification metrics.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DecisionRule {
|
||||
/// Consider a class predicted if its probability exceeds the threshold.
|
||||
Threshold(f64),
|
||||
/// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
|
||||
TopK(NonZeroUsize),
|
||||
}
|
||||
|
||||
impl Default for DecisionRule {
|
||||
fn default() -> Self {
|
||||
Self::Threshold(0.5)
|
||||
}
|
||||
}
|
||||
|
||||
/// The reduction strategy for classification metrics.
|
||||
#[derive(Copy, Clone, Default, Debug)]
|
||||
pub enum ClassReduction {
|
||||
/// Computes the statistics over all classes before averaging
|
||||
Micro,
|
||||
/// Computes the statistics independently for each class before averaging
|
||||
#[default]
|
||||
Macro,
|
||||
}
|
||||
@@ -0,0 +1,351 @@
|
||||
use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule};
|
||||
use burn_core::{
|
||||
prelude::{Backend, Bool, Int, Tensor},
|
||||
tensor::IndexingUpdateOp,
|
||||
};
|
||||
use std::fmt::{self, Debug};
|
||||
|
||||
/// Input for confusion statistics error types.
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct ConfusionStatsInput<B: Backend> {
|
||||
/// Sample x Class Non thresholded normalized predictions.
|
||||
pub predictions: Tensor<B, 2>,
|
||||
/// Sample x Class one-hot encoded target.
|
||||
pub targets: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
impl<B: Backend> From<ConfusionStatsInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {
|
||||
fn from(input: ConfusionStatsInput<B>) -> Self {
|
||||
(input.predictions, input.targets)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for ConfusionStatsInput<B> {
|
||||
fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {
|
||||
Self::new(value.0, value.1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConfusionStats<B: Backend> {
|
||||
confusion_classes: Tensor<B, 2, Int>,
|
||||
class_reduction: ClassReduction,
|
||||
}
|
||||
|
||||
impl<B: Backend> Debug for ConfusionStats<B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let to_vec = |tensor_data: Tensor<B, 1>| {
|
||||
tensor_data
|
||||
.to_data()
|
||||
.to_vec::<f32>()
|
||||
.expect("A vector representation of the input Tensor is expected")
|
||||
};
|
||||
let ratio_of_support_vec =
|
||||
|metric: Tensor<B, 1>| to_vec(self.clone().ratio_of_support(metric));
|
||||
f.debug_struct("ConfusionStats")
|
||||
.field("tp", &ratio_of_support_vec(self.clone().true_positive()))
|
||||
.field("fp", &ratio_of_support_vec(self.clone().false_positive()))
|
||||
.field("tn", &ratio_of_support_vec(self.clone().true_negative()))
|
||||
.field("fn", &ratio_of_support_vec(self.clone().false_negative()))
|
||||
.field("support", &to_vec(self.clone().support()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConfusionStats<B> {
|
||||
/// Expects `predictions` to be normalized.
|
||||
pub fn new(input: &ConfusionStatsInput<B>, config: &ClassificationMetricConfig) -> Self {
|
||||
let prediction_mask = match config.decision_rule {
|
||||
DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold),
|
||||
DecisionRule::TopK(top_k) => {
|
||||
let mask = input.predictions.zeros_like();
|
||||
let indexes =
|
||||
input
|
||||
.predictions
|
||||
.clone()
|
||||
.argsort_descending(1)
|
||||
.narrow(1, 0, top_k.get());
|
||||
let values = indexes.ones_like().float();
|
||||
mask.scatter(1, indexes, values, IndexingUpdateOp::Add)
|
||||
.bool()
|
||||
}
|
||||
};
|
||||
Self {
|
||||
confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2,
|
||||
class_reduction: config.class_reduction,
|
||||
}
|
||||
}
|
||||
|
||||
/// sum over samples
|
||||
fn aggregate(
|
||||
sample_class_mask: Tensor<B, 2, Bool>,
|
||||
class_reduction: ClassReduction,
|
||||
) -> Tensor<B, 1> {
|
||||
use ClassReduction::{Macro, Micro};
|
||||
match class_reduction {
|
||||
Micro => sample_class_mask.float().sum(),
|
||||
Macro => sample_class_mask.float().sum_dim(0).squeeze_dim(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn true_positive(self) -> Tensor<B, 1> {
|
||||
Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction)
|
||||
}
|
||||
|
||||
pub fn true_negative(self) -> Tensor<B, 1> {
|
||||
Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction)
|
||||
}
|
||||
|
||||
pub fn false_positive(self) -> Tensor<B, 1> {
|
||||
Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction)
|
||||
}
|
||||
|
||||
pub fn false_negative(self) -> Tensor<B, 1> {
|
||||
Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction)
|
||||
}
|
||||
|
||||
pub fn positive(self) -> Tensor<B, 1> {
|
||||
self.clone().true_positive() + self.false_negative()
|
||||
}
|
||||
|
||||
pub fn negative(self) -> Tensor<B, 1> {
|
||||
self.clone().true_negative() + self.false_positive()
|
||||
}
|
||||
|
||||
pub fn predicted_positive(self) -> Tensor<B, 1> {
|
||||
self.clone().true_positive() + self.false_positive()
|
||||
}
|
||||
|
||||
pub fn support(self) -> Tensor<B, 1> {
|
||||
self.clone().positive() + self.negative()
|
||||
}
|
||||
|
||||
pub fn ratio_of_support(self, metric: Tensor<B, 1>) -> Tensor<B, 1> {
|
||||
metric / self.clone().support()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ConfusionStats, ConfusionStatsInput};
|
||||
use crate::{
|
||||
TestBackend,
|
||||
metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
|
||||
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
|
||||
};
|
||||
use burn_core::prelude::TensorData;
|
||||
use rstest::{fixture, rstest};
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
fn top_k_config(
|
||||
top_k: NonZeroUsize,
|
||||
class_reduction: ClassReduction,
|
||||
) -> ClassificationMetricConfig {
|
||||
ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::TopK(top_k),
|
||||
class_reduction,
|
||||
}
|
||||
}
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn top_k_config_k1_micro() -> ClassificationMetricConfig {
|
||||
top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro)
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn top_k_config_k1_macro() -> ClassificationMetricConfig {
|
||||
top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro)
|
||||
}
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn top_k_config_k2_micro() -> ClassificationMetricConfig {
|
||||
top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro)
|
||||
}
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn top_k_config_k2_macro() -> ClassificationMetricConfig {
|
||||
top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro)
|
||||
}
|
||||
|
||||
fn threshold_config(
|
||||
threshold: f64,
|
||||
class_reduction: ClassReduction,
|
||||
) -> ClassificationMetricConfig {
|
||||
ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
class_reduction,
|
||||
}
|
||||
}
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn threshold_config_micro() -> ClassificationMetricConfig {
|
||||
threshold_config(THRESHOLD, ClassReduction::Micro)
|
||||
}
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn threshold_config_macro() -> ClassificationMetricConfig {
|
||||
threshold_config(THRESHOLD, ClassReduction::Macro)
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())]
|
||||
fn test_true_positive(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.true_positive()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())]
|
||||
fn test_true_negative(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.true_negative()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())]
|
||||
fn test_false_positive(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.false_positive()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())]
|
||||
fn test_false_negatives(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.false_negative()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())]
|
||||
fn test_positive(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.positive()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())]
|
||||
fn test_negative(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.negative()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
|
||||
#[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())]
|
||||
#[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
|
||||
#[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())]
|
||||
#[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())]
|
||||
#[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())]
|
||||
fn test_predicted_positive(
|
||||
#[case] classification_type: ClassificationType,
|
||||
#[case] config: ClassificationMetricConfig,
|
||||
#[case] expected: Vec<i64>,
|
||||
) {
|
||||
let input: ConfusionStatsInput<TestBackend> =
|
||||
dummy_classification_input(&classification_type).into();
|
||||
ConfusionStats::new(&input, &config)
|
||||
.predicted_positive()
|
||||
.int()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from(expected.as_slice()), true);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
/// CPU Temperature metric
|
||||
use super::MetricMetadata;
|
||||
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry};
|
||||
use systemstat::{Platform, System};
|
||||
|
||||
/// CPU Temperature in celsius degrees
|
||||
#[derive(Clone)]
|
||||
pub struct CpuTemperature {
|
||||
name: MetricName,
|
||||
temp_celsius: f32,
|
||||
sys: Arc<System>,
|
||||
}
|
||||
|
||||
impl CpuTemperature {
|
||||
/// Creates a new CPU temp metric
|
||||
pub fn new() -> Self {
|
||||
let name = Arc::new("CPU Temperature".to_string());
|
||||
|
||||
Self {
|
||||
name,
|
||||
temp_celsius: 0.,
|
||||
sys: Arc::new(System::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CpuTemperature {
|
||||
fn default() -> Self {
|
||||
CpuTemperature::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for CpuTemperature {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
match self.sys.cpu_temp() {
|
||||
Ok(temp) => self.temp_celsius = temp,
|
||||
Err(_) => self.temp_celsius = f32::NAN,
|
||||
}
|
||||
|
||||
let formatted = match self.temp_celsius.is_nan() {
|
||||
true => format!("{}: NaN °C", self.name()),
|
||||
false => format!("{}: {:.2} °C", self.name(), self.temp_celsius),
|
||||
};
|
||||
let raw = format!("{:.2}", self.temp_celsius);
|
||||
|
||||
SerializedEntry::new(formatted, raw)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
super::NumericAttributes {
|
||||
unit: Some("°C".to_string()),
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for CpuTemperature {
|
||||
fn value(&self) -> NumericEntry {
|
||||
NumericEntry::Value(self.temp_celsius as f64)
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
NumericEntry::Value(self.temp_celsius as f64)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use super::MetricMetadata;
|
||||
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry};
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use sysinfo::{CpuRefreshKind, RefreshKind, System};
|
||||
|
||||
/// General CPU Usage metric
|
||||
pub struct CpuUse {
|
||||
name: MetricName,
|
||||
last_refresh: Instant,
|
||||
refresh_frequency: Duration,
|
||||
sys: System,
|
||||
current: f64,
|
||||
}
|
||||
|
||||
impl Clone for CpuUse {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
name: self.name.clone(),
|
||||
last_refresh: self.last_refresh,
|
||||
refresh_frequency: self.refresh_frequency,
|
||||
sys: System::new(),
|
||||
current: self.current,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CpuUse {
|
||||
/// Creates a new CPU metric
|
||||
pub fn new() -> Self {
|
||||
let mut sys = System::new();
|
||||
let current = Self::refresh(&mut sys);
|
||||
let name = "CPU Usage".to_string();
|
||||
|
||||
Self {
|
||||
name: Arc::new(name),
|
||||
last_refresh: Instant::now(),
|
||||
refresh_frequency: Duration::from_millis(200),
|
||||
sys,
|
||||
current,
|
||||
}
|
||||
}
|
||||
|
||||
fn refresh(sys: &mut System) -> f64 {
|
||||
sys.refresh_specifics(
|
||||
RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()),
|
||||
);
|
||||
|
||||
let cpus = sys.cpus();
|
||||
let num_cpus = cpus.len();
|
||||
let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64;
|
||||
|
||||
use_percentage / num_cpus as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CpuUse {
|
||||
fn default() -> Self {
|
||||
CpuUse::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for CpuUse {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
if self.last_refresh.elapsed() >= self.refresh_frequency {
|
||||
self.current = Self::refresh(&mut self.sys);
|
||||
self.last_refresh = Instant::now();
|
||||
}
|
||||
|
||||
let formatted = format!("{}: {:.2} %", self.name(), self.current);
|
||||
let raw = format!("{:.2}", self.current);
|
||||
|
||||
SerializedEntry::new(formatted, raw)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
super::NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for CpuUse {
|
||||
fn value(&self) -> NumericEntry {
|
||||
NumericEntry::Value(self.current)
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
NumericEntry::Value(self.current)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::MetricMetadata;
|
||||
use crate::metric::{Metric, MetricName, SerializedEntry};
|
||||
use nvml_wrapper::Nvml;
|
||||
|
||||
/// Track basic cuda infos.
|
||||
#[derive(Clone)]
|
||||
pub struct CudaMetric {
|
||||
name: MetricName,
|
||||
nvml: Arc<Option<Nvml>>,
|
||||
}
|
||||
|
||||
impl CudaMetric {
|
||||
/// Creates a new metric for CUDA.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Cuda".to_string()),
|
||||
nvml: Arc::new(Nvml::init().map(Some).unwrap_or_else(|err| {
|
||||
log::warn!("Unable to initialize CUDA Metric: {err}");
|
||||
None
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CudaMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for CudaMetric {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let not_available =
|
||||
|| SerializedEntry::new("Unavailable".to_string(), "Unavailable".to_string());
|
||||
|
||||
let available = |nvml: &Nvml| {
|
||||
let mut formatted = String::new();
|
||||
let mut raw_running = String::new();
|
||||
|
||||
let device_count = match nvml.device_count() {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
log::warn!("Unable to get the number of cuda devices: {err}");
|
||||
return not_available();
|
||||
}
|
||||
};
|
||||
|
||||
for index in 0..device_count {
|
||||
let device = match nvml.device_by_index(index) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
log::warn!("Unable to get device {index}: {err}");
|
||||
return not_available();
|
||||
}
|
||||
};
|
||||
let memory_info = match device.memory_info() {
|
||||
Ok(info) => info,
|
||||
Err(err) => {
|
||||
log::warn!("Unable to get memory info from device {index}: {err}");
|
||||
return not_available();
|
||||
}
|
||||
};
|
||||
|
||||
let used_gb = memory_info.used as f64 * 1e-9;
|
||||
let total_gb = memory_info.total as f64 * 1e-9;
|
||||
|
||||
let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
|
||||
let memory_info_raw = format!("{used_gb}/{total_gb}");
|
||||
|
||||
formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
|
||||
raw_running = format!("{memory_info_raw} ");
|
||||
|
||||
let utilization_rates = match device.utilization_rates() {
|
||||
Ok(rate) => rate,
|
||||
Err(err) => {
|
||||
log::warn!("Unable to get utilization rates from device {index}: {err}");
|
||||
return not_available();
|
||||
}
|
||||
};
|
||||
let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
|
||||
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
|
||||
|
||||
// Power is the currency for perf/W. NVML reports milliwatts.
|
||||
if let Ok(power_mw) = device.power_usage() {
|
||||
let power_w = power_mw as f64 / 1000.0;
|
||||
formatted = format!("{formatted} - Power {power_w:.1} W");
|
||||
}
|
||||
}
|
||||
|
||||
SerializedEntry::new(formatted, raw_running)
|
||||
};
|
||||
|
||||
match self.nvml.as_ref() {
|
||||
Some(nvml) => available(nvml),
|
||||
None => not_available(),
|
||||
}
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
use crate::metric::{MetricName, Numeric};
|
||||
|
||||
use super::{
|
||||
Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,
|
||||
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
|
||||
confusion_stats::{ConfusionStats, ConfusionStatsInput},
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use burn_core::{
|
||||
prelude::{Backend, Tensor},
|
||||
tensor::cast::ToElement,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
use std::{num::NonZeroUsize, sync::Arc};
|
||||
|
||||
/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric.
|
||||
///
|
||||
/// The `beta` parameter represents the ratio of recall importance to precision importance.
|
||||
/// `beta > 1` gives more weight to recall, while `beta < 1` favors precision.
|
||||
#[derive(Clone)]
|
||||
pub struct FBetaScoreMetric<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
_b: PhantomData<B>,
|
||||
config: ClassificationMetricConfig,
|
||||
beta: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for FBetaScoreMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new(Default::default(), Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> FBetaScoreMetric<B> {
|
||||
#[allow(dead_code)]
|
||||
fn new(config: ClassificationMetricConfig, beta: f64) -> Self {
|
||||
let name = Arc::new(format!(
|
||||
"FBetaScore ({}) @ {:?} [{:?}]",
|
||||
beta, config.decision_rule, config.class_reduction
|
||||
));
|
||||
Self {
|
||||
name,
|
||||
config,
|
||||
beta,
|
||||
state: Default::default(),
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// F-beta score metric for binary classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `beta` - Positive real factor to weight recall's importance.
|
||||
/// * `threshold` - The threshold to transform a probability into a binary prediction.
|
||||
#[allow(dead_code)]
|
||||
pub fn binary(beta: f64, threshold: f64) -> Self {
|
||||
Self::new(
|
||||
ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
// binary classification results are the same independently of class_reduction
|
||||
..Default::default()
|
||||
},
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
/// F-beta score metric for multiclass classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `beta` - Positive real factor to weight recall's importance.
|
||||
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self {
|
||||
Self::new(
|
||||
ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::TopK(
|
||||
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
|
||||
),
|
||||
class_reduction,
|
||||
},
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
/// F-beta score metric for multi-label classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `beta` - Positive real factor to weight recall's importance.
|
||||
/// * `threshold` - The threshold to transform a probability into a binary prediction.
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self {
|
||||
Self::new(
|
||||
ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
class_reduction,
|
||||
},
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
|
||||
use ClassReduction::{Macro, Micro};
|
||||
let avg_tensor = match self.config.class_reduction {
|
||||
Micro => aggregated_metric,
|
||||
Macro => {
|
||||
if aggregated_metric
|
||||
.clone()
|
||||
.contains_nan()
|
||||
.any()
|
||||
.into_scalar()
|
||||
.to_bool()
|
||||
{
|
||||
let nan_mask = aggregated_metric.clone().is_nan();
|
||||
aggregated_metric = aggregated_metric
|
||||
.clone()
|
||||
.select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))
|
||||
}
|
||||
aggregated_metric.mean()
|
||||
}
|
||||
};
|
||||
avg_tensor.into_scalar().to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for FBetaScoreMetric<B> {
|
||||
type Input = ConfusionStatsInput<B>;
|
||||
|
||||
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let [sample_size, _] = input.predictions.dims();
|
||||
|
||||
let cf_stats = ConfusionStats::new(input, &self.config);
|
||||
let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2));
|
||||
let metric = self.class_average(
|
||||
scaled_true_positive.clone()
|
||||
/ (scaled_true_positive
|
||||
+ cf_stats.clone().false_negative() * self.beta.powi(2)
|
||||
+ cf_stats.false_positive()),
|
||||
);
|
||||
|
||||
self.state.update(
|
||||
100.0 * metric,
|
||||
sample_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for FBetaScoreMetric<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
ClassReduction::{self, *},
|
||||
FBetaScoreMetric, Metric, MetricMetadata,
|
||||
};
|
||||
use crate::metric::Numeric;
|
||||
use crate::{
|
||||
TestBackend,
|
||||
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
|
||||
};
|
||||
use burn_core::tensor::TensorData;
|
||||
use burn_core::tensor::Tolerance;
|
||||
use rstest::rstest;
|
||||
|
||||
#[rstest]
|
||||
#[case::binary_b1(1.0, THRESHOLD, 0.5)]
|
||||
#[case::binary_b2(2.0, THRESHOLD, 0.5)]
|
||||
fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) {
|
||||
let input = dummy_classification_input(&ClassificationType::Binary).into();
|
||||
let mut metric = FBetaScoreMetric::binary(beta, threshold);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)]
|
||||
#[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))]
|
||||
#[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)]
|
||||
#[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)]
|
||||
#[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)]
|
||||
#[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))]
|
||||
#[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)]
|
||||
#[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)]
|
||||
fn test_multiclass_fscore(
|
||||
#[case] beta: f64,
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] top_k: usize,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
|
||||
let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))]
|
||||
#[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)]
|
||||
#[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))]
|
||||
#[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)]
|
||||
fn test_multilabel_fscore(
|
||||
#[case] beta: f64,
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] threshold: f64,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
|
||||
let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameterized_unique_name() {
|
||||
let metric_a = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 1, ClassReduction::Macro);
|
||||
let metric_b = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 2, ClassReduction::Macro);
|
||||
let metric_c = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 1, ClassReduction::Macro);
|
||||
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
assert_eq!(metric_a.name(), metric_c.name());
|
||||
|
||||
let metric_a = FBetaScoreMetric::<TestBackend>::binary(0.5, 0.5);
|
||||
let metric_b = FBetaScoreMetric::<TestBackend>::binary(0.75, 0.5);
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use super::{MetricMetadata, SerializedEntry};
|
||||
use crate::metric::{
|
||||
Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
|
||||
};
|
||||
use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};
|
||||
|
||||
/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
|
||||
#[derive(Clone)]
|
||||
pub struct HammingScore<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
threshold: f32,
|
||||
sigmoid: bool,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [hamming score](HammingScore) input type.
|
||||
#[derive(new)]
|
||||
pub struct HammingScoreInput<B: Backend> {
|
||||
outputs: Tensor<B, 2>,
|
||||
targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> HammingScore<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn update_name(&mut self) {
|
||||
self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold));
|
||||
}
|
||||
|
||||
/// Sets the threshold.
|
||||
pub fn with_threshold(mut self, threshold: f32) -> Self {
|
||||
self.threshold = threshold;
|
||||
self.update_name();
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the sigmoid activation function usage.
|
||||
pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
|
||||
self.sigmoid = sigmoid;
|
||||
self.update_name();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for HammingScore<B> {
|
||||
/// Creates a new metric instance with default values.
|
||||
fn default() -> Self {
|
||||
let threshold = 0.5;
|
||||
let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold));
|
||||
|
||||
Self {
|
||||
name,
|
||||
state: NumericMetricState::default(),
|
||||
threshold,
|
||||
sigmoid: false,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for HammingScore<B> {
|
||||
type Input = HammingScoreInput<B>;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
input: &HammingScoreInput<B>,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
let [batch_size, _n_classes] = input.outputs.dims();
|
||||
|
||||
let targets = input.targets.clone();
|
||||
|
||||
let mut outputs = input.outputs.clone();
|
||||
|
||||
if self.sigmoid {
|
||||
outputs = sigmoid(outputs);
|
||||
}
|
||||
|
||||
let score = outputs
|
||||
.greater_elem(self.threshold)
|
||||
.equal(targets.bool())
|
||||
.float()
|
||||
.mean()
|
||||
.into_scalar()
|
||||
.elem::<f64>();
|
||||
|
||||
self.state.update(
|
||||
100.0 * score,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for HammingScore<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_hamming_score() {
|
||||
let device = Default::default();
|
||||
let mut metric = HammingScore::<TestBackend>::new();
|
||||
|
||||
let x = Tensor::from_data(
|
||||
[
|
||||
[0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
|
||||
[0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1]
|
||||
[0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1]
|
||||
[0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0]
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let y = Tensor::from_data(
|
||||
[
|
||||
[0, 1, 0, 1, 1],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 0, 1],
|
||||
[0, 0, 1, 0, 0],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
let _entry = metric.update(
|
||||
&HammingScoreInput::new(x.clone(), y.clone()),
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert_eq!(100.0, metric.value().current());
|
||||
|
||||
// Invert all targets: y = (1 - y)
|
||||
let y = y.neg().add_scalar(1);
|
||||
let _entry = metric.update(
|
||||
&HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert_eq!(0.0, metric.value().current());
|
||||
|
||||
// Invert 5 target values -> 1 - (5/20) = 0.75
|
||||
let y = Tensor::from_data(
|
||||
[
|
||||
[0, 1, 1, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 1, 0, 0],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let _entry = metric.update(
|
||||
&HammingScoreInput::new(x, y), // invert targets (1 - y)
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert_eq!(75.0, metric.value().current());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameterized_unique_name() {
|
||||
let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);
|
||||
let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);
|
||||
let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);
|
||||
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
assert_eq!(metric_a.name(), metric_c.name());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::MetricMetadata;
|
||||
use super::SerializedEntry;
|
||||
use super::state::FormatOptions;
|
||||
use super::state::NumericMetricState;
|
||||
use crate::metric::MetricName;
|
||||
use crate::metric::Numeric;
|
||||
use crate::metric::{Metric, MetricAttributes, NumericAttributes, NumericEntry};
|
||||
|
||||
/// The loss metric.
|
||||
#[derive(Clone)]
|
||||
pub struct IterationSpeedMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
instant: Option<std::time::Instant>,
|
||||
}
|
||||
|
||||
impl Default for IterationSpeedMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl IterationSpeedMetric {
|
||||
/// Create the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Iteration Speed".to_string()),
|
||||
state: Default::default(),
|
||||
instant: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for IterationSpeedMetric {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let raw = match self.instant {
|
||||
Some(val) => {
|
||||
// If iteration is not logged, compute the speed over the number of items processed.
|
||||
// 1 iteration should equal 1 item when iteration is not logged.
|
||||
metadata
|
||||
.iteration
|
||||
.unwrap_or(metadata.progress.items_processed) as f64
|
||||
/ val.elapsed().as_secs_f64()
|
||||
}
|
||||
None => {
|
||||
self.instant = Some(std::time::Instant::now());
|
||||
0.0
|
||||
}
|
||||
};
|
||||
|
||||
self.state.update(
|
||||
raw,
|
||||
1,
|
||||
FormatOptions::new(self.name())
|
||||
.unit("iter/sec")
|
||||
.precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.instant = None;
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("iter/sec".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for IterationSpeedMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Track the learning rate across iterations.
|
||||
#[derive(Clone)]
|
||||
pub struct LearningRateMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl LearningRateMetric {
|
||||
/// Creates a new learning rate metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Learning Rate".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LearningRateMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for LearningRateMetric {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let lr = metadata.lr.unwrap_or(0.0);
|
||||
|
||||
self.state
|
||||
.update(lr, 1, FormatOptions::new(self.name()).precision(2))
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: None,
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for LearningRateMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::MetricMetadata;
|
||||
use super::SerializedEntry;
|
||||
use super::state::FormatOptions;
|
||||
use super::state::NumericMetricState;
|
||||
use crate::metric::MetricName;
|
||||
use crate::metric::{Metric, MetricAttributes, Numeric, NumericAttributes, NumericEntry};
|
||||
use burn_core::tensor::Tensor;
|
||||
use burn_core::tensor::backend::Backend;
|
||||
|
||||
/// The loss metric.
|
||||
#[derive(Clone)]
|
||||
pub struct LossMetric<B: Backend> {
|
||||
name: Arc<String>,
|
||||
state: NumericMetricState,
|
||||
_b: B,
|
||||
}
|
||||
|
||||
/// The [loss metric](LossMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct LossInput<B: Backend> {
|
||||
tensor: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for LossMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LossMetric<B> {
|
||||
/// Create the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Loss".to_string()),
|
||||
state: NumericMetricState::default(),
|
||||
_b: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for LossMetric<B> {
|
||||
type Input = LossInput<B>;
|
||||
|
||||
fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let [batch_size] = loss.tensor.dims();
|
||||
let loss = loss
|
||||
.tensor
|
||||
.clone()
|
||||
.mean()
|
||||
.into_data()
|
||||
.iter::<f64>()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
self.state.update(
|
||||
loss,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: None,
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for LossMetric<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
/// RAM use metric
|
||||
use super::{MetricAttributes, MetricMetadata, NumericAttributes};
|
||||
use crate::metric::{Metric, Numeric, NumericEntry, SerializedEntry};
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use sysinfo::System;
|
||||
|
||||
/// Memory information
|
||||
pub struct CpuMemory {
|
||||
name: Arc<String>,
|
||||
last_refresh: Instant,
|
||||
refresh_frequency: Duration,
|
||||
sys: System,
|
||||
ram_bytes_total: u64,
|
||||
ram_bytes_used: u64,
|
||||
}
|
||||
|
||||
impl Clone for CpuMemory {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
name: self.name.clone(),
|
||||
last_refresh: self.last_refresh,
|
||||
refresh_frequency: self.refresh_frequency,
|
||||
sys: System::new(),
|
||||
ram_bytes_total: self.ram_bytes_total,
|
||||
ram_bytes_used: self.ram_bytes_used,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CpuMemory {
|
||||
/// Creates a new memory metric
|
||||
pub fn new() -> Self {
|
||||
let mut metric = Self {
|
||||
name: Arc::new("CPU Memory".into()),
|
||||
last_refresh: Instant::now(),
|
||||
refresh_frequency: Duration::from_millis(200),
|
||||
sys: System::new(),
|
||||
ram_bytes_total: 0,
|
||||
ram_bytes_used: 0,
|
||||
};
|
||||
metric.refresh();
|
||||
metric
|
||||
}
|
||||
|
||||
fn refresh(&mut self) {
|
||||
self.sys.refresh_memory();
|
||||
self.last_refresh = Instant::now();
|
||||
|
||||
// bytes of RAM available
|
||||
self.ram_bytes_total = self.sys.total_memory();
|
||||
|
||||
// bytes of RAM in use
|
||||
self.ram_bytes_used = self.sys.used_memory();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CpuMemory {
|
||||
fn default() -> Self {
|
||||
CpuMemory::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for CpuMemory {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
if self.last_refresh.elapsed() >= self.refresh_frequency {
|
||||
self.refresh();
|
||||
}
|
||||
|
||||
let raw = bytes2gb(self.ram_bytes_used);
|
||||
let formatted = format!(
|
||||
"RAM Used: {:.2} / {:.2} Gb",
|
||||
raw,
|
||||
bytes2gb(self.ram_bytes_total),
|
||||
);
|
||||
|
||||
SerializedEntry::new(formatted, raw.to_string())
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
fn name(&self) -> Arc<String> {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("Gb".to_string()),
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for CpuMemory {
|
||||
fn value(&self) -> NumericEntry {
|
||||
NumericEntry::Value(bytes2gb(self.ram_bytes_used))
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
NumericEntry::Value(bytes2gb(self.ram_bytes_used))
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes2gb(bytes: u64) -> f64 {
|
||||
bytes as f64 / 1e9
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
/// State module.
|
||||
pub mod state;
|
||||
/// Module responsible to save and exposes data collected during training.
|
||||
pub mod store;
|
||||
/// Metrics module for vision tasks.
|
||||
#[cfg(feature = "vision")]
|
||||
pub mod vision;
|
||||
|
||||
//Metrics for reinforcement learning.
|
||||
#[cfg(feature = "rl")]
|
||||
mod rl;
|
||||
#[cfg(feature = "rl")]
|
||||
pub use rl::*;
|
||||
|
||||
// System metrics
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
mod cpu_temp;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
mod cpu_use;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
mod cuda;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
mod memory_use;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
pub use cpu_temp::*;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
pub use cpu_use::*;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
pub use cuda::*;
|
||||
#[cfg(feature = "sys-metrics")]
|
||||
pub use memory_use::*;
|
||||
|
||||
// Training metrics
|
||||
mod acc;
|
||||
mod auroc;
|
||||
mod base;
|
||||
mod cer;
|
||||
mod confusion_stats;
|
||||
mod fbetascore;
|
||||
mod hamming;
|
||||
mod iteration;
|
||||
mod learning_rate;
|
||||
mod loss;
|
||||
mod perplexity;
|
||||
mod precision;
|
||||
mod recall;
|
||||
mod top_k_acc;
|
||||
mod wer;
|
||||
|
||||
pub use acc::*;
|
||||
pub use auroc::*;
|
||||
pub use base::*;
|
||||
pub use cer::*;
|
||||
pub use confusion_stats::ConfusionStatsInput;
|
||||
pub use fbetascore::*;
|
||||
pub use hamming::*;
|
||||
pub use iteration::*;
|
||||
pub use learning_rate::*;
|
||||
pub use loss::*;
|
||||
pub use perplexity::*;
|
||||
pub use precision::*;
|
||||
pub use recall::*;
|
||||
pub use top_k_acc::*;
|
||||
pub use wer::*;
|
||||
|
||||
pub(crate) mod classification;
|
||||
pub(crate) mod processor;
|
||||
|
||||
pub use crate::metric::classification::ClassReduction;
|
||||
// Expose `ItemLazy` so it can be implemented for custom types
|
||||
pub use processor::ItemLazy;
|
||||
@@ -0,0 +1,438 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::state::FormatOptions;
|
||||
use super::{MetricMetadata, NumericEntry, SerializedEntry, format_float};
|
||||
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{ElementConversion, Int, Tensor};
|
||||
|
||||
/// Custom state for perplexity metric that correctly accumulates negative log-likelihood.
|
||||
///
|
||||
/// Unlike other metrics that can be averaged, perplexity requires special handling:
|
||||
/// - Accumulate total negative log-likelihood across all tokens
|
||||
/// - Accumulate total number of effective tokens
|
||||
/// - Compute perplexity as exp(total_nll / total_tokens) only at the end
|
||||
#[derive(Clone)]
|
||||
struct PerplexityState {
|
||||
/// Sum of negative log-likelihood across all tokens
|
||||
sum_nll: f64,
|
||||
/// Total number of effective tokens (excluding padding)
|
||||
total_tokens: usize,
|
||||
/// Current batch perplexity (for display purposes)
|
||||
current: f64,
|
||||
}
|
||||
|
||||
impl PerplexityState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
sum_nll: 0.0,
|
||||
total_tokens: 0,
|
||||
current: f64::NAN,
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.sum_nll = 0.0;
|
||||
self.total_tokens = 0;
|
||||
self.current = f64::NAN;
|
||||
}
|
||||
|
||||
/// Update state with negative log-likelihood and token count from current batch
|
||||
fn update(
|
||||
&mut self,
|
||||
sum_log_prob: f64,
|
||||
effective_tokens: usize,
|
||||
format: FormatOptions,
|
||||
) -> SerializedEntry {
|
||||
// sum_log_prob is already the sum of log probabilities (negative values)
|
||||
// We need to negate it to get negative log-likelihood
|
||||
let batch_nll = -sum_log_prob;
|
||||
|
||||
// Accumulate across batches
|
||||
self.sum_nll += batch_nll;
|
||||
self.total_tokens += effective_tokens;
|
||||
|
||||
// Compute current batch perplexity for display
|
||||
let batch_perplexity = if effective_tokens > 0 {
|
||||
(batch_nll / effective_tokens as f64).exp()
|
||||
} else {
|
||||
f64::INFINITY
|
||||
};
|
||||
self.current = batch_perplexity;
|
||||
|
||||
// Compute running epoch perplexity
|
||||
let epoch_perplexity = if self.total_tokens > 0 {
|
||||
(self.sum_nll / self.total_tokens as f64).exp()
|
||||
} else {
|
||||
f64::INFINITY
|
||||
};
|
||||
|
||||
// Format for display
|
||||
let (formatted_current, formatted_running) = match format.precision_value() {
|
||||
Some(precision) => (
|
||||
format_float(batch_perplexity, precision),
|
||||
format_float(epoch_perplexity, precision),
|
||||
),
|
||||
None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
|
||||
};
|
||||
|
||||
let formatted = match format.unit_value() {
|
||||
Some(unit) => {
|
||||
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
|
||||
}
|
||||
None => format!("epoch {formatted_running} - batch {formatted_current}"),
|
||||
};
|
||||
|
||||
// Serialize the state for aggregation
|
||||
let serialized = NumericEntry::Aggregated {
|
||||
aggregated_value: epoch_perplexity,
|
||||
count: self.total_tokens,
|
||||
}
|
||||
.serialize();
|
||||
|
||||
SerializedEntry::new(formatted, serialized)
|
||||
}
|
||||
|
||||
fn value(&self) -> NumericEntry {
|
||||
let perplexity = if self.total_tokens > 0 {
|
||||
(self.sum_nll / self.total_tokens as f64).exp()
|
||||
} else {
|
||||
f64::INFINITY
|
||||
};
|
||||
|
||||
NumericEntry::Aggregated {
|
||||
aggregated_value: perplexity,
|
||||
count: self.total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.value()
|
||||
}
|
||||
}
|
||||
|
||||
/// The perplexity metric.
|
||||
///
|
||||
/// Perplexity is a measure of how well a probability distribution or probability model
|
||||
/// predicts a sample. It's commonly used to evaluate language models. A lower perplexity
|
||||
/// indicates that the model is more confident in its predictions.
|
||||
///
|
||||
/// Mathematically, perplexity is defined as the exponentiation of the cross-entropy loss:
|
||||
/// PPL = exp(H(p, q)) = exp(-1/N * Σ log(p(x_i)))
|
||||
///
|
||||
/// where:
|
||||
/// - H(p, q) is the cross-entropy between the true distribution p and predicted distribution q
|
||||
/// - N is the number of tokens
|
||||
/// - p(x_i) is the predicted probability of the i-th token
|
||||
///
|
||||
/// # Aggregation
|
||||
/// Unlike other metrics, perplexity cannot be simply averaged across batches.
|
||||
/// This implementation correctly accumulates the total negative log-likelihood and
|
||||
/// total token count across batches, then computes perplexity as exp(total_nll / total_tokens).
|
||||
#[derive(Clone)]
|
||||
pub struct PerplexityMetric<B: Backend> {
|
||||
name: MetricName,
|
||||
state: PerplexityState,
|
||||
pad_token: Option<usize>,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [perplexity metric](PerplexityMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct PerplexityInput<B: Backend> {
|
||||
/// Logits tensor of shape [batch_size * sequence_length, vocab_size]
|
||||
outputs: Tensor<B, 2>,
|
||||
/// Target tokens tensor of shape [batch_size * sequence_length]
|
||||
targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for PerplexityMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PerplexityMetric<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: MetricName::new("Perplexity".to_string()),
|
||||
state: PerplexityState::new(),
|
||||
pad_token: Default::default(),
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the pad token to exclude from perplexity calculation.
|
||||
///
|
||||
/// When a pad token is set, predictions for padding tokens are masked out
|
||||
/// and do not contribute to the perplexity calculation. This is important
|
||||
/// for variable-length sequences where padding is used.
|
||||
pub fn with_pad_token(mut self, index: usize) -> Self {
|
||||
self.pad_token = Some(index);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for PerplexityMetric<B> {
|
||||
type Input = PerplexityInput<B>;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
input: &PerplexityInput<B>,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
let targets = input.targets.clone();
|
||||
let outputs = input.outputs.clone();
|
||||
|
||||
let [total_tokens, _vocab_size] = outputs.dims();
|
||||
|
||||
// Convert logits to log probabilities using log_softmax for numerical stability
|
||||
let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
|
||||
|
||||
// Gather the log probabilities for the target tokens
|
||||
let target_log_probs = log_probs
|
||||
.gather(1, targets.clone().unsqueeze_dim(1))
|
||||
.squeeze_dim(1);
|
||||
|
||||
let (sum_log_prob, effective_tokens) = match self.pad_token {
|
||||
Some(pad_token) => {
|
||||
// Create a mask for non-padding tokens
|
||||
let mask = targets.clone().not_equal_elem(pad_token as i64);
|
||||
|
||||
// Apply mask to log probabilities (set padding log probs to 0)
|
||||
let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
|
||||
|
||||
// Sum the log probabilities and count effective tokens
|
||||
let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
|
||||
let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
|
||||
|
||||
(sum_log_prob, effective_tokens)
|
||||
}
|
||||
None => {
|
||||
// No padding, use all tokens
|
||||
let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
|
||||
(sum_log_prob, total_tokens)
|
||||
}
|
||||
};
|
||||
|
||||
// Pass the sum_log_prob and effective_tokens to the state
|
||||
// The state will handle the correct accumulation and perplexity calculation
|
||||
self.state.update(
|
||||
sum_log_prob,
|
||||
effective_tokens,
|
||||
FormatOptions::new(self.name()).precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: None,
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for PerplexityMetric<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_perplexity_perfect_prediction() {
|
||||
let device = Default::default();
|
||||
let mut metric = PerplexityMetric::<TestBackend>::new();
|
||||
|
||||
// Perfect prediction: target is always the highest probability class
|
||||
let input = PerplexityInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[10.0, 0.0, 0.0], // Very confident prediction for class 0
|
||||
[0.0, 10.0, 0.0], // Very confident prediction for class 1
|
||||
[0.0, 0.0, 10.0], // Very confident prediction for class 2
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([0, 1, 2], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
let perplexity = metric.value().current();
|
||||
|
||||
// Perfect predictions should result in very low perplexity (close to 1.0)
|
||||
assert!(
|
||||
perplexity < 1.1,
|
||||
"Perfect predictions should have low perplexity, got {}",
|
||||
perplexity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perplexity_uniform_prediction() {
|
||||
let device = Default::default();
|
||||
let mut metric = PerplexityMetric::<TestBackend>::new();
|
||||
|
||||
// Uniform prediction: all classes have equal probability
|
||||
let input = PerplexityInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.0, 0.0], // Uniform distribution (after softmax)
|
||||
[0.0, 0.0, 0.0], // Uniform distribution (after softmax)
|
||||
[0.0, 0.0, 0.0], // Uniform distribution (after softmax)
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([0, 1, 2], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
let perplexity = metric.value().current();
|
||||
|
||||
// Uniform distribution over 3 classes should have perplexity ≈ 3.0
|
||||
assert!(
|
||||
(perplexity - 3.0).abs() < 0.1,
|
||||
"Uniform distribution perplexity should be ~3.0, got {}",
|
||||
perplexity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perplexity_with_padding() {
|
||||
let device = Default::default();
|
||||
let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
|
||||
|
||||
let input = PerplexityInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[10.0, 0.0, 0.0, 0.0], // Good prediction for class 0
|
||||
[0.0, 10.0, 0.0, 0.0], // Good prediction for class 1
|
||||
[0.0, 0.0, 0.0, 1.0], // This is padding - should be ignored
|
||||
[0.0, 0.0, 0.0, 1.0], // This is padding - should be ignored
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([0, 1, 3, 3], &device), // 3 is pad token
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
let perplexity = metric.value().current();
|
||||
|
||||
// Should only consider the first two predictions, both of which are confident
|
||||
assert!(
|
||||
perplexity < 1.1,
|
||||
"Good predictions with padding should have low perplexity, got {}",
|
||||
perplexity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perplexity_wrong_prediction() {
|
||||
let device = Default::default();
|
||||
let mut metric = PerplexityMetric::<TestBackend>::new();
|
||||
|
||||
// Wrong predictions: target class has very low probability
|
||||
let input = PerplexityInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 10.0, 0.0], // Predicts class 1, but target is 0
|
||||
[10.0, 0.0, 0.0], // Predicts class 0, but target is 1
|
||||
[0.0, 0.0, 10.0], // Predicts class 2, but target is 0
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([0, 1, 0], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
let perplexity = metric.value().current();
|
||||
|
||||
// Wrong predictions should result in high perplexity
|
||||
assert!(
|
||||
perplexity > 10.0,
|
||||
"Wrong predictions should have high perplexity, got {}",
|
||||
perplexity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perplexity_multi_batch_aggregation() {
|
||||
let device = Default::default();
|
||||
let mut metric = PerplexityMetric::<TestBackend>::new();
|
||||
|
||||
// First batch: 2 tokens with uniform distribution (log_prob ≈ -1.0986 each)
|
||||
let input1 = PerplexityInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
|
||||
[0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([0, 1], &device),
|
||||
);
|
||||
|
||||
// Second batch: 1 token with uniform distribution
|
||||
let input2 = PerplexityInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2], &device),
|
||||
);
|
||||
|
||||
// Update with both batches
|
||||
let _entry1 = metric.update(&input1, &MetricMetadata::fake());
|
||||
let _entry2 = metric.update(&input2, &MetricMetadata::fake());
|
||||
|
||||
let aggregated_perplexity = metric.value().current();
|
||||
|
||||
// For uniform distribution over 3 classes: log_prob ≈ -log(3) ≈ -1.0986
|
||||
// Total negative log-likelihood: 3 * 1.0986 ≈ 3.2958
|
||||
// Total tokens: 3
|
||||
// Expected perplexity: exp(3.2958 / 3) = exp(1.0986) ≈ 3.0
|
||||
assert!(
|
||||
(aggregated_perplexity - 3.0).abs() < 0.1,
|
||||
"Multi-batch aggregated perplexity should be ~3.0, got {}",
|
||||
aggregated_perplexity
|
||||
);
|
||||
|
||||
// Compare with single batch containing all data
|
||||
let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
|
||||
let single_input = PerplexityInput::new(
|
||||
Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
|
||||
Tensor::from_data([0, 1, 2], &device),
|
||||
);
|
||||
|
||||
let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
|
||||
let single_batch_perplexity = single_batch_metric.value().current();
|
||||
|
||||
// Multi-batch and single-batch should give the same result
|
||||
assert!(
|
||||
(aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
|
||||
"Multi-batch ({}) and single-batch ({}) perplexity should match",
|
||||
aggregated_perplexity,
|
||||
single_batch_perplexity
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
use crate::metric::{MetricName, Numeric};
|
||||
|
||||
use super::{
|
||||
Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,
|
||||
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
|
||||
confusion_stats::{ConfusionStats, ConfusionStatsInput},
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use burn_core::{
|
||||
prelude::{Backend, Tensor},
|
||||
tensor::cast::ToElement,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
use std::{num::NonZeroUsize, sync::Arc};
|
||||
|
||||
/// The Precision Metric
|
||||
#[derive(Clone)]
|
||||
pub struct PrecisionMetric<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
_b: PhantomData<B>,
|
||||
config: ClassificationMetricConfig,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for PrecisionMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PrecisionMetric<B> {
|
||||
fn new(config: ClassificationMetricConfig) -> Self {
|
||||
let state = Default::default();
|
||||
let name = Arc::new(format!(
|
||||
"Precision @ {:?} [{:?}]",
|
||||
config.decision_rule, config.class_reduction
|
||||
));
|
||||
|
||||
Self {
|
||||
state,
|
||||
config,
|
||||
name,
|
||||
_b: Default::default(),
|
||||
}
|
||||
}
|
||||
/// Precision metric for binary classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `threshold` - The threshold to transform a probability into a binary prediction.
|
||||
#[allow(dead_code)]
|
||||
pub fn binary(threshold: f64) -> Self {
|
||||
Self::new(ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
// binary classification results are the same independently of class_reduction
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Precision metric for multiclass classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
|
||||
Self::new(ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::TopK(
|
||||
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
|
||||
),
|
||||
class_reduction,
|
||||
})
|
||||
}
|
||||
|
||||
/// Precision metric for multi-label classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `threshold` - The threshold to transform a probability into a binary value.
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
|
||||
Self {
|
||||
config: ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
class_reduction,
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
|
||||
use ClassReduction::{Macro, Micro};
|
||||
let avg_tensor = match self.config.class_reduction {
|
||||
Micro => aggregated_metric,
|
||||
Macro => {
|
||||
if aggregated_metric
|
||||
.clone()
|
||||
.contains_nan()
|
||||
.any()
|
||||
.into_scalar()
|
||||
.to_bool()
|
||||
{
|
||||
let nan_mask = aggregated_metric.clone().is_nan();
|
||||
aggregated_metric = aggregated_metric
|
||||
.clone()
|
||||
.select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))
|
||||
}
|
||||
aggregated_metric.mean()
|
||||
}
|
||||
};
|
||||
avg_tensor.into_scalar().to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for PrecisionMetric<B> {
|
||||
type Input = ConfusionStatsInput<B>;
|
||||
|
||||
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let [sample_size, _] = input.predictions.dims();
|
||||
|
||||
let cf_stats = ConfusionStats::new(input, &self.config);
|
||||
let metric =
|
||||
self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive());
|
||||
|
||||
self.state.update(
|
||||
100.0 * metric,
|
||||
sample_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for PrecisionMetric<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
ClassReduction::{self, *},
|
||||
Metric, MetricMetadata, PrecisionMetric,
|
||||
};
|
||||
use crate::metric::Numeric;
|
||||
use crate::{
|
||||
TestBackend,
|
||||
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
|
||||
};
|
||||
use burn_core::tensor::TensorData;
|
||||
use burn_core::tensor::Tolerance;
|
||||
use rstest::rstest;
|
||||
|
||||
#[rstest]
|
||||
#[case::binary(THRESHOLD, 0.5)]
|
||||
fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) {
|
||||
let input = dummy_classification_input(&ClassificationType::Binary).into();
|
||||
let mut metric = PrecisionMetric::binary(threshold);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]
|
||||
#[case::multiclass_micro_k2(Micro, 2, 4.0/10.0)]
|
||||
#[case::multiclass_macro_k1(Macro, 1, (0.5 + 0.5 + 1.0)/3.0)]
|
||||
#[case::multiclass_macro_k2(Macro, 2, (0.5 + 1.0/4.0 + 0.5)/3.0)]
|
||||
fn test_multiclass_precision(
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] top_k: usize,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
|
||||
let mut metric = PrecisionMetric::multiclass(top_k, class_reduction);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)]
|
||||
#[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)]
|
||||
fn test_multilabel_precision(
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] threshold: f64,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
|
||||
let mut metric = PrecisionMetric::multilabel(threshold, class_reduction);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameterized_unique_name() {
|
||||
let metric_a = PrecisionMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
|
||||
let metric_b = PrecisionMetric::<TestBackend>::multiclass(2, ClassReduction::Macro);
|
||||
let metric_c = PrecisionMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
|
||||
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
assert_eq!(metric_a.name(), metric_c.name());
|
||||
|
||||
let metric_a = PrecisionMetric::<TestBackend>::binary(0.5);
|
||||
let metric_b = PrecisionMetric::<TestBackend>::binary(0.75);
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
|
||||
|
||||
use super::EventProcessorTraining;
|
||||
use async_channel::{Receiver, Sender};
|
||||
|
||||
/// Event processor for the training process.
|
||||
pub struct AsyncProcessorTraining<ET, EV> {
|
||||
sender: Sender<Message<ET, EV>>,
|
||||
}
|
||||
|
||||
/// Event processor for the model evaluation.
|
||||
pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
|
||||
sender: Sender<EvalMessage<P>>,
|
||||
}
|
||||
|
||||
struct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {
|
||||
processor: P,
|
||||
rec: Receiver<Message<ET, EV>>,
|
||||
}
|
||||
|
||||
struct WorkerEvaluation<P: EventProcessorEvaluation> {
|
||||
processor: P,
|
||||
rec: Receiver<EvalMessage<P>>,
|
||||
}
|
||||
|
||||
impl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>
|
||||
WorkerTraining<ET, EV, P>
|
||||
{
|
||||
pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {
|
||||
let mut worker = Self { processor, rec };
|
||||
|
||||
std::thread::spawn(move || {
|
||||
while let Ok(msg) = worker.rec.recv_blocking() {
|
||||
match msg {
|
||||
Message::Train(event) => worker.processor.process_train(event),
|
||||
Message::Valid(event) => worker.processor.process_valid(event),
|
||||
Message::Renderer(callback) => {
|
||||
callback.send_blocking(worker.processor.renderer()).unwrap();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
|
||||
pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
|
||||
let mut worker = Self { processor, rec };
|
||||
|
||||
std::thread::spawn(move || {
|
||||
while let Ok(event) = worker.rec.recv_blocking() {
|
||||
match event {
|
||||
EvalMessage::Test(event) => worker.processor.process_test(event),
|
||||
EvalMessage::Renderer(sender) => {
|
||||
sender.send_blocking(worker.processor.renderer()).unwrap();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {
|
||||
/// Create an event processor for training.
|
||||
pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {
|
||||
let (sender, rec) = async_channel::bounded(1);
|
||||
|
||||
WorkerTraining::start(processor, rec);
|
||||
|
||||
Self { sender }
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
|
||||
/// Create an event processor for model evaluation.
|
||||
pub fn new(processor: P) -> Self {
|
||||
let (sender, rec) = async_channel::bounded(1);
|
||||
|
||||
WorkerEvaluation::start(processor, rec);
|
||||
|
||||
Self { sender }
|
||||
}
|
||||
}
|
||||
|
||||
enum Message<EventTrain, EventValid> {
|
||||
Train(EventTrain),
|
||||
Valid(EventValid),
|
||||
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
|
||||
}
|
||||
|
||||
enum EvalMessage<P: EventProcessorEvaluation> {
|
||||
Test(EvaluatorEvent<P::ItemTest>),
|
||||
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
|
||||
}
|
||||
|
||||
impl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {
|
||||
fn process_train(&mut self, event: ET) {
|
||||
self.sender.send_blocking(Message::Train(event)).unwrap();
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: EV) {
|
||||
self.sender.send_blocking(Message::Valid(event)).unwrap();
|
||||
}
|
||||
|
||||
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
|
||||
let (sender, rec) = async_channel::bounded(1);
|
||||
self.sender
|
||||
.send_blocking(Message::Renderer(sender))
|
||||
.unwrap();
|
||||
|
||||
match rec.recv_blocking() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("{err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
|
||||
type ItemTest = P::ItemTest;
|
||||
|
||||
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
|
||||
self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
|
||||
}
|
||||
|
||||
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
|
||||
let (sender, rec) = async_channel::bounded(1);
|
||||
self.sender
|
||||
.send_blocking(EvalMessage::Renderer(sender))
|
||||
.unwrap();
|
||||
|
||||
match rec.recv_blocking() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("{err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_optim::LearningRate;
|
||||
|
||||
use crate::{
|
||||
LearnerSummary,
|
||||
renderer::{EvaluationName, MetricsRenderer},
|
||||
};
|
||||
|
||||
/// Event happening during the training/validation process.
|
||||
pub enum LearnerEvent<T> {
|
||||
/// Signal the start of the process (e.g., training start)
|
||||
Start,
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(TrainingItem<T>),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(usize),
|
||||
/// Signal the end of the process (e.g., training end).
|
||||
End(Option<LearnerSummary>),
|
||||
}
|
||||
|
||||
/// Event happening during the evaluation process.
|
||||
pub enum EvaluatorEvent<T> {
|
||||
/// Signal the start of the process (e.g., evaluation start)
|
||||
Start,
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(EvaluationName, EvaluationItem<T>),
|
||||
/// Signal the end of the process (e.g., evaluation end).
|
||||
End(Option<LearnerSummary>),
|
||||
}
|
||||
|
||||
/// Items that are lazy are not ready to be processed by metrics.
|
||||
///
|
||||
/// We want to sync them on a different thread to avoid blocking training.
|
||||
pub trait ItemLazy: Send {
|
||||
/// Item that is properly synced and ready to be processed by metrics.
|
||||
type ItemSync: Send;
|
||||
|
||||
/// Sync the item.
|
||||
fn sync(self) -> Self::ItemSync;
|
||||
}
|
||||
|
||||
/// Process events happening during training and validation.
|
||||
pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
|
||||
/// Collect a training event.
|
||||
fn process_train(&mut self, event: TrainEvent);
|
||||
/// Collect a validation event.
|
||||
fn process_valid(&mut self, event: ValidEvent);
|
||||
/// Returns the renderer used for training.
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer>;
|
||||
}
|
||||
|
||||
/// Process events happening during evaluation.
|
||||
pub trait EventProcessorEvaluation: Send {
|
||||
/// The test item.
|
||||
type ItemTest: ItemLazy;
|
||||
|
||||
/// Collect a test event.
|
||||
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
|
||||
|
||||
/// Returns the renderer used for evaluation.
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer>;
|
||||
}
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct TrainingItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The global progress of the training (e.g. epochs).
|
||||
pub global_progress: Progress,
|
||||
|
||||
/// The iteration, if it it different from the items processed.
|
||||
pub iteration: Option<usize>,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
|
||||
type ItemSync = TrainingItem<T::ItemSync>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
TrainingItem {
|
||||
item: self.item.sync(),
|
||||
progress: self.progress,
|
||||
global_progress: self.global_progress,
|
||||
iteration: self.iteration,
|
||||
lr: self.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An evaluation item.
|
||||
#[derive(new)]
|
||||
pub struct EvaluationItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The iteration, if it it different from the items processed.
|
||||
pub iteration: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
|
||||
type ItemSync = EvaluationItem<T::ItemSync>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
EvaluationItem {
|
||||
item: self.item.sync(),
|
||||
progress: self.progress,
|
||||
iteration: self.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ItemLazy for () {
|
||||
type ItemSync = ();
|
||||
|
||||
fn sync(self) -> Self::ItemSync {}
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
|
||||
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation};
|
||||
use crate::metric::store::{EpochSummary, EventStoreClient, Split};
|
||||
use crate::renderer::{
|
||||
EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// An [event processor](EventProcessorTraining) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
/// - Render metrics using a [metrics renderer](MetricsRenderer).
|
||||
pub struct FullEventProcessorTraining<T: ItemLazy, V: ItemLazy> {
|
||||
metrics: MetricsTraining<T, V>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
/// An [event processor](EventProcessorEvaluation) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
/// - Render metrics using a [metrics renderer](MetricsRenderer).
|
||||
pub struct FullEventProcessorEvaluation<T: ItemLazy> {
|
||||
metrics: MetricsEvaluation<T>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> FullEventProcessorTraining<T, V> {
|
||||
pub(crate) fn new(
|
||||
metrics: MetricsTraining<T, V>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics,
|
||||
renderer,
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
|
||||
let mut indicators = vec![];
|
||||
indicators.push(ProgressType::Detailed {
|
||||
tag: String::from("Epoch"),
|
||||
progress: progress.global_progress.clone(),
|
||||
});
|
||||
|
||||
if let Some(iteration) = progress.iteration {
|
||||
indicators.push(ProgressType::Value {
|
||||
tag: String::from("Iteration"),
|
||||
value: iteration,
|
||||
});
|
||||
};
|
||||
|
||||
if let Some(p) = &progress.progress {
|
||||
indicators.push(ProgressType::Detailed {
|
||||
tag: String::from("Items"),
|
||||
progress: p.clone(),
|
||||
});
|
||||
};
|
||||
|
||||
indicators
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
|
||||
pub(crate) fn new(
|
||||
metrics: MetricsEvaluation<T>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics,
|
||||
renderer,
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec<ProgressType> {
|
||||
let mut indicators = vec![];
|
||||
if let Some(iteration) = progress.iteration {
|
||||
indicators.push(ProgressType::Value {
|
||||
tag: String::from("Iteration"),
|
||||
value: iteration,
|
||||
});
|
||||
};
|
||||
|
||||
indicators.push(ProgressType::Detailed {
|
||||
tag: String::from("Items"),
|
||||
progress: progress.progress.clone(),
|
||||
});
|
||||
|
||||
indicators
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
|
||||
type ItemTest = T;
|
||||
|
||||
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
|
||||
match event {
|
||||
EvaluatorEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsInit(
|
||||
definitions.clone(),
|
||||
));
|
||||
definitions
|
||||
.iter()
|
||||
.for_each(|definition| self.renderer.register_metric(definition.clone()));
|
||||
}
|
||||
EvaluatorEvent::ProcessedItem(name, item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_test(&item, &metadata);
|
||||
|
||||
self.store.add_event_test(
|
||||
crate::metric::store::Event::MetricsUpdate(update.clone()),
|
||||
name.name.clone(),
|
||||
);
|
||||
|
||||
update.entries.into_iter().for_each(|entry| {
|
||||
self.renderer
|
||||
.update_test(name.clone(), MetricState::Generic(entry))
|
||||
});
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_test(
|
||||
name.clone(),
|
||||
MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
),
|
||||
)
|
||||
});
|
||||
|
||||
let indicators = self.progress_indicators(&progress);
|
||||
self.renderer.render_test(progress, indicators);
|
||||
}
|
||||
EvaluatorEvent::End(summary) => {
|
||||
self.renderer.on_test_end(summary).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer> {
|
||||
self.renderer
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
|
||||
for FullEventProcessorTraining<T, V>
|
||||
{
|
||||
fn process_train(&mut self, event: LearnerEvent<T>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsInit(
|
||||
definitions.clone(),
|
||||
));
|
||||
definitions
|
||||
.iter()
|
||||
.for_each(|definition| self.renderer.register_metric(definition.clone()));
|
||||
}
|
||||
LearnerEvent::ProcessedItem(item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_train(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_train(MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
))
|
||||
});
|
||||
|
||||
let indicators = self.progress_indicators(&progress);
|
||||
self.renderer.render_train(progress, indicators);
|
||||
}
|
||||
LearnerEvent::EndEpoch(epoch) => {
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(
|
||||
epoch,
|
||||
Split::Train,
|
||||
)));
|
||||
self.metrics.end_epoch_train();
|
||||
}
|
||||
LearnerEvent::End(summary) => {
|
||||
self.renderer.on_train_end(summary).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: LearnerEvent<V>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {} // no-op for now
|
||||
LearnerEvent::ProcessedItem(item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_valid(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_valid(MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
))
|
||||
});
|
||||
|
||||
let indicators = self.progress_indicators(&progress);
|
||||
self.renderer.render_valid(progress, indicators);
|
||||
}
|
||||
LearnerEvent::EndEpoch(epoch) => {
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(
|
||||
epoch,
|
||||
Split::Valid,
|
||||
)));
|
||||
self.metrics.end_epoch_valid();
|
||||
}
|
||||
LearnerEvent::End(_) => {} // no-op for now
|
||||
}
|
||||
}
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer> {
|
||||
self.renderer
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ItemLazy, TrainingItem};
|
||||
use crate::{
|
||||
EvaluationItem,
|
||||
metric::{
|
||||
Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
|
||||
store::{MetricsUpdate, NumericMetricUpdate},
|
||||
},
|
||||
renderer::{EvaluationProgress, TrainingProgress},
|
||||
};
|
||||
|
||||
pub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {
|
||||
train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
|
||||
valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,
|
||||
train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
|
||||
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,
|
||||
metric_definitions: HashMap<MetricId, MetricDefinition>,
|
||||
}
|
||||
|
||||
pub(crate) struct MetricsEvaluation<T: ItemLazy> {
|
||||
test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
|
||||
test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
|
||||
metric_definitions: HashMap<MetricId, MetricDefinition>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> Default for MetricsEvaluation<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
test: Default::default(),
|
||||
test_numeric: Default::default(),
|
||||
metric_definitions: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
train: Vec::default(),
|
||||
valid: Vec::default(),
|
||||
train_numeric: Vec::default(),
|
||||
valid_numeric: Vec::default(),
|
||||
metric_definitions: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> MetricsEvaluation<T> {
|
||||
/// Register a testing metric.
|
||||
pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
T::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.test.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric testing metric.
|
||||
pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
T::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.test_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
|
||||
self.metric_definitions.insert(
|
||||
metric.id.clone(),
|
||||
MetricDefinition::new(metric.id.clone(), &metric.metric),
|
||||
);
|
||||
}
|
||||
|
||||
/// Get metric definitions.
|
||||
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
|
||||
self.metric_definitions.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Update the testing information from the testing item.
|
||||
pub(crate) fn update_test(
|
||||
&mut self,
|
||||
item: &EvaluationItem<T::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.test.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
|
||||
|
||||
for metric in self.test.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.test_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
T::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.train.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation metric.
|
||||
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
V::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric training metric.
|
||||
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
T::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.train_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric validation metric.
|
||||
pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)
|
||||
where
|
||||
V::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
Me: Metric + Numeric + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
|
||||
self.metric_definitions.insert(
|
||||
metric.id.clone(),
|
||||
MetricDefinition::new(metric.id.clone(), &metric.metric),
|
||||
);
|
||||
}
|
||||
|
||||
/// Get metric definitions for all splits
|
||||
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
|
||||
self.metric_definitions.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Update the training information from the training item.
|
||||
pub(crate) fn update_train(
|
||||
&mut self,
|
||||
item: &TrainingItem<T::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.train.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
|
||||
|
||||
for metric in self.train.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the training information from the validation item.
|
||||
pub(crate) fn update_valid(
|
||||
&mut self,
|
||||
item: &TrainingItem<V::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
|
||||
|
||||
for metric in self.valid.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Signal the end of a training epoch.
|
||||
pub(crate) fn end_epoch_train(&mut self) {
|
||||
for metric in self.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal the end of a validation epoch.
|
||||
pub(crate) fn end_epoch_valid(&mut self) {
|
||||
for metric in self.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&TrainingItem<T>> for TrainingProgress {
|
||||
fn from(item: &TrainingItem<T>) -> Self {
|
||||
Self {
|
||||
progress: Some(item.progress.clone()),
|
||||
global_progress: item.global_progress.clone(),
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&EvaluationItem<T>> for TrainingProgress {
|
||||
fn from(item: &EvaluationItem<T>) -> Self {
|
||||
Self {
|
||||
progress: None,
|
||||
global_progress: item.progress.clone(),
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&EvaluationItem<T>> for EvaluationProgress {
|
||||
fn from(item: &EvaluationItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&TrainingItem<T>> for MetricMetadata {
|
||||
fn from(item: &TrainingItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
global_progress: item.global_progress.clone(),
|
||||
iteration: item.iteration,
|
||||
lr: item.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&EvaluationItem<T>> for MetricMetadata {
|
||||
fn from(item: &EvaluationItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
global_progress: item.progress.clone(),
|
||||
iteration: item.iteration,
|
||||
lr: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
pub(crate) trait MetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
pub(crate) struct MetricWrapper<M> {
|
||||
pub id: MetricId,
|
||||
pub metric: M,
|
||||
}
|
||||
|
||||
impl<M: Metric> MetricWrapper<M> {
|
||||
pub fn new(metric: M) -> Self {
|
||||
Self {
|
||||
id: MetricId::new(metric.name()),
|
||||
metric,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {
|
||||
let serialized_entry = self.metric.update(&item.adapt(), metadata);
|
||||
let update = MetricEntry::new(self.id.clone(), serialized_entry);
|
||||
let numeric = self.metric.value();
|
||||
let running = self.metric.running_value();
|
||||
|
||||
NumericMetricUpdate {
|
||||
entry: update,
|
||||
numeric_entry: numeric,
|
||||
running_entry: running,
|
||||
}
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.metric.clear()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {
|
||||
let serialized_entry = self.metric.update(&item.adapt(), metadata);
|
||||
MetricEntry::new(self.id.clone(), serialized_entry)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.metric.clear()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
|
||||
use crate::{
|
||||
metric::store::{EpochSummary, EventStoreClient, Split},
|
||||
renderer::cli::CliMetricsRenderer,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// An [event processor](EventProcessor) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
#[allow(dead_code)]
|
||||
#[derive(new)]
|
||||
pub(crate) struct MinimalEventProcessor<T: ItemLazy, V: ItemLazy> {
|
||||
metrics: MetricsTraining<T, V>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
|
||||
for MinimalEventProcessor<T, V>
|
||||
{
|
||||
fn process_train(&mut self, event: LearnerEvent<T>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsInit(definitions));
|
||||
}
|
||||
|
||||
LearnerEvent::ProcessedItem(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_train(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsUpdate(update));
|
||||
}
|
||||
LearnerEvent::EndEpoch(epoch) => {
|
||||
self.metrics.end_epoch_train();
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(
|
||||
epoch,
|
||||
Split::Train,
|
||||
)));
|
||||
}
|
||||
LearnerEvent::End(_summary) => {} // no-op for now
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: LearnerEvent<V>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {} // no-op for now
|
||||
LearnerEvent::ProcessedItem(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_valid(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update));
|
||||
}
|
||||
LearnerEvent::EndEpoch(epoch) => {
|
||||
self.metrics.end_epoch_valid();
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(
|
||||
epoch,
|
||||
Split::Valid,
|
||||
)));
|
||||
}
|
||||
LearnerEvent::End(_) => {} // no-op for now
|
||||
}
|
||||
}
|
||||
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
|
||||
// TODO: Check for another default.
|
||||
Box::new(CliMetricsRenderer::new())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
mod async_wrapper;
|
||||
mod base;
|
||||
mod full;
|
||||
mod metrics;
|
||||
mod minimal;
|
||||
#[cfg(feature = "rl")]
|
||||
mod rl_metrics;
|
||||
#[cfg(feature = "rl")]
|
||||
mod rl_processor;
|
||||
|
||||
pub use base::*;
|
||||
pub(crate) use full::*;
|
||||
pub(crate) use metrics::*;
|
||||
#[cfg(feature = "rl")]
|
||||
pub(crate) use rl_metrics::*;
|
||||
#[cfg(feature = "rl")]
|
||||
pub(crate) use rl_processor::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use minimal::*;
|
||||
|
||||
pub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test_utils {
|
||||
use crate::metric::{
|
||||
Adaptor, LossInput,
|
||||
processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem},
|
||||
};
|
||||
use burn_core::tensor::{ElementConversion, Tensor, backend::Backend};
|
||||
|
||||
use super::ItemLazy;
|
||||
|
||||
impl ItemLazy for f64 {
|
||||
type ItemSync = f64;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for f64 {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
let device = B::Device::default();
|
||||
LossInput::new(Tensor::from_data([self.elem::<B::FloatElem>()], &device))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn process_train(
|
||||
processor: &mut MinimalEventProcessor<f64, f64>,
|
||||
value: f64,
|
||||
epoch: usize,
|
||||
) {
|
||||
let dummy_progress = burn_core::data::dataloader::Progress {
|
||||
items_processed: 1,
|
||||
items_total: 10,
|
||||
};
|
||||
let dummy_global_progress = burn_core::data::dataloader::Progress {
|
||||
items_processed: epoch,
|
||||
items_total: 3,
|
||||
};
|
||||
let dummy_iteration = Some(1);
|
||||
|
||||
processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new(
|
||||
value,
|
||||
dummy_progress,
|
||||
dummy_global_progress,
|
||||
dummy_iteration,
|
||||
None,
|
||||
)));
|
||||
}
|
||||
|
||||
pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor<f64, f64>, epoch: usize) {
|
||||
processor.process_train(LearnerEvent::EndEpoch(epoch));
|
||||
processor.process_valid(LearnerEvent::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater,
|
||||
metric::{
|
||||
Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate,
|
||||
},
|
||||
};
|
||||
|
||||
pub(crate) struct RLMetrics<TS: ItemLazy, ES: ItemLazy> {
|
||||
train_step: Vec<Box<dyn MetricUpdater<TS::ItemSync>>>,
|
||||
env_step: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
|
||||
env_step_valid: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
|
||||
episode_end: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
|
||||
episode_end_valid: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
|
||||
|
||||
train_step_numeric: Vec<Box<dyn NumericMetricUpdater<TS::ItemSync>>>,
|
||||
env_step_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
|
||||
env_step_valid_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
|
||||
episode_end_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
|
||||
episode_end_valid_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
|
||||
|
||||
metric_definitions: HashMap<MetricId, MetricDefinition>,
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> Default for RLMetrics<TS, ES> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
train_step: Vec::default(),
|
||||
env_step: Vec::default(),
|
||||
env_step_valid: Vec::default(),
|
||||
episode_end: Vec::default(),
|
||||
episode_end_valid: Vec::default(),
|
||||
train_step_numeric: Vec::default(),
|
||||
env_step_numeric: Vec::default(),
|
||||
env_step_valid_numeric: Vec::default(),
|
||||
episode_end_numeric: Vec::default(),
|
||||
episode_end_valid_numeric: Vec::default(),
|
||||
metric_definitions: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> RLMetrics<TS, ES> {
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_text_metric_agent<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_agent_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_text_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
TS::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.train_step.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_metric_train<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
TS::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.train_step_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation env-step metric.
|
||||
pub(crate) fn register_text_metric_agent_valid<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step_valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation env-step numeric metric.
|
||||
pub(crate) fn register_agent_metric_valid<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step_valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end metric.
|
||||
pub(crate) fn register_text_metric_episode<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end numeric metric.
|
||||
pub(crate) fn register_episode_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end metric for validation.
|
||||
pub(crate) fn register_text_metric_episode_valid<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end_valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end numeric metric for validation.
|
||||
pub(crate) fn register_episode_metric_valid<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end_valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
|
||||
self.metric_definitions.insert(
|
||||
metric.id.clone(),
|
||||
MetricDefinition::new(metric.id.clone(), &metric.metric),
|
||||
);
|
||||
}
|
||||
|
||||
/// Get metric definitions for all splits
|
||||
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
|
||||
self.metric_definitions.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Update the training information from the training item.
|
||||
pub(crate) fn update_train_step(
|
||||
&mut self,
|
||||
item: &EvaluationItem<TS::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.train_step.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len());
|
||||
|
||||
for metric in self.train_step.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.train_step_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the env-step metrics from an environment step item.
|
||||
pub(crate) fn update_env_step(
|
||||
&mut self,
|
||||
item: &EvaluationItem<ES::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.env_step.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len());
|
||||
|
||||
for metric in self.env_step.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.env_step_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the env-step metrics for validation from an environment step item.
|
||||
pub(crate) fn update_env_step_valid(
|
||||
&mut self,
|
||||
item: &EvaluationItem<ES::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.env_step_valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len());
|
||||
|
||||
for metric in self.env_step_valid.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.env_step_valid_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the episode-end metrics from an episode summary.
|
||||
pub(crate) fn update_episode_end(
|
||||
&mut self,
|
||||
item: &EvaluationItem<EpisodeSummary>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.episode_end.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len());
|
||||
|
||||
for metric in self.episode_end.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.episode_end_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the episode-end metrics for validation from an episode summary.
|
||||
pub(crate) fn update_episode_end_valid(
|
||||
&mut self,
|
||||
item: &EvaluationItem<EpisodeSummary>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.episode_end_valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len());
|
||||
|
||||
for metric in self.episode_end_valid.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.episode_end_valid_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
EpisodeSummary, EvaluationItem, EventProcessorTraining, ItemLazy, LearnerSummary, RLMetrics,
|
||||
metric::store::{Event, EventStoreClient, MetricsUpdate},
|
||||
renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress},
|
||||
};
|
||||
|
||||
/// Event happening during reinforcement learning.
|
||||
pub enum RLEvent<TS, ES> {
|
||||
/// Signal the start of the process (e.g., learning starts).
|
||||
Start,
|
||||
/// Signal an agent's training step.
|
||||
TrainStep(EvaluationItem<TS>),
|
||||
/// Signal a timestep of the agent-environment interface.
|
||||
TimeStep(EvaluationItem<ES>),
|
||||
/// Signal an episode end.
|
||||
EpisodeEnd(EvaluationItem<EpisodeSummary>),
|
||||
/// Signal the end of the process (e.g., learning ends).
|
||||
End(Option<LearnerSummary>),
|
||||
}
|
||||
|
||||
/// Event happening during evaluation of a reinforcement learning's agent.
|
||||
pub enum AgentEvaluationEvent<T> {
|
||||
/// Signal the start of the process (e.g., training start)
|
||||
Start,
|
||||
/// Signal a timestep of the agent-environment interface.
|
||||
TimeStep(EvaluationItem<T>),
|
||||
/// Signal an episode end.
|
||||
EpisodeEnd(EvaluationItem<EpisodeSummary>),
|
||||
/// Signal the end of the process (e.g., training end).
|
||||
End,
|
||||
}
|
||||
|
||||
/// An [event processor](EventProcessorTraining) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
/// - Render metrics using a [metrics renderer](MetricsRenderer).
|
||||
#[derive(new)]
|
||||
pub struct RLEventProcessor<TS: ItemLazy, ES: ItemLazy> {
|
||||
metrics: RLMetrics<TS, ES>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
|
||||
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
|
||||
let indicators = vec![ProgressType::Detailed {
|
||||
tag: String::from("Step"),
|
||||
progress: progress.global_progress.clone(),
|
||||
}];
|
||||
|
||||
indicators
|
||||
}
|
||||
|
||||
fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
|
||||
let indicators = vec![ProgressType::Detailed {
|
||||
tag: String::from("Step"),
|
||||
progress: progress.global_progress.clone(),
|
||||
}];
|
||||
|
||||
indicators
|
||||
}
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
|
||||
fn process_update_train(&mut self, update: MetricsUpdate) {
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_train(MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
fn process_update_valid(&mut self, update: MetricsUpdate) {
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_valid(MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> EventProcessorTraining<RLEvent<TS, ES>, AgentEvaluationEvent<ES>>
|
||||
for RLEventProcessor<TS, ES>
|
||||
{
|
||||
fn process_train(&mut self, event: RLEvent<TS, ES>) {
|
||||
match event {
|
||||
RLEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
self.store
|
||||
.add_event_train(Event::MetricsInit(definitions.clone()));
|
||||
definitions
|
||||
.iter()
|
||||
.for_each(|definition| self.renderer.register_metric(definition.clone()));
|
||||
}
|
||||
RLEvent::TrainStep(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_train_step(&item, &metadata);
|
||||
self.process_update_train(update);
|
||||
}
|
||||
RLEvent::TimeStep(item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_env_step(&item, &metadata);
|
||||
self.process_update_train(update);
|
||||
let status = self.progress_indicators(&progress);
|
||||
self.renderer.render_train(progress, status);
|
||||
}
|
||||
RLEvent::EpisodeEnd(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_episode_end(&item, &metadata);
|
||||
self.process_update_train(update);
|
||||
}
|
||||
RLEvent::End(learner_summary) => {
|
||||
self.renderer.on_train_end(learner_summary).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: AgentEvaluationEvent<ES>) {
|
||||
match event {
|
||||
AgentEvaluationEvent::Start => {} // no-op for now
|
||||
AgentEvaluationEvent::TimeStep(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_env_step_valid(&item, &metadata);
|
||||
self.process_update_valid(update);
|
||||
}
|
||||
AgentEvaluationEvent::EpisodeEnd(item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_episode_end_valid(&item, &metadata);
|
||||
self.process_update_valid(update);
|
||||
let status = self.progress_indicators_eval(&progress);
|
||||
self.renderer.render_valid(progress, status);
|
||||
}
|
||||
AgentEvaluationEvent::End => {} // no-op for now
|
||||
}
|
||||
}
|
||||
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer> {
|
||||
self.renderer
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
use crate::metric::{MetricName, Numeric};
|
||||
|
||||
use super::{
|
||||
Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,
|
||||
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
|
||||
confusion_stats::{ConfusionStats, ConfusionStatsInput},
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use burn_core::{
|
||||
prelude::{Backend, Tensor},
|
||||
tensor::cast::ToElement,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
use std::{num::NonZeroUsize, sync::Arc};
|
||||
|
||||
///The Recall Metric
|
||||
#[derive(Clone)]
|
||||
pub struct RecallMetric<B: Backend> {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
_b: PhantomData<B>,
|
||||
config: ClassificationMetricConfig,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for RecallMetric<B> {
|
||||
fn default() -> Self {
|
||||
Self::new(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> RecallMetric<B> {
|
||||
fn new(config: ClassificationMetricConfig) -> Self {
|
||||
let state = Default::default();
|
||||
let name = Arc::new(format!(
|
||||
"Recall @ {:?} [{:?}]",
|
||||
config.decision_rule, config.class_reduction
|
||||
));
|
||||
|
||||
Self {
|
||||
state,
|
||||
config,
|
||||
name,
|
||||
_b: Default::default(),
|
||||
}
|
||||
}
|
||||
/// Recall metric for binary classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `threshold` - The threshold to transform a probability into a binary prediction.
|
||||
#[allow(dead_code)]
|
||||
pub fn binary(threshold: f64) -> Self {
|
||||
Self::new(ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
// binary classification results are the same independently of class_reduction
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Recall metric for multiclass classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
|
||||
Self::new(ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::TopK(
|
||||
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
|
||||
),
|
||||
class_reduction,
|
||||
})
|
||||
}
|
||||
|
||||
/// Recall metric for multi-label classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `threshold` - The threshold to transform a probability into a binary prediction.
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
|
||||
Self::new(ClassificationMetricConfig {
|
||||
decision_rule: DecisionRule::Threshold(threshold),
|
||||
class_reduction,
|
||||
})
|
||||
}
|
||||
|
||||
fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
|
||||
use ClassReduction::{Macro, Micro};
|
||||
let avg_tensor = match self.config.class_reduction {
|
||||
Micro => aggregated_metric,
|
||||
Macro => {
|
||||
if aggregated_metric
|
||||
.clone()
|
||||
.contains_nan()
|
||||
.any()
|
||||
.into_scalar()
|
||||
.to_bool()
|
||||
{
|
||||
let nan_mask = aggregated_metric.clone().is_nan();
|
||||
aggregated_metric = aggregated_metric
|
||||
.clone()
|
||||
.select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))
|
||||
}
|
||||
aggregated_metric.mean()
|
||||
}
|
||||
};
|
||||
avg_tensor.into_scalar().to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for RecallMetric<B> {
|
||||
type Input = ConfusionStatsInput<B>;
|
||||
|
||||
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let [sample_size, _] = input.predictions.dims();
|
||||
|
||||
let cf_stats = ConfusionStats::new(input, &self.config);
|
||||
let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive());
|
||||
|
||||
self.state.update(
|
||||
100.0 * metric,
|
||||
sample_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for RecallMetric<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
ClassReduction::{self, *},
|
||||
Metric, MetricMetadata, RecallMetric,
|
||||
};
|
||||
use crate::metric::Numeric;
|
||||
use crate::{
|
||||
TestBackend,
|
||||
tests::{ClassificationType, THRESHOLD, dummy_classification_input},
|
||||
};
|
||||
use burn_core::tensor::{TensorData, Tolerance};
|
||||
use rstest::rstest;
|
||||
|
||||
#[rstest]
|
||||
#[case::binary(THRESHOLD, 0.5)]
|
||||
fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) {
|
||||
let input = dummy_classification_input(&ClassificationType::Binary).into();
|
||||
let mut metric = RecallMetric::binary(threshold);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]
|
||||
#[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)]
|
||||
#[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)]
|
||||
#[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)]
|
||||
fn test_multiclass_recall(
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] top_k: usize,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
|
||||
let mut metric = RecallMetric::multiclass(top_k, class_reduction);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)]
|
||||
#[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)]
|
||||
fn test_multilabel_recall(
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] threshold: f64,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
|
||||
let mut metric = RecallMetric::multilabel(threshold, class_reduction);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameterized_unique_name() {
|
||||
let metric_a = RecallMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
|
||||
let metric_b = RecallMetric::<TestBackend>::multiclass(2, ClassReduction::Macro);
|
||||
let metric_c = RecallMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);
|
||||
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
assert_eq!(metric_a.name(), metric_c.name());
|
||||
|
||||
let metric_a = RecallMetric::<TestBackend>::binary(0.5);
|
||||
let metric_b = RecallMetric::<TestBackend>::binary(0.75);
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Metric for the cumulative reward of the last completed episode.
|
||||
#[derive(Clone)]
|
||||
pub struct CumulativeRewardMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl CumulativeRewardMetric {
|
||||
/// Creates a new episode length metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Cum. Reward".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CumulativeRewardMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct CumulativeRewardInput {
|
||||
cum_reward: f64,
|
||||
}
|
||||
|
||||
impl Metric for CumulativeRewardMetric {
|
||||
type Input = CumulativeRewardInput;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
item: &CumulativeRewardInput,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
self.state.update(
|
||||
item.cum_reward,
|
||||
1,
|
||||
FormatOptions::new(self.name()).precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: None,
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for CumulativeRewardMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Metric for the length of the last completed episode.
|
||||
#[derive(Clone)]
|
||||
pub struct EpisodeLengthMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl EpisodeLengthMetric {
|
||||
/// Creates a new episode length metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Episode length".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EpisodeLengthMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct EpisodeLengthInput {
|
||||
ep_len: f64,
|
||||
}
|
||||
|
||||
impl Metric for EpisodeLengthMetric {
|
||||
type Input = EpisodeLengthInput;
|
||||
|
||||
fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
self.state
|
||||
.update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some(String::from("steps")),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for EpisodeLengthMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Metric for the length of the last completed episode.
|
||||
#[derive(Clone)]
|
||||
pub struct ExplorationRateMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl ExplorationRateMetric {
|
||||
/// Creates a new episode length metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Exploration rate".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ExplorationRateMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The [ExplorationRateMetric](ExplorationRateMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct ExplorationRateInput {
|
||||
exploration_rate: f64,
|
||||
}
|
||||
|
||||
impl Metric for ExplorationRateMetric {
|
||||
type Input = ExplorationRateInput;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
item: &ExplorationRateInput,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
self.state.update(
|
||||
item.exploration_rate,
|
||||
1,
|
||||
FormatOptions::new(self.name()).precision(3),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some(String::from("%")),
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for ExplorationRateMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod cum_reward;
|
||||
mod ep_len;
|
||||
mod exploration_rate;
|
||||
|
||||
pub use cum_reward::*;
|
||||
pub use ep_len::*;
|
||||
pub use exploration_rate::*;
|
||||
@@ -0,0 +1,144 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::metric::{MetricName, NumericEntry, SerializedEntry, format_float};
|
||||
|
||||
/// Useful utility to implement numeric metrics.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The numeric metric store values inside floats.
|
||||
/// Even if some metric are integers, their mean are floats.
|
||||
#[derive(Clone)]
|
||||
pub struct NumericMetricState {
|
||||
sum: f64,
|
||||
count: usize,
|
||||
current: f64,
|
||||
current_count: usize,
|
||||
}
|
||||
|
||||
/// Formatting options for the [numeric metric state](NumericMetricState).
|
||||
pub struct FormatOptions {
|
||||
name: Arc<String>,
|
||||
unit: Option<String>,
|
||||
precision: Option<usize>,
|
||||
}
|
||||
|
||||
impl FormatOptions {
|
||||
/// Create the [formatting options](FormatOptions) with a name.
|
||||
pub fn new(name: MetricName) -> Self {
|
||||
Self {
|
||||
name: name.clone(),
|
||||
unit: None,
|
||||
precision: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Specify the metric unit.
|
||||
pub fn unit(mut self, unit: &str) -> Self {
|
||||
self.unit = Some(unit.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify the floating point precision.
|
||||
pub fn precision(mut self, precision: usize) -> Self {
|
||||
self.precision = Some(precision);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the metric name.
|
||||
pub fn name(&self) -> &Arc<String> {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Get the metric unit.
|
||||
pub fn unit_value(&self) -> &Option<String> {
|
||||
&self.unit
|
||||
}
|
||||
|
||||
/// Get the precision.
|
||||
pub fn precision_value(&self) -> Option<usize> {
|
||||
self.precision
|
||||
}
|
||||
}
|
||||
|
||||
impl NumericMetricState {
|
||||
/// Create a new [numeric metric state](NumericMetricState).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sum: 0.0,
|
||||
count: 0,
|
||||
current: f64::NAN,
|
||||
current_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the state.
|
||||
pub fn reset(&mut self) {
|
||||
self.sum = 0.0;
|
||||
self.count = 0;
|
||||
self.current = f64::NAN;
|
||||
self.current_count = 0;
|
||||
}
|
||||
|
||||
/// Update the state.
|
||||
pub fn update(
|
||||
&mut self,
|
||||
value: f64,
|
||||
batch_size: usize,
|
||||
format: FormatOptions,
|
||||
) -> SerializedEntry {
|
||||
self.sum += value * batch_size as f64;
|
||||
self.count += batch_size;
|
||||
self.current = value;
|
||||
self.current_count = batch_size;
|
||||
|
||||
let value_current = value;
|
||||
let value_running = self.sum / self.count as f64;
|
||||
// Numeric metric state is an aggregated value
|
||||
let serialized = NumericEntry::Aggregated {
|
||||
aggregated_value: value_current,
|
||||
count: batch_size,
|
||||
}
|
||||
.serialize();
|
||||
|
||||
let (formatted_current, formatted_running) = match format.precision {
|
||||
Some(precision) => (
|
||||
format_float(value_current, precision),
|
||||
format_float(value_running, precision),
|
||||
),
|
||||
None => (format!("{value_current}"), format!("{value_running}")),
|
||||
};
|
||||
|
||||
// TODO: naming inconsistent with RL.
|
||||
let formatted = match format.unit {
|
||||
Some(unit) => {
|
||||
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
|
||||
}
|
||||
None => format!("epoch {formatted_running} - batch {formatted_current}"),
|
||||
};
|
||||
|
||||
SerializedEntry::new(formatted, serialized)
|
||||
}
|
||||
|
||||
/// Get the numeric value.
|
||||
pub fn current_value(&self) -> NumericEntry {
|
||||
NumericEntry::Aggregated {
|
||||
aggregated_value: self.current,
|
||||
count: self.current_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the running aggregated value.
|
||||
pub fn running_value(&self) -> NumericEntry {
|
||||
NumericEntry::Aggregated {
|
||||
aggregated_value: self.sum / self.count as f64,
|
||||
count: self.count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NumericMetricState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
use crate::{
|
||||
logger::MetricLogger,
|
||||
metric::{NumericEntry, store::Split},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{Aggregate, Direction};
|
||||
|
||||
/// Type that can be used to fetch and use numeric metric aggregates.
|
||||
#[derive(Default, Debug)]
|
||||
pub(crate) struct NumericMetricsAggregate {
|
||||
value_for_each_epoch: HashMap<Key, f64>,
|
||||
}
|
||||
|
||||
#[derive(new, Hash, PartialEq, Eq, Debug)]
|
||||
struct Key {
|
||||
name: String,
|
||||
epoch: usize,
|
||||
split: Split,
|
||||
aggregate: Aggregate,
|
||||
}
|
||||
|
||||
impl NumericMetricsAggregate {
|
||||
pub(crate) fn aggregate(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
split: &Split,
|
||||
aggregate: Aggregate,
|
||||
loggers: &mut [Box<dyn MetricLogger>],
|
||||
) -> Option<f64> {
|
||||
let key = Key::new(name.to_string(), epoch, split.clone(), aggregate);
|
||||
|
||||
if let Some(value) = self.value_for_each_epoch.get(&key) {
|
||||
return Some(*value);
|
||||
}
|
||||
|
||||
let points = || {
|
||||
let mut errors = Vec::new();
|
||||
for logger in loggers {
|
||||
match logger.read_numeric(name, epoch, split) {
|
||||
Ok(points) => return Ok(points),
|
||||
Err(err) => errors.push(err),
|
||||
};
|
||||
}
|
||||
|
||||
Err(errors.join(" "))
|
||||
};
|
||||
|
||||
let points = points().expect("Can read values");
|
||||
|
||||
if points.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Accurately compute the aggregated value based on the *actual* number of points
|
||||
// since not all mini-batches are guaranteed to have the specified batch size
|
||||
let (sum, num_points) = points
|
||||
.into_iter()
|
||||
.map(|entry| match entry {
|
||||
NumericEntry::Value(v) => (v, 1),
|
||||
// Right now the mean is the only aggregate available, so we can assume that the sum
|
||||
// of an entry corresponds to (value * number of elements)
|
||||
NumericEntry::Aggregated {
|
||||
aggregated_value,
|
||||
count,
|
||||
} => (aggregated_value * count as f64, count),
|
||||
})
|
||||
.reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))
|
||||
.unwrap();
|
||||
let value = match aggregate {
|
||||
Aggregate::Mean => sum / num_points as f64,
|
||||
};
|
||||
|
||||
self.value_for_each_epoch.insert(key, value);
|
||||
Some(value)
|
||||
}
|
||||
|
||||
pub(crate) fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
split: &Split,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
loggers: &mut [Box<dyn MetricLogger>],
|
||||
) -> Option<usize> {
|
||||
let mut data = Vec::new();
|
||||
let mut current_epoch = 1;
|
||||
|
||||
while let Some(value) = self.aggregate(name, current_epoch, split, aggregate, loggers) {
|
||||
data.push(value);
|
||||
current_epoch += 1;
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut current_value = match &direction {
|
||||
Direction::Lowest => f64::MAX,
|
||||
Direction::Highest => f64::MIN,
|
||||
};
|
||||
|
||||
for (i, value) in data.into_iter().enumerate() {
|
||||
match &direction {
|
||||
Direction::Lowest => {
|
||||
if value < current_value {
|
||||
current_value = value;
|
||||
current_epoch = i + 1;
|
||||
}
|
||||
}
|
||||
Direction::Highest => {
|
||||
if value > current_value {
|
||||
current_value = value;
|
||||
current_epoch = i + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(current_epoch)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
logger::{FileMetricLogger, InMemoryMetricLogger},
|
||||
metric::{MetricDefinition, MetricEntry, MetricId, SerializedEntry, store::MetricsUpdate},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
struct TestLogger {
|
||||
logger: FileMetricLogger,
|
||||
epoch: usize,
|
||||
}
|
||||
const NAME: &str = "test-logger";
|
||||
|
||||
impl TestLogger {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
logger: FileMetricLogger::new("/tmp"),
|
||||
epoch: 1,
|
||||
}
|
||||
}
|
||||
fn log(&mut self, num: f64) {
|
||||
let entry = MetricEntry::new(
|
||||
MetricId::new(Arc::new(NAME.into())),
|
||||
SerializedEntry::new(num.to_string(), num.to_string()),
|
||||
);
|
||||
let entries = Vec::from([entry]);
|
||||
let metrics_update = MetricsUpdate::new(entries, vec![]);
|
||||
self.logger.log(metrics_update, self.epoch, &Split::Train);
|
||||
}
|
||||
fn log_definition(&mut self) {
|
||||
let definition = MetricDefinition {
|
||||
metric_id: MetricId::new(Arc::new(NAME.into())),
|
||||
name: NAME.into(),
|
||||
attributes: crate::metric::MetricAttributes::None,
|
||||
description: None,
|
||||
};
|
||||
self.logger.log_metric_definition(definition);
|
||||
}
|
||||
fn new_epoch(&mut self) {
|
||||
self.epoch += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_find_epoch() {
|
||||
let mut logger = TestLogger::new();
|
||||
let mut aggregate = NumericMetricsAggregate::default();
|
||||
logger.log_definition();
|
||||
|
||||
logger.log(500.); // Epoch 1
|
||||
logger.log(1000.); // Epoch 1
|
||||
logger.new_epoch();
|
||||
logger.log(200.); // Epoch 2
|
||||
logger.log(1000.); // Epoch 2
|
||||
logger.new_epoch();
|
||||
logger.log(10000.); // Epoch 3
|
||||
|
||||
let value = aggregate
|
||||
.find_epoch(
|
||||
NAME,
|
||||
&Split::Train,
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
&mut [Box::new(logger.logger)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(value, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_aggregate_numeric_entry() {
|
||||
let mut logger = InMemoryMetricLogger::default();
|
||||
let mut aggregate = NumericMetricsAggregate::default();
|
||||
let metric_name = Arc::new("Loss".to_string());
|
||||
let metric_id = MetricId::new(metric_name.clone());
|
||||
let definition = MetricDefinition {
|
||||
metric_id: metric_id.clone(),
|
||||
name: metric_name.to_string(),
|
||||
attributes: crate::metric::MetricAttributes::None,
|
||||
description: None,
|
||||
};
|
||||
logger.log_metric_definition(definition);
|
||||
|
||||
// Epoch 1
|
||||
let loss_1 = 0.5;
|
||||
let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2
|
||||
let entry = MetricEntry::new(
|
||||
metric_id.clone(),
|
||||
SerializedEntry::new(loss_1.to_string(), NumericEntry::Value(loss_1).serialize()),
|
||||
);
|
||||
let entries = Vec::from([entry]);
|
||||
let metrics_update = MetricsUpdate::new(entries, vec![]);
|
||||
logger.log(metrics_update, 1, &Split::Train);
|
||||
let entry = MetricEntry::new(
|
||||
metric_id.clone(),
|
||||
SerializedEntry::new(
|
||||
loss_2.to_string(),
|
||||
NumericEntry::Aggregated {
|
||||
aggregated_value: loss_2,
|
||||
count: 2,
|
||||
}
|
||||
.serialize(),
|
||||
),
|
||||
);
|
||||
let entries = Vec::from([entry]);
|
||||
let metrics_update = MetricsUpdate::new(entries, vec![]);
|
||||
logger.log(metrics_update, 1, &Split::Train);
|
||||
|
||||
let value = aggregate
|
||||
.aggregate(
|
||||
&metric_name,
|
||||
1,
|
||||
&Split::Train,
|
||||
Aggregate::Mean,
|
||||
&mut [Box::new(logger)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875
|
||||
assert_eq!(value, 1.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::metric::{MetricDefinition, MetricEntry, NumericEntry};
|
||||
|
||||
/// Event happening during the training/validation process.
|
||||
pub enum Event {
|
||||
/// Signal the iniialization of the metrics
|
||||
MetricsInit(Vec<MetricDefinition>),
|
||||
/// Signal that metrics have been updated.
|
||||
MetricsUpdate(MetricsUpdate),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(EpochSummary),
|
||||
}
|
||||
|
||||
/// Contains all metric information.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct NumericMetricUpdate {
|
||||
/// Generic metric information.
|
||||
pub entry: MetricEntry,
|
||||
/// The numeric information.
|
||||
pub numeric_entry: NumericEntry,
|
||||
/// Numeric value averaged over the global step (epoch).
|
||||
pub running_entry: NumericEntry,
|
||||
}
|
||||
|
||||
/// Contains all metric information.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct MetricsUpdate {
|
||||
/// Metrics information related to non-numeric metrics.
|
||||
pub entries: Vec<MetricEntry>,
|
||||
/// Metrics information related to numeric metrics.
|
||||
pub entries_numeric: Vec<NumericMetricUpdate>,
|
||||
}
|
||||
|
||||
/// Summary information about a given epoch
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct EpochSummary {
|
||||
/// Epoch number.
|
||||
pub epoch_number: usize,
|
||||
/// Dataset split (train, valid, test).
|
||||
pub split: Split,
|
||||
}
|
||||
|
||||
/// Defines how training and validation events are collected and searched.
|
||||
///
|
||||
/// This trait also exposes methods that uses the collected data to compute useful information.
|
||||
pub trait EventStore: Send {
|
||||
/// Collect a training/validation event.
|
||||
fn add_event(&mut self, event: Event, split: Split);
|
||||
|
||||
/// Find the epoch following the given criteria from the collected data.
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: &Split,
|
||||
) -> Option<usize>;
|
||||
|
||||
/// Find the metric value for the current epoch following the given criteria.
|
||||
fn find_metric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
split: &Split,
|
||||
) -> Option<f64>;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
|
||||
/// How to aggregate the metric.
|
||||
pub enum Aggregate {
|
||||
/// Compute the average.
|
||||
Mean,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
/// The split to use.
|
||||
pub enum Split {
|
||||
/// The training split.
|
||||
Train,
|
||||
/// The validation split.
|
||||
Valid,
|
||||
/// The testing split, which might be tagged.
|
||||
Test(Option<Arc<String>>),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Split {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Split::Train => write!(f, "train"),
|
||||
Split::Valid => write!(f, "valid"),
|
||||
Split::Test(_) => write!(f, "test"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/// The direction of the query.
|
||||
pub enum Direction {
|
||||
/// Lower is better.
|
||||
Lowest,
|
||||
/// Higher is better.
|
||||
Highest,
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
use super::EventStore;
|
||||
use super::{Aggregate, Direction, Event, Split};
|
||||
use std::sync::Arc;
|
||||
use std::{sync::mpsc, thread::JoinHandle};
|
||||
|
||||
/// Type that allows to communicate with an [event store](EventStore).
|
||||
pub struct EventStoreClient {
|
||||
sender: mpsc::Sender<Message>,
|
||||
handler: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl EventStoreClient {
|
||||
/// Create a new [event store](EventStore) client.
|
||||
pub(crate) fn new<C>(store: C) -> Self
|
||||
where
|
||||
C: EventStore + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = WorkerThread::new(store, receiver);
|
||||
|
||||
let handler = std::thread::spawn(move || thread.run());
|
||||
let handler = Some(handler);
|
||||
|
||||
Self { sender, handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl EventStoreClient {
|
||||
/// Add a training event to the [event store](EventStore).
|
||||
pub(crate) fn add_event_train(&self, event: Event) {
|
||||
self.sender
|
||||
.send(Message::OnEventTrain(event))
|
||||
.expect("Can send event to event store thread.");
|
||||
}
|
||||
|
||||
/// Add a validation event to the [event store](EventStore).
|
||||
pub(crate) fn add_event_valid(&self, event: Event) {
|
||||
self.sender
|
||||
.send(Message::OnEventValid(event))
|
||||
.expect("Can send event to event store thread.");
|
||||
}
|
||||
|
||||
/// Add a testing event to the [event store](EventStore).
|
||||
pub(crate) fn add_event_test(&self, event: Event, tag: Arc<String>) {
|
||||
self.sender
|
||||
.send(Message::OnEventTest(event, tag))
|
||||
.expect("Can send event to event store thread.");
|
||||
}
|
||||
|
||||
/// Find the epoch following the given criteria from the collected data.
|
||||
pub fn find_epoch(
|
||||
&self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: &Split,
|
||||
) -> Option<usize> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::FindEpoch(
|
||||
name.to_string(),
|
||||
aggregate,
|
||||
direction,
|
||||
split.clone(),
|
||||
sender,
|
||||
))
|
||||
.expect("Can send event to event store thread.");
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("Event store thread crashed: {err:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the metric value for the current epoch following the given criteria.
|
||||
pub fn find_metric(
|
||||
&self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
split: &Split,
|
||||
) -> Option<f64> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::FindMetric(
|
||||
name.to_string(),
|
||||
epoch,
|
||||
aggregate,
|
||||
split.clone(),
|
||||
sender,
|
||||
))
|
||||
.expect("Can send event to event store thread.");
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("Event store thread crashed: {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct WorkerThread<S> {
|
||||
store: S,
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
}
|
||||
|
||||
impl<C> WorkerThread<C>
|
||||
where
|
||||
C: EventStore,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
Message::FindEpoch(name, aggregate, direction, split, callback) => {
|
||||
let response = self.store.find_epoch(&name, aggregate, direction, &split);
|
||||
callback
|
||||
.send(response)
|
||||
.expect("Can send response using callback channel.");
|
||||
}
|
||||
Message::FindMetric(name, epoch, aggregate, split, callback) => {
|
||||
let response = self.store.find_metric(&name, epoch, aggregate, &split);
|
||||
callback
|
||||
.send(response)
|
||||
.expect("Can send response using callback channel.");
|
||||
}
|
||||
Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
|
||||
Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
|
||||
Message::OnEventTest(event, tag) => {
|
||||
self.store.add_event(event, Split::Test(Some(tag)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Message {
|
||||
OnEventTest(Event, Arc<String>),
|
||||
OnEventTrain(Event),
|
||||
OnEventValid(Event),
|
||||
End,
|
||||
FindEpoch(
|
||||
String,
|
||||
Aggregate,
|
||||
Direction,
|
||||
Split,
|
||||
mpsc::SyncSender<Option<usize>>,
|
||||
),
|
||||
FindMetric(
|
||||
String,
|
||||
usize,
|
||||
Aggregate,
|
||||
Split,
|
||||
mpsc::SyncSender<Option<f64>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl Drop for EventStoreClient {
|
||||
fn drop(&mut self) {
|
||||
self.sender
|
||||
.send(Message::End)
|
||||
.expect("Can send the end message to the event store thread.");
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler.join().expect("The event store thread should stop.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{Aggregate, Direction, Event, EventStore, Split, aggregate::NumericMetricsAggregate};
|
||||
use crate::logger::MetricLogger;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct LogEventStore {
|
||||
loggers: Vec<Box<dyn MetricLogger>>,
|
||||
aggregate: NumericMetricsAggregate,
|
||||
epochs: HashMap<Split, usize>,
|
||||
}
|
||||
|
||||
impl EventStore for LogEventStore {
|
||||
fn add_event(&mut self, event: Event, split: Split) {
|
||||
let epoch = *self.epochs.entry(split.clone()).or_insert(1);
|
||||
|
||||
match event {
|
||||
Event::MetricsInit(definitions) => {
|
||||
definitions.iter().for_each(|def| {
|
||||
self.loggers
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.log_metric_definition(def.clone()));
|
||||
});
|
||||
}
|
||||
Event::MetricsUpdate(update) => {
|
||||
self.loggers
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.log(update.clone(), epoch, &split));
|
||||
}
|
||||
Event::EndEpoch(summary) => {
|
||||
self.epochs.insert(split, summary.epoch_number + 1);
|
||||
self.loggers
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.log_epoch_summary(summary.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: &Split,
|
||||
) -> Option<usize> {
|
||||
self.aggregate
|
||||
.find_epoch(name, split, aggregate, direction, &mut self.loggers)
|
||||
}
|
||||
|
||||
fn find_metric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
split: &Split,
|
||||
) -> Option<f64> {
|
||||
self.aggregate
|
||||
.aggregate(name, epoch, split, aggregate, &mut self.loggers)
|
||||
}
|
||||
}
|
||||
|
||||
impl LogEventStore {
|
||||
/// Register a logger for metrics.
|
||||
pub(crate) fn register_logger<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers.push(Box::new(logger));
|
||||
}
|
||||
|
||||
/// Returns whether any loggers are registered.
|
||||
pub(crate) fn has_loggers(&self) -> bool {
|
||||
!self.loggers.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
pub(crate) mod aggregate;
|
||||
|
||||
mod base;
|
||||
mod client;
|
||||
mod log;
|
||||
|
||||
pub(crate) use self::log::*;
|
||||
pub use base::*;
|
||||
pub use client::*;
|
||||
@@ -0,0 +1,185 @@
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use super::{MetricMetadata, SerializedEntry};
|
||||
use crate::metric::{
|
||||
Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
|
||||
};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{ElementConversion, Int, Tensor};
|
||||
|
||||
/// The Top-K accuracy metric.
|
||||
///
|
||||
/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`).
|
||||
#[derive(Default, Clone)]
|
||||
pub struct TopKAccuracyMetric<B: Backend> {
|
||||
name: Arc<String>,
|
||||
k: usize,
|
||||
state: NumericMetricState,
|
||||
/// If specified, targets equal to this value will be considered padding and will not count
|
||||
/// towards the metric
|
||||
pad_token: Option<usize>,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [top-k accuracy metric](TopKAccuracyMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct TopKAccuracyInput<B: Backend> {
|
||||
/// The outputs (batch_size, num_classes)
|
||||
outputs: Tensor<B, 2>,
|
||||
/// The labels (batch_size)
|
||||
targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TopKAccuracyMetric<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new(k: usize) -> Self {
|
||||
Self {
|
||||
name: Arc::new(format!("Top-K Accuracy @ TopK({})", k)),
|
||||
k,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the pad token.
|
||||
pub fn with_pad_token(mut self, index: usize) -> Self {
|
||||
self.pad_token = Some(index);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for TopKAccuracyMetric<B> {
|
||||
type Input = TopKAccuracyInput<B>;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
input: &TopKAccuracyInput<B>,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
let [batch_size, _n_classes] = input.outputs.dims();
|
||||
|
||||
let targets = input.targets.clone().to_device(&B::Device::default());
|
||||
|
||||
let outputs = input
|
||||
.outputs
|
||||
.clone()
|
||||
.argsort_descending(1)
|
||||
.narrow(1, 0, self.k)
|
||||
.to_device(&B::Device::default())
|
||||
.reshape([batch_size, self.k]);
|
||||
|
||||
let (targets, num_pad) = match self.pad_token {
|
||||
Some(pad_token) => {
|
||||
// we ignore the samples where the target is equal to the pad token
|
||||
let mask = targets.clone().equal_elem(pad_token as i64);
|
||||
let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();
|
||||
(targets.clone().mask_fill(mask, -1_i64), num_pad)
|
||||
}
|
||||
None => (targets.clone(), 0_f64),
|
||||
};
|
||||
|
||||
let accuracy = targets
|
||||
.reshape([batch_size, 1])
|
||||
.repeat_dim(1, self.k)
|
||||
.equal(outputs)
|
||||
.int()
|
||||
.sum()
|
||||
.into_scalar()
|
||||
.elem::<f64>()
|
||||
/ (batch_size as f64 - num_pad);
|
||||
|
||||
self.state.update(
|
||||
100.0 * accuracy,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some("%".to_string()),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for TopKAccuracyMetric<B> {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_accuracy_without_padding() {
|
||||
let device = Default::default();
|
||||
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);
|
||||
let input = TopKAccuracyInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.2, 0.8], // 2, 1
|
||||
[1.0, 2.0, 0.5], // 1, 0
|
||||
[0.4, 0.1, 0.2], // 0, 2
|
||||
[0.6, 0.7, 0.2], // 1, 0
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2, 2, 1, 1], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(50.0, metric.value().current());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accuracy_with_padding() {
|
||||
let device = Default::default();
|
||||
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);
|
||||
let input = TopKAccuracyInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.0, 0.2, 0.8, 0.0], // 2, 1
|
||||
[1.0, 2.0, 0.5, 0.0], // 1, 0
|
||||
[0.4, 0.1, 0.2, 0.0], // 0, 2
|
||||
[0.6, 0.7, 0.2, 0.0], // 1, 0
|
||||
[0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
|
||||
[0.0, 0.1, 0.2, 0.0], // Error on padding should not count
|
||||
[0.6, 0.0, 0.2, 0.0], // Error on padding should not count
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(50.0, metric.value().current());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameterized_unique_name() {
|
||||
let metric_a = TopKAccuracyMetric::<TestBackend>::new(2);
|
||||
let metric_b = TopKAccuracyMetric::<TestBackend>::new(1);
|
||||
let metric_c = TopKAccuracyMetric::<TestBackend>::new(2);
|
||||
|
||||
assert_ne!(metric_a.name(), metric_b.name());
|
||||
assert_eq!(metric_a.name(), metric_c.name());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
use crate::metric::{MetricAttributes, MetricName, SerializedEntry};
|
||||
|
||||
use super::super::{
|
||||
Metric, MetricMetadata,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use burn_core::{
|
||||
prelude::{Backend, Tensor},
|
||||
tensor::{ElementConversion, Int, s},
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Input type for the [DiceMetric].
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `B`: Backend type.
|
||||
/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).
|
||||
pub struct DiceInput<B: Backend, const D: usize = 4> {
|
||||
/// Model outputs (predictions), as a tensor.
|
||||
outputs: Tensor<B, D, Int>,
|
||||
/// Ground truth targets, as a tensor.
|
||||
targets: Tensor<B, D, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> DiceInput<B, D> {
|
||||
/// Creates a new DiceInput with the given outputs and targets.
|
||||
///
|
||||
/// Inputs are expected to have the dimensions `[B, C, ...]`
|
||||
/// where `B` is the batch size, `C` is the number of classes,
|
||||
/// and `...` represents additional dimensions (e.g., height, width for images).
|
||||
///
|
||||
/// If `C` is more than 1, the first class (index 0) is considered the background.
|
||||
/// Additionally, one-hot encoding is the responsibility of the caller.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `outputs`: The model outputs as a tensor.
|
||||
/// - `targets`: The ground truth targets as a tensor.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new instance of `DiceInput`.
|
||||
///
|
||||
/// # Panics
|
||||
/// - If `D` is less than 3.
|
||||
/// - If `outputs` and `targets` do not have the same dimensions.
|
||||
/// - If `outputs` or `targets` do not have exactly `D` dimensions.
|
||||
/// - If `outputs` and `targets` do not have the same shape.
|
||||
pub fn new(outputs: Tensor<B, D, Int>, targets: Tensor<B, D, Int>) -> Self {
|
||||
assert!(D >= 3, "DiceInput requires at least 3 dimensions.");
|
||||
assert!(
|
||||
outputs.dims() == targets.dims(),
|
||||
"Outputs and targets must have the same dimensions. Got {:?} and {:?}",
|
||||
outputs.dims(),
|
||||
targets.dims()
|
||||
);
|
||||
Self { outputs, targets }
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the [DiceMetric].
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DiceMetricConfig {
|
||||
/// Epsilon value to avoid division by zero.
|
||||
pub epsilon: f64,
|
||||
/// Whether to include the background class in the metric calculation.
|
||||
/// The background is assumed to be the first class (index 0).
|
||||
/// if `true`, will panic if there are fewer than 2 classes.
|
||||
pub include_background: bool,
|
||||
}
|
||||
|
||||
impl Default for DiceMetricConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
epsilon: 1e-7,
|
||||
include_background: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The Dice-Sorenson coefficient (DSC) for evaluating overlap between two binary masks.
|
||||
/// The DSC is defined as:
|
||||
/// `DSC = 2 * (|X ∩ Y|) / (|X| + |Y|)`
|
||||
/// where `X` is the model output and `Y` is the ground truth target.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `B`: Backend type.
|
||||
/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).
|
||||
#[derive(Default, Clone)]
|
||||
pub struct DiceMetric<B: Backend, const D: usize = 4> {
|
||||
name: MetricName,
|
||||
/// Internal state for numeric metric aggregation.
|
||||
state: NumericMetricState,
|
||||
/// Marker for backend type.
|
||||
_b: PhantomData<B>,
|
||||
/// Configuration for the metric.
|
||||
config: DiceMetricConfig,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> DiceMetric<B, D> {
|
||||
/// Creates a new Dice metric instance with default config.
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(DiceMetricConfig::default())
|
||||
}
|
||||
|
||||
/// Creates a new Dice metric with a custom config.
|
||||
pub fn with_config(config: DiceMetricConfig) -> Self {
|
||||
let name = MetricName::new(format!("{D}D Dice Metric"));
|
||||
assert!(D >= 3, "DiceMetric requires at least 3 dimensions.");
|
||||
Self {
|
||||
name,
|
||||
config,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Metric for DiceMetric<B, D> {
|
||||
type Input = DiceInput<B, D>;
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
// Dice coefficient: 2 * (|X ∩ Y|) / (|X| + |Y|)
|
||||
if item.outputs.dims() != item.targets.dims() {
|
||||
panic!(
|
||||
"Outputs and targets must have the same dimensions. Got {:?} and {:?}",
|
||||
item.outputs.dims(),
|
||||
item.targets.dims()
|
||||
);
|
||||
}
|
||||
|
||||
let dims = item.outputs.dims();
|
||||
let batch_size = dims[0];
|
||||
let n_classes = dims[1];
|
||||
|
||||
let mut outputs = item.outputs.clone();
|
||||
let mut targets = item.targets.clone();
|
||||
|
||||
if !self.config.include_background && n_classes > 1 {
|
||||
// If not including background, we can ignore the first class
|
||||
outputs = outputs.slice(s![.., 1..]);
|
||||
targets = targets.slice(s![.., 1..]);
|
||||
} else if self.config.include_background && n_classes < 2 {
|
||||
// If including background, we need at least 2 classes
|
||||
panic!("Dice metric requires at least 2 classes when including background.");
|
||||
}
|
||||
|
||||
let intersection = (outputs.clone() * targets.clone()).sum();
|
||||
let outputs_sum = outputs.sum();
|
||||
let targets_sum = targets.sum();
|
||||
|
||||
// Convert to f64
|
||||
let intersection_val = intersection.into_scalar().elem::<f64>();
|
||||
let outputs_sum_val = outputs_sum.into_scalar().elem::<f64>();
|
||||
let targets_sum_val = targets_sum.into_scalar().elem::<f64>();
|
||||
|
||||
// Use epsilon from config
|
||||
let epsilon = self.config.epsilon;
|
||||
let dice =
|
||||
(2.0 * intersection_val + epsilon) / (outputs_sum_val + targets_sum_val + epsilon);
|
||||
|
||||
self.state.update(
|
||||
dice,
|
||||
batch_size,
|
||||
FormatOptions::new(self.name()).precision(4),
|
||||
)
|
||||
}
|
||||
|
||||
/// Clears the metric state.
|
||||
fn clear(&mut self) {
|
||||
self.state.reset();
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
crate::metric::NumericAttributes {
|
||||
unit: None,
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> crate::metric::Numeric for DiceMetric<B, D> {
|
||||
fn value(&self) -> crate::metric::NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> crate::metric::NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{TestBackend, metric::Numeric};
|
||||
use burn_core::tensor::{Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_dice_perfect_overlap() {
|
||||
let device = Default::default();
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::new();
|
||||
let input = DiceInput::new(
|
||||
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
|
||||
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert!((metric.value().current() - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dice_no_overlap() {
|
||||
let device = Default::default();
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::new();
|
||||
let input = DiceInput::new(
|
||||
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
|
||||
Tensor::from_data([[[[0, 1], [0, 1]]]], &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert!(metric.value().current() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dice_partial_overlap() {
|
||||
let device = Default::default();
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::new();
|
||||
let input = DiceInput::new(
|
||||
Tensor::from_data([[[[1, 1], [0, 0]]]], &device),
|
||||
Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
// intersection = 1, sum = 2+2=4, dice = 2*1/4 = 0.5
|
||||
assert!((metric.value().current() - 0.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dice_empty_masks() {
|
||||
let device = Default::default();
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::new();
|
||||
let input = DiceInput::new(
|
||||
Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
|
||||
Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert!((metric.value().current() - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dice_no_background() {
|
||||
let device = Default::default();
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::new();
|
||||
let input = DiceInput::new(
|
||||
Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
|
||||
Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert!((metric.value().current() - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dice_with_background() {
|
||||
let device = Default::default();
|
||||
let config = DiceMetricConfig {
|
||||
epsilon: 1e-7,
|
||||
include_background: true,
|
||||
};
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
|
||||
let input = DiceInput::new(
|
||||
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
|
||||
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert!((metric.value().current() - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dice_ignored_background() {
|
||||
let device = Default::default();
|
||||
let config = DiceMetricConfig {
|
||||
epsilon: 1e-7,
|
||||
include_background: false,
|
||||
};
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
|
||||
let input = DiceInput::new(
|
||||
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
|
||||
Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
|
||||
);
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert!((metric.value().current() - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "DiceInput requires at least 3 dimensions.")]
|
||||
fn test_invalid_input_dimensions() {
|
||||
let device = Default::default();
|
||||
// D = 2, should panic
|
||||
let _ = DiceInput::<TestBackend, 2>::new(
|
||||
Tensor::from_data([[0.0, 0.0]], &device),
|
||||
Tensor::from_data([[0.0, 0.0]], &device),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(
|
||||
expected = "Outputs and targets must have the same dimensions. Got [1, 1, 2, 2] and [1, 1, 2, 3]"
|
||||
)]
|
||||
fn test_mismatched_shape() {
|
||||
let device = Default::default();
|
||||
// shapes differ
|
||||
let _ = DiceInput::<TestBackend, 4>::new(
|
||||
Tensor::from_data([[[[0.0; 2]; 2]; 1]; 1], &device),
|
||||
Tensor::from_data([[[[0.0; 3]; 2]; 1]; 1], &device),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dice metric requires at least 2 classes when including background.")]
|
||||
fn test_include_background_panic() {
|
||||
let device = Default::default();
|
||||
let config = DiceMetricConfig {
|
||||
epsilon: 1e-7,
|
||||
include_background: true,
|
||||
};
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
|
||||
let input = DiceInput::new(
|
||||
Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
|
||||
Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
|
||||
);
|
||||
// n_classes = 2, should not panic
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
|
||||
let config = DiceMetricConfig {
|
||||
epsilon: 1e-7,
|
||||
include_background: true,
|
||||
};
|
||||
let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
|
||||
let input = DiceInput::new(
|
||||
Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
|
||||
Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
|
||||
);
|
||||
// n_classes = 1, should panic
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user