feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,138 @@
use crate::{
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
grads::Gradients,
tensor::AutodiffTensor,
};
use alloc::{format, string::String};
use burn_backend::{
backend::{AutodiffBackend, Backend, ExecutionError},
tensor::{BoolTensor, IntTensor, QuantizedTensor},
};
use core::marker::PhantomData;
/// Enable auto-differentiation on a backend.
///
/// This works as a backend decorator, extending the functionality of any backend with
/// backpropagation.
#[derive(Clone, Copy, Debug, Default)]
pub struct Autodiff<B, C = NoCheckpointing> {
_b: PhantomData<B>,
_checkpoint_strategy: PhantomData<C>,
}
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type Device = B::Device;
type FloatTensorPrimitive = AutodiffTensor<B>;
type FloatElem = B::FloatElem;
type IntTensorPrimitive = B::IntTensorPrimitive;
type IntElem = B::IntElem;
type BoolTensorPrimitive = B::BoolTensorPrimitive;
type BoolElem = B::BoolElem;
type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
fn ad_enabled(_device: &Self::Device) -> bool {
true
}
fn name(device: &Self::Device) -> String {
format!("autodiff<{}>", B::name(device))
}
fn seed(device: &B::Device, seed: u64) {
B::seed(device, seed)
}
fn sync(device: &B::Device) -> Result<(), ExecutionError> {
B::sync(device)
}
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output {
B::memory_persistent_allocations(device, input, func)
}
fn memory_cleanup(device: &Self::Device) {
B::memory_cleanup(device)
}
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
where
Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
{
B::staging(data, device);
}
fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
B::supports_dtype(device, dtype)
}
fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
B::dtype_usage(device, dtype)
}
}
impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
type InnerBackend = B;
type Gradients = Gradients;
fn backward(tensor: AutodiffTensor<B>) -> Gradients {
tensor.backward()
}
fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
tensor.grad(grads)
}
fn grad_remove(
tensor: &AutodiffTensor<B>,
grads: &mut Gradients,
) -> Option<B::FloatTensorPrimitive> {
tensor.grad_remove(grads)
}
fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
tensor.primitive
}
fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
AutodiffTensor::new(tensor)
}
fn grad_replace(
tensor: &AutodiffTensor<B>,
grads: &mut Self::Gradients,
grad: B::FloatTensorPrimitive,
) {
tensor.grad_replace(grads, grad);
}
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
tensor
}
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
tensor
}
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
tensor
}
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
tensor
}
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
tensor
}
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
tensor
}
}

View File

@@ -0,0 +1,82 @@
use super::{
retro_forward::RetroForwards,
state::{BackwardStates, State},
};
use crate::collections::HashMap;
use crate::graph::NodeId;
use alloc::{vec, vec::Vec};
#[derive(new, Debug)]
/// Links a [NodeId] to its autodiff graph [NodeRef]
pub(crate) struct NodeTree {
map: HashMap<NodeId, Vec<NodeId>>,
}
impl NodeTree {
/// Gives the parents of the node in the autodiff graph
pub(crate) fn parents(&self, node_id: &NodeId) -> Option<Vec<NodeId>> {
self.map.get(node_id).cloned()
}
}
#[derive(new, Debug)]
/// Struct responsible of fetching the output for a node in the autodiff graph during a backward pass
pub struct Checkpointer {
backward_states: BackwardStates,
retro_forwards: RetroForwards,
node_tree: NodeTree,
}
impl Checkpointer {
/// Gives the output of the given node, by recursively asking parents to compute themselves
/// or give their pre-computed tensors.
pub fn retrieve_node_output<T>(&mut self, node_id: NodeId) -> T
where
T: Clone + Send + 'static,
{
self.topological_sort(node_id).into_iter().for_each(|node| {
self.retro_forwards
.execute_retro_forward(node, &mut self.backward_states)
});
self.backward_states.get_state::<T>(&node_id)
}
/// Sorts the ancestors of NodeId in a way such that all parents come before their children
/// Useful to avoid recursivity later when mutating the states
///
/// The sort on a compute bound state or a memory bound that is already computed is trivial.
/// The match on State::Computed also serves as a stopping criterion for the sort,
/// we don't need to look higher than that during recursivity.
fn topological_sort(&self, node_id: NodeId) -> Vec<NodeId> {
match self.backward_states.get_state_ref(&node_id) {
Some(state) => match state {
State::Recompute { n_required: _ } => {
let mut sorted = Vec::new();
let parents = self.node_tree.parents(&node_id).unwrap();
for parent_node in parents {
let parent_sorted = self.topological_sort(parent_node);
for ps in parent_sorted {
if !sorted.contains(&ps) {
sorted.push(ps)
}
}
}
sorted.push(node_id);
sorted
}
State::Computed {
state_content: _,
n_required: _,
} => vec![node_id],
},
None => panic!("Node {node_id:?} is not in the backward_states. "),
}
}
/// Checks if checkpointer has been drained adequately. Useful for testing
pub fn is_empty(&self) -> bool {
self.backward_states.is_empty() && self.retro_forwards.is_empty()
}
}

View File

@@ -0,0 +1,304 @@
use crate::{
collections::HashMap,
graph::{ComputingProperty, NodeId},
tensor::AutodiffTensor,
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use burn_backend::Backend;
use core::any::Any;
use super::{
base::{Checkpointer, NodeTree},
retro_forward::{RetroForward, RetroForwards},
state::{BackwardStates, State},
};
#[derive(Debug)]
/// Determines if a node should checkpoint its computed output or its retro_forward for recomputation
/// The action is normally created by the child of the node, once the node is determined to be needed
pub enum CheckpointingAction {
/// The node's already computed output should be saved
Computed {
/// The node
node_id: NodeId,
/// The node's output
state_content: Box<dyn Any + Send>,
},
/// The node should recompute itself when asked
Recompute {
/// The node
node_id: NodeId,
/// How the node should recompute itself
retro_forward: Arc<dyn RetroForward>,
},
}
// TODO: Remove that when proper client server.
unsafe impl Send for CheckpointingAction {}
impl CheckpointingAction {
/// Utility function to access the id of the node of the checkpointing action
pub fn id(&self) -> NodeId {
match self {
CheckpointingAction::Computed {
node_id: node_ref,
state_content: _,
} => *node_ref,
CheckpointingAction::Recompute {
node_id: node_ref,
retro_forward: _,
} => *node_ref,
}
}
}
#[derive(new, Debug, Default)]
/// Accumulates checkpoints as checkpointing actions during the forward pass,
/// and builds a checkpointer right before the backward pass
pub struct CheckpointerBuilder {
explicit_actions: Vec<CheckpointingAction>,
backup_actions: Vec<CheckpointingAction>,
}
/// Determines if a checkpoint should impact the n_required values (Main)
/// or if it should just keep the state in case it's required (Backup)
///
pub(crate) enum ActionType {
/// Explicit actions have been explicitly requested by some operation to retrieve their state
Explicit,
/// Backup actions are not always needed. They exist to save the output of an operation
/// whose child is memory bound, in case the state is indirectly needed when computing
/// the child's retro_forward. If no explicit action ever asks for the child's output, then
/// the backup output will go out of scope when the checkpointer is built.
Backup,
}
impl CheckpointerBuilder {
pub(crate) fn checkpoint<B: Backend>(
&mut self,
tensor: &AutodiffTensor<B>,
action_type: ActionType,
) {
let action_list = match action_type {
ActionType::Explicit => &mut self.explicit_actions,
ActionType::Backup => &mut self.backup_actions,
};
match &tensor.node.properties {
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
action_list.push(CheckpointingAction::Computed {
node_id: tensor.node.id,
state_content: Box::new(tensor.primitive.clone()),
})
}
ComputingProperty::MemoryBound { retro_forward } => {
action_list.push(CheckpointingAction::Recompute {
node_id: tensor.node.id,
retro_forward: retro_forward.clone(),
})
}
}
}
pub(crate) fn extend(&mut self, other: CheckpointerBuilder) {
for other_action in other.explicit_actions {
self.explicit_actions.push(other_action)
}
for other_unsure in other.backup_actions {
self.backup_actions.push(other_unsure)
}
}
pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {
let mut backward_states_map = HashMap::new();
let mut retro_forwards_map = HashMap::new();
// Find recursion stopping points
let stop_nodes: Vec<NodeId> = self.find_stop_nodes();
// We start by identifying how many times each node will be required.
let n_required_map = self.build_n_required_map(&node_tree, stop_nodes);
// Then we checkpoint the nodes with the corresponding n_required value
self.insert_checkpoints(
&mut backward_states_map,
&mut retro_forwards_map,
n_required_map,
);
Checkpointer::new(
BackwardStates::new(backward_states_map),
RetroForwards::new(retro_forwards_map),
node_tree,
)
}
fn find_stop_nodes(&self) -> Vec<NodeId> {
let mut stop_nodes = Vec::default();
for action in self
.explicit_actions
.iter()
.chain(self.backup_actions.iter())
{
match action {
CheckpointingAction::Computed {
node_id: node_ref,
state_content: _,
} => stop_nodes.push(*node_ref),
CheckpointingAction::Recompute {
node_id: _,
retro_forward: _,
} => {}
}
}
stop_nodes
}
fn build_n_required_map(
&self,
node_tree: &NodeTree,
stop_nodes: Vec<NodeId>,
) -> HashMap<NodeId, usize> {
let mut n_required_map = HashMap::<NodeId, usize>::default();
for action in self.explicit_actions.iter() {
match action {
CheckpointingAction::Computed {
node_id: node_ref,
state_content: _,
} => {
let id = *node_ref;
match n_required_map.remove(&id) {
Some(n) => {
n_required_map.insert(id, n + 1);
}
None => {
n_required_map.insert(id, 1);
}
};
}
CheckpointingAction::Recompute {
node_id: node_ref,
retro_forward: _,
} => {
let id = *node_ref;
Self::update_n_required_of_parents(
id,
&mut n_required_map,
node_tree,
&stop_nodes,
);
}
}
}
n_required_map
}
fn insert_checkpoints(
mut self,
backward_states_map: &mut HashMap<NodeId, State>,
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
n_required_map: HashMap<NodeId, usize>,
) {
// We do not loop over checkpointing actions anymore because they can contain
// duplicates or miss some that are in backup. We loop over the n_required_map
// from which we use the ids to find them again in the checkpointing actions
for (node_id, n_required) in n_required_map {
// We find the checkpointing action for node_id. It's likely in checkpointing_actions
// so we check there first, otherwise it will be in backup.
// Technically it can be there several times but can never be of both types, so we can assume the first we find is fine
let action = match self
.explicit_actions
.iter()
.position(|action| action.id() == node_id)
{
Some(pos) => self.explicit_actions.remove(pos),
None => {
let pos = self
.backup_actions
.iter()
.position(|action| action.id() == node_id);
self.backup_actions.remove(pos.unwrap_or_else(|| {
panic!("Node {:?} is needed but never checkpointed", &node_id)
}))
}
};
match action {
CheckpointingAction::Computed {
node_id: _,
state_content,
} => {
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
}
CheckpointingAction::Recompute {
node_id: _,
retro_forward,
} => self.checkpoint_lazy(
backward_states_map,
retro_forward_map,
node_id,
retro_forward,
n_required,
),
};
}
}
fn update_n_required_of_parents(
id: NodeId,
n_required_map: &mut HashMap<NodeId, usize>,
node_tree: &NodeTree,
stop_nodes: &Vec<NodeId>,
) {
match n_required_map.remove(&id) {
Some(n) => {
n_required_map.insert(id, n + 1);
}
None => {
n_required_map.insert(id, 1);
if !stop_nodes.contains(&id)
&& let Some(parents) = node_tree.parents(&id)
{
for p in parents {
Self::update_n_required_of_parents(
p,
n_required_map,
node_tree,
stop_nodes,
);
}
}
}
}
}
fn checkpoint_compute(
&self,
backward_states_map: &mut HashMap<NodeId, State>,
node_id: NodeId,
state_content: Box<dyn Any + Send>,
n_required: usize,
) {
backward_states_map.insert(
node_id,
State::Computed {
state_content,
n_required,
},
);
}
fn checkpoint_lazy(
&self,
backward_states_map: &mut HashMap<NodeId, State>,
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
node_id: NodeId,
retro_forward: Arc<dyn RetroForward>,
n_required: usize,
) {
retro_forward_map.insert(node_id, retro_forward);
backward_states_map.insert(node_id, State::Recompute { n_required });
}
}

View File

@@ -0,0 +1,9 @@
/// Checkpointer module
pub mod base;
pub(crate) mod builder;
/// RetroForward module
pub mod retro_forward;
/// BackwardStates module
pub mod state;
/// CheckpointStrategy module
pub mod strategy;

View File

@@ -0,0 +1,116 @@
use crate::collections::HashMap;
use crate::graph::NodeId;
use alloc::sync::Arc;
use core::fmt::Debug;
use super::state::{BackwardStates, State};
/// Definition of the forward function of a node, called during retropropagation only.
/// This is different from the normal forward function because it reads and writes from
/// the [BackwardStates] map instead of having a clear function signature.
pub trait RetroForward: Debug + Send + 'static {
/// Applies the forward pass for retropropagation.
fn forward(&self, states: &mut BackwardStates, out_node: NodeId);
}
#[derive(new, Debug)]
/// Links [NodeId]s to their corresponding [RetroForward]
pub(crate) struct RetroForwards {
map: HashMap<NodeId, Arc<dyn RetroForward>>,
}
impl RetroForwards {
/// Executes the [RetroForward] for a given [NodeId] if the node's
/// [State] is [State::Recompute], otherwise does nothing.
pub(crate) fn execute_retro_forward(
&mut self,
node_id: NodeId,
backward_states: &mut BackwardStates,
) {
if let State::Recompute { n_required: _ } = backward_states
.get_state_ref(&node_id)
.unwrap_or_else(|| panic!("Should find node {node_id:?}"))
{
// Retro forwards are always used only once because afterwards their state is computed
let retro_forward = self.map.remove(&node_id).unwrap();
retro_forward.forward(backward_states, node_id);
}
}
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
#[macro_export]
/// Creates a RetroForward struct for unary scalar operations
macro_rules! retro_unary_scalar {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug, Clone)]
struct $name<B: Backend> {
lhs_id: NodeId,
rhs: Scalar,
_backend: PhantomData<B>,
}
impl<B: Backend> RetroForward for $name<B> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
let out = $ops(lhs, self.rhs);
states.save(out_node, out)
}
}
};
}
#[macro_export]
/// Creates a RetroForward struct for unary scalar operations
macro_rules! retro_unary {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug, Clone)]
struct $name<B: Backend> {
input_id: NodeId,
_backend: PhantomData<B>,
}
impl<B: Backend> RetroForward for $name<B> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);
let out = $ops(input);
states.save(out_node, out)
}
}
};
}
#[macro_export]
/// Creates a RetroForward struct for binary operations
macro_rules! retro_binary {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug, Clone)]
struct $name<B: Backend> {
lhs_id: NodeId,
rhs_id: NodeId,
_backend: PhantomData<B>,
}
impl<B: Backend> RetroForward for $name<B> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
let rhs = states.get_state::<B::FloatTensorPrimitive>(&self.rhs_id);
let out = $ops(lhs, rhs);
states.save(out_node, out)
}
}
};
}

View File

@@ -0,0 +1,144 @@
use core::any::Any;
use crate::collections::HashMap;
use crate::graph::NodeId;
use alloc::boxed::Box;
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
pub(crate) type StateContent = Box<dyn Any + Send>;
#[derive(Debug)]
/// The state contained at one node. Encapsulates the node output if precomputed,
/// or clearly asks that it needs to be recomputed from the parents.
/// Also keeps track of the number of times the state is required so it can be removed
/// from the map of states on its last use.
pub(crate) enum State {
/// The state was not checkpointed, will need to recompute it from the node's parents
Recompute { n_required: usize },
/// The state was checkpointed or computed during retropropagation and can be directly accessed
Computed {
state_content: StateContent,
n_required: usize,
},
}
impl State {
/// Returns a reference to the (not yet) downcasted node output, if checkpointed
pub(crate) fn to_state_content(&self) -> &StateContent {
match self {
State::Recompute { n_required: _ } => {
unreachable!(
"Can't get state content of recompute state. A child has likely been accessed before its parents."
)
}
State::Computed {
state_content,
n_required: _,
} => state_content,
}
}
/// Returns a (not yet) downcasted node output, if checkpointed
pub(crate) fn into_state_content(self) -> StateContent {
match self {
State::Recompute { n_required: _ } => {
unreachable!(
"Can't get state content of recompute state. A child has likely been accessed before its parents."
)
}
State::Computed {
state_content,
n_required: _,
} => state_content,
}
}
/// Returns the number of time the state is required
pub(crate) fn n_required(&self) -> usize {
match self {
State::Recompute { n_required } => *n_required,
State::Computed {
state_content: _,
n_required,
} => *n_required,
}
}
}
#[derive(new, Default, Debug)]
/// Links [NodeId]s to their current state
pub struct BackwardStates {
map: HashMap<NodeId, State>,
}
impl BackwardStates {
/// Returns the output in the state of the given [NodeId],
/// and decrements the number of times this state is required.
/// This function always gives ownership of the output, but will clone it if needed for further uses.
pub fn get_state<T>(&mut self, node_id: &NodeId) -> T
where
T: Clone + Send + 'static,
{
// Fetch the state and decrement its number of required
let state = self.map.remove(node_id).unwrap();
let remaining_n_required = state.n_required() - 1;
// Downcast the state to whatever it is supposed to be
// If still needed after giving ownership, we copy it back to the hashmap
if remaining_n_required > 0 {
let new_stored_state = match state {
State::Recompute { n_required: _ } => unreachable!(),
State::Computed {
state_content,
n_required: _,
} => State::Computed {
state_content,
n_required: remaining_n_required,
},
};
let downcasted = new_stored_state
.to_state_content()
.downcast_ref::<T>()
.unwrap()
.clone();
self.insert_state(*node_id, new_stored_state);
downcasted
} else {
let downcasted = state.into_state_content().downcast::<T>().unwrap();
*downcasted
}
}
/// Returns a reference to the [State] of the given node
/// Useful when we need [State] information without needing the underlying tensor
pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> {
self.map.get(node_id)
}
/// Associates a [State] to its [NodeId]
pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) {
self.map.insert(node_id, state);
}
/// Saves the output to the state of the given [NodeId].
pub fn save<T>(&mut self, node_id: NodeId, saved_output: T)
where
T: Clone + Send + 'static,
{
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
self.insert_state(
node_id,
State::Computed {
state_content: Box::new(saved_output),
n_required,
},
);
}
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
}
}

View File

@@ -0,0 +1,102 @@
use core::fmt::Debug;
use burn_backend::Backend;
use crate::{graph::ComputingProperty, tensor::AutodiffTensor};
use alloc::sync::Arc;
use super::{
builder::{ActionType, CheckpointerBuilder},
retro_forward::RetroForward,
};
/// Strategy for the amount of checkpointing to do during autodiff
pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'static {
/// May modify the compute property depending on the strategy
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty;
/// Checkpoints parents if necessary in the strategy
fn checkpoint_parents<'a, B2, A>(
parents: A,
builder: &mut CheckpointerBuilder,
) -> Result<(), CheckpointingError>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>;
}
#[derive(Debug)]
/// Error that can happen when trying to checkpoint a tensor.
pub enum CheckpointingError {
/// When a parent is untracked, we can't easily checkpoint its state, since we don't know the
/// requirements in advanced.
UntrackedParent,
}
#[derive(Clone, Copy, Debug, Default)]
/// All operations are considered compute bound, notwithstanding how they are marked
pub struct NoCheckpointing {}
impl CheckpointStrategy for NoCheckpointing {
/// An operation marked as memory bound is actually compute bound.
fn compute_property<R: RetroForward>(_retro_forward: R) -> ComputingProperty {
ComputingProperty::ComputeBound
}
/// An operation marked as memory bound is actually compute bound.
/// It's therefore useless to checkpoint the parents
fn checkpoint_parents<'a, B2, A>(
_parents: A,
_builder: &mut CheckpointerBuilder,
) -> Result<(), CheckpointingError>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
{
// Nothing to do here
Ok(())
}
}
#[derive(Clone, Copy, Debug, Default)]
/// Operation properties are as they are marked (compute or memory bound)
pub struct BalancedCheckpointing {}
impl CheckpointStrategy for BalancedCheckpointing {
/// An operation marked as memory bound is memory bound.
/// When memory bound, an operation needs to save its RetroForward
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty {
ComputingProperty::MemoryBound {
retro_forward: Arc::new(retro_forward),
}
}
/// An operation marked as memory bound is really memory bound.
/// Since the operation may not checkpoint its parents but may need them indirectly
/// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them
fn checkpoint_parents<'a, B2, A>(
parents: A,
builder: &mut CheckpointerBuilder,
) -> Result<(), CheckpointingError>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
{
let mut can_checkpoint = true;
for tensor in parents.into_iter() {
if let crate::graph::Requirement::None = tensor.node.requirement {
can_checkpoint = false;
} else {
builder.checkpoint(tensor, ActionType::Backup);
}
}
if !can_checkpoint {
*builder = CheckpointerBuilder::default();
return Err(CheckpointingError::UntrackedParent);
}
Ok(())
}
}

View File

@@ -0,0 +1,85 @@
use burn_backend::{
Backend, TensorMetadata, TensorPrimitive,
tensor::{FloatTensor, TensorContainer},
};
use crate::{
NodeId,
graph::{NodeRef, Requirement},
tensor::AutodiffTensor,
};
/// Gradient identifier.
pub type GradID = u64;
/// Gradients container used during the backward pass.
pub struct Gradients {
container: TensorContainer<GradID>,
}
impl Gradients {
/// Creates a new gradients container.
pub fn new<B: Backend>(root_node: NodeRef, root_tensor: FloatTensor<B>) -> Self {
let mut gradients = Self {
container: TensorContainer::new(),
};
gradients.register::<B>(
root_node.id,
B::float_ones(
root_tensor.shape(),
&B::float_device(&root_tensor),
root_tensor.dtype().into(),
),
);
gradients
}
/// Consumes the gradients for a given tensor.
///
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
/// backward pass, otherwise, it may be consume multiple times.
pub fn consume<B: Backend>(&mut self, node: &NodeRef) -> FloatTensor<B> {
match node.requirement {
Requirement::Grad => self
.container
.get::<B>(&node.id.value)
.map(|tensor| tensor.tensor())
.expect("Can't consume the gradients before they are registered at least once."),
Requirement::GradInBackward => self
.container
.remove::<B>(&node.id.value)
.map(|tensor| tensor.tensor())
.expect("Can't consume the gradients before they are registered at least once."),
Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"),
}
}
/// Removes a grad tensor from the container.
pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
self.container
.remove::<B>(&tensor.node.id.value)
.map(|tensor| tensor.tensor())
}
/// Gets a grad tensor from the container.
pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
self.container
.get::<B>(&tensor.node.id.value)
.map(|tensor| tensor.tensor())
}
/// Register a grad tensor in the container.
///
/// If the tensor already exists, add both tensors together before saving the result.
pub fn register<B: Backend>(&mut self, node_id: NodeId, value: FloatTensor<B>) {
if let Some(tensor_old) = self.container.remove::<B>(&node_id.value) {
self.container.register::<B>(
node_id.value,
TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())),
);
} else {
self.container
.register::<B>(node_id.value, TensorPrimitive::Float(value));
}
}
}

View File

@@ -0,0 +1,17 @@
use super::NodeId;
use crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent};
use alloc::boxed::Box;
/// Backward step for reverse mode autodiff.
pub trait Step: Send + core::fmt::Debug {
/// Executes the step and consumes it.
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
/// Depth of the operation relative to the first node added to a graph.
fn depth(&self) -> usize;
/// The node associated to the step.
fn node(&self) -> NodeId;
/// The parents of the node associated to the step.
fn parents(&self) -> &[Parent];
}
pub type StepBoxed = Box<dyn Step>;

View File

@@ -0,0 +1,9 @@
mod base;
mod node;
mod requirement;
pub mod traversal;
pub use base::*;
pub use node::*;
pub use requirement::*;

View File

@@ -0,0 +1,87 @@
use alloc::{sync::Arc, vec::Vec};
#[cfg(target_has_atomic = "64")]
use core::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(target_has_atomic = "64"))]
use portable_atomic::{AtomicU64, Ordering};
use crate::checkpoint::retro_forward::RetroForward;
use crate::runtime::AutodiffClientImpl;
use super::Requirement;
#[derive(Debug, Clone)]
pub enum ComputingProperty {
ComputeBound,
MemoryBound {
retro_forward: Arc<dyn RetroForward>,
},
Ambiguous, // Maybe autotune someday
}
/// This is safe only because we only call RetroForward on the autodiff server.
/// Therefore, the trait will never be used by multiple threads at the same time.
///
/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
/// Arc, which will make (dyn RetroForward) safely implement Send.
unsafe impl Send for ComputingProperty {}
/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
unsafe impl Sync for ComputingProperty {}
/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
#[derive(new, Debug)]
pub struct Node {
pub parents: Vec<Parent>,
pub order: usize,
pub id: NodeId,
pub requirement: Requirement,
pub properties: ComputingProperty,
pub client: AutodiffClientImpl,
}
pub type NodeRef = Arc<Node>;
#[derive(new, Debug, Clone, PartialEq, Eq)]
pub struct Parent {
pub id: NodeId,
}
impl Node {
/// Returns the [node](Node) only if gradients are required.
pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
match self.requirement.is_none() {
true => None,
false => Some(self.clone()),
}
}
}
/// Unique identifier generated for each node.
#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
pub struct NodeId {
/// The integer representation of the id
pub value: u64,
}
impl core::fmt::Display for NodeId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("NodeId({})", self.value))
}
}
impl NodeId {
/// Create a unique [node id](NodeId).
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
if value == u64::MAX {
panic!("NodeId overflowed");
}
Self { value }
}
}
impl Default for NodeId {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,38 @@
use super::NodeRef;
/// Requirement for each tensor in the graph.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Requirement {
/// Operations that require gradients.
Grad,
/// Operations that require gradients only for backprop.
GradInBackward,
/// Operations that don't need gradients, therefore not to be included in the graph.
None,
}
impl Requirement {
/// Returns true if gradients are not required.
pub fn is_none(&self) -> bool {
matches!(self, Self::None)
}
/// Returns the right requirement from a list of nodes.
pub fn from_nodes(nodes: &[NodeRef]) -> Self {
if nodes.len() == 1 {
return nodes[0].requirement.infer(&Requirement::None);
}
nodes
.iter()
.map(|node| node.requirement)
.reduce(|acc, requirement| requirement.infer(&acc))
.unwrap_or(Requirement::None)
}
fn infer(&self, other: &Self) -> Self {
match self.is_none() && other.is_none() {
true => Self::None,
false => Self::GradInBackward,
}
}
}

View File

@@ -0,0 +1,74 @@
use super::{Step, StepBoxed};
use crate::{
NodeId,
collections::{HashMap, HashSet},
graph::Parent,
};
use alloc::vec::Vec;
/// Breadth for search algorithm.
pub struct BreadthFirstSearch;
pub trait TraversalItem {
fn id(&self) -> NodeId;
fn parents(&self) -> &[Parent];
fn parent_nodes(&self) -> Vec<NodeId> {
self.parents().iter().map(|p| p.id).collect()
}
}
impl BreadthFirstSearch {
/// Traverse the graph of backward steps from a root node.
pub fn traverse<F, I>(
&self,
root_id: NodeId,
root_step: I,
steps: &mut HashMap<NodeId, I>,
mut callback: F,
) where
F: FnMut(NodeId, I),
I: TraversalItem,
{
let mut visited = HashSet::new();
let mut parents = Vec::new();
visited.insert(root_id);
parents.append(&mut root_step.parent_nodes());
callback(root_id, root_step);
while let Some(id) = parents.pop() {
let step = match steps.remove(&id) {
Some(step) => step,
None => continue,
};
let step_node = step.id();
let step_parents = step.parent_nodes();
if visited.contains(&step_node) {
continue;
}
visited.insert(step_node);
for id in step_parents.iter() {
if !visited.contains(id) {
parents.push(*id);
}
}
callback(step_node, step);
}
}
}
impl TraversalItem for StepBoxed {
fn id(&self) -> NodeId {
Step::node(self.as_ref())
}
fn parents(&self) -> &[Parent] {
Step::parents(self.as_ref())
}
}

View File

@@ -0,0 +1,43 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! # Burn Autodiff
//!
//! This autodiff library is a part of the Burn project. It is a standalone crate
//! that can be used to perform automatic differentiation on tensors. It is
//! designed to be used with the Burn Tensor crate, but it can be used with any
//! tensor library that implements the `Backend` trait.
#[macro_use]
extern crate derive_new;
extern crate alloc;
/// Checkpoint module.
pub mod checkpoint;
/// Gradients module.
pub mod grads;
/// Operation module.
pub mod ops;
pub(crate) mod graph;
// Exported for backend extension
pub use graph::NodeId;
pub(crate) mod tensor;
pub(crate) mod utils;
mod backend;
pub(crate) mod runtime;
pub use backend::*;
/// A facade around for HashMap and HashSet.
/// This avoids elaborate import wrangling having to happen in every module.
mod collections {
#[cfg(not(feature = "std"))]
pub use hashbrown::{HashMap, HashSet};
#[cfg(feature = "std")]
pub use std::collections::{HashMap, HashSet};
}

View File

@@ -0,0 +1,167 @@
use core::marker::PhantomData;
use crate::{
Autodiff,
checkpoint::{
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,
strategy::CheckpointStrategy,
},
grads::Gradients,
graph::NodeId,
ops::{Backward, Ops, OpsKind, unary},
retro_unary,
};
use burn_backend::{Backend, ops::ActivationOps, tensor::FloatTensor};
impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodiff<B, C> {
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Gelu;
retro_unary!(RetroGelu, B::gelu);
impl<B: Backend> Backward<B, 1> for Gelu {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::gelu_backward(input, grad)
});
}
}
match Gelu
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroGelu::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::gelu(tensor.primitive.clone()))
}
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
}
}
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Relu;
retro_unary!(RetroRelu, B::relu);
impl<B: Backend> Backward<B, 1> for Relu {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let state = checkpointer.retrieve_node_output(ops.state);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::relu_backward(state, grad)
});
}
}
match Relu
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroRelu::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::relu(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::relu(tensor.primitive)),
}
}
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Sigmoid;
retro_unary!(RetroSigmoid, B::sigmoid);
impl<B: Backend> Backward<B, 1> for Sigmoid {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
let output = B::sigmoid(input);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::sigmoid_backward(output, grad)
});
}
}
match Sigmoid
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSigmoid::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::sigmoid(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
}
}
fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct LogSigmoid;
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
impl<B: Backend> Backward<B, 1> for LogSigmoid {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::log_sigmoid_backward(input, grad)
});
}
}
match LogSigmoid
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroLogSigmoid::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
}
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
}
}
}

View File

@@ -0,0 +1,88 @@
use super::{Ops, OpsPrep};
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},
grads::Gradients,
graph::{ComputingProperty, NodeRef, Requirement},
utils::duplicate,
};
use burn_backend::Backend;
/// Trait for all operations.
///
/// # Notes
///
/// Concrete types implementing this trait should not have any state.
/// If a state is necessary during the backward pass,
/// they should be declared with the associated type 'State'.
pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
where
Self: Sized + 'static,
B: Backend,
{
/// Associated type to compute the backward pass.
type State: Clone + Send + core::fmt::Debug + 'static;
/// The backward pass.
fn backward(
self,
ops: Ops<Self::State, N>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
);
/// Prepare the backward ops.
fn prepare<C: CheckpointStrategy>(
self,
nodes: [NodeRef; N],
) -> OpsPrep<Self, B, Self::State, C, N> {
let requirement = Requirement::from_nodes(&nodes);
OpsPrep::new(
nodes,
requirement,
self,
ComputingProperty::Ambiguous, // If not specified we start with ambiguous
CheckpointerBuilder::default(),
)
}
}
/// Execute a binary operation during the backward step.
pub fn binary<B, FLhs, FRhs>(
parents: [Option<NodeRef>; 2],
node: NodeRef,
grads: &mut Gradients,
func_lhs: FLhs,
func_rhs: FRhs,
) where
B: Backend,
FLhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
FRhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
{
let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::<B>(&node)));
let [node_lhs, node_rhs] = parents;
if let Some(node) = node_lhs {
let grad = func_lhs(grad_4lhs.unwrap());
grads.register::<B>(node.id, grad)
}
if let Some(node) = node_rhs {
let grad = func_rhs(grad_4rhs.unwrap());
grads.register::<B>(node.id, grad)
}
}
/// Execute a unary operation during the backward step.
pub fn unary<B, F>(parents: [Option<NodeRef>; 1], node: NodeRef, grads: &mut Gradients, func: F)
where
B: Backend,
F: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
{
let [parent_node] = parents;
let grad = grads.consume::<B>(&node);
if let Some(node) = parent_node {
let grad = func(grad);
grads.register::<B>(node.id, grad)
}
}

View File

@@ -0,0 +1,317 @@
use super::Backward;
use crate::{
checkpoint::{
base::Checkpointer,
builder::{ActionType, CheckpointerBuilder},
retro_forward::RetroForward,
strategy::CheckpointStrategy,
},
grads::Gradients,
graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step},
tensor::AutodiffTensor,
};
use alloc::boxed::Box;
use burn_backend::{Backend, TensorMetadata, tensor::FloatTensor};
use burn_std::Shape;
use core::marker::PhantomData;
/// Operation in preparation.
///
/// Each mode has its own set of functions to minimize cloning for unused backward states.
#[derive(new)]
pub struct OpsPrep<Backward, B, S, C, const N: usize, Mode = Init> {
nodes: [NodeRef; N],
requirement: Requirement,
backward: Backward,
compute_property: ComputingProperty,
checkpointer_builder: CheckpointerBuilder,
checkpoint_strategy: PhantomData<C>,
phantom_backend: PhantomData<B>,
phantom_state: PhantomData<S>,
marker: PhantomData<Mode>,
}
/// Operation is initialized
pub struct Init;
/// Operation has been tagged as memory bound
pub struct MemoryBound;
/// Memory bound operation has received its RetroForward
pub struct MemoryBoundRetroForward;
/// Operation's compute property is fixed
pub struct ComputePropertyDone;
/// Tracked operation tag.
pub struct Tracked;
/// Untracked operation tag.
pub struct UnTracked;
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Init>
where
B: Backend,
BO: Backward<B, N, State = S>,
{
/// Indicates that the operation is compute bound, meaning its computation
/// is heavy and should not be recomputed
pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone> {
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
ComputingProperty::ComputeBound,
self.checkpointer_builder,
)
}
/// Indicates that the operation is memory bound, meaning its computation
/// is light and can be recomputed
pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, N, MemoryBound> {
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
self.compute_property,
self.checkpointer_builder,
)
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBound>
where
B: Backend,
BO: Backward<B, N, State = S>,
C: CheckpointStrategy,
{
/// Registers the retro forward, if needed
pub fn retro_forward<R: RetroForward>(
self,
retro_forward: R,
) -> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward> {
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
C::compute_property(retro_forward),
self.checkpointer_builder,
)
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward>
where
B: Backend,
BO: Backward<B, N, State = S>,
C: CheckpointStrategy,
{
/// Checkpoints the parents, if needed
pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
{
let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder)
{
Ok(..) => self.compute_property,
Err(..) => ComputingProperty::ComputeBound,
};
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
compute_property,
self.checkpointer_builder,
)
}
}
impl<BO, B, C, const N: usize> OpsPrep<BO, B, (), C, N, ComputePropertyDone>
where
B: Backend,
BO: Backward<B, N, State = ()>,
{
/// Prepare a stateless operation.
pub fn stateless(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
match self.stateful() {
OpsKind::Tracked(prep) => prep.finish((), output),
OpsKind::UnTracked(prep) => prep.finish(output),
}
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
where
B: Backend,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Prepare an operation that requires a state during the backward pass.
pub fn stateful(self) -> OpsKind<BO, B, S, C, N> {
match self.requirement.is_none() {
false => OpsKind::Tracked(OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
self.compute_property,
self.checkpointer_builder,
)),
true => OpsKind::UnTracked(OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
self.compute_property,
self.checkpointer_builder,
)),
}
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>
where
B: Backend,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
pub fn finish(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.requirement,
self.compute_property,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, output.node.clone(), ());
// We register the ops in the graph even if untracked, otherwise memory bound operations
// that have an untracked parent would not be able to retrieve it
output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>
where
B: Backend,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
pub fn finish(self, state: S, output: FloatTensor<B>) -> AutodiffTensor<B> {
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.requirement,
self.compute_property,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, output.node.clone(), state);
output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)
}
/// Checkpoints the tensor
pub fn checkpoint(&mut self, tensor: &AutodiffTensor<B>) -> NodeId {
self.checkpointer_builder
.checkpoint(tensor, ActionType::Explicit);
tensor.node.id
}
}
/// Enum used before finishing tracked and untracked operations.
pub enum OpsKind<BO, B, S, C, const N: usize> {
/// Tracked operation preparation.
Tracked(OpsPrep<BO, B, S, C, N, Tracked>),
/// Untracked operation preparation.
UnTracked(OpsPrep<BO, B, S, C, N, UnTracked>),
}
/// Operation containing its parent nodes, its own node and the backward step state.
#[derive(new, Debug)]
pub struct Ops<S, const N: usize> {
/// Parents nodes.
pub parents: [Option<NodeRef>; N],
/// The node.
pub node: NodeRef,
/// The state.
pub state: S,
}
/// Operation implementing backward [step](Step) with type erasing.
#[derive(new, Debug)]
struct OpsStep<B, T, SB, const N: usize>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + core::fmt::Debug + 'static,
{
ops: Ops<SB, N>,
backward: T,
phantom: PhantomData<B>,
}
impl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + core::fmt::Debug + 'static,
{
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
self.backward.backward(self.ops, grads, checkpointer);
}
fn node(&self) -> NodeId {
self.ops.node.id
}
fn parents(&self) -> &[Parent] {
&self.ops.node.parents
}
fn depth(&self) -> usize {
self.ops.node.order
}
}
#[derive(new, Debug)]
struct UntrackedOpsStep<const N: usize> {
ops: Ops<(), N>,
}
impl<const N: usize> Step for UntrackedOpsStep<N> {
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
// Nothing to do
}
fn node(&self) -> NodeId {
self.ops.node.id
}
fn parents(&self) -> &[Parent] {
&self.ops.node.parents
}
fn depth(&self) -> usize {
self.ops.node.order
}
}
/// Make sure the grad tensor has the given shape.
///
/// If broadcasting happened during the forward pass, the gradients will be sum along the
/// broadcasted dimension.
pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {
let shape_grad = grad.shape();
let ndims = shape_grad.num_dims();
for i in 0..ndims {
if shape_grad[i] != shape[i] {
if shape[i] != 1 {
panic!(
"Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}",
shape, shape_grad, "Expected the shape of the next grad to be 1."
);
}
grad = B::float_sum_dim(grad, i);
}
}
grad
}

View File

@@ -0,0 +1,161 @@
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
use alloc::vec::Vec;
use burn_backend::{
Backend, ExecutionError, Scalar, TensorData,
ops::BoolTensorOps,
tensor::{BoolTensor, Device, IntTensor},
};
use burn_std::Shape;
impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B> {
B::bool_from_data(data, device)
}
async fn bool_into_data(tensor: BoolTensor<B>) -> Result<TensorData, ExecutionError> {
B::bool_into_data(tensor).await
}
fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B> {
B::bool_into_int(tensor)
}
fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B> {
B::bool_to_device(tensor, device)
}
fn bool_device(tensor: &BoolTensor<B>) -> Device<B> {
B::bool_device(tensor)
}
fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
B::bool_reshape(tensor, shape)
}
fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_std::Slice]) -> BoolTensor<B> {
B::bool_slice(tensor, slices)
}
fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
B::bool_empty(shape, device)
}
fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
B::bool_zeros(shape, device)
}
fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
B::bool_ones(shape, device)
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
slices: &[burn_std::Slice],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_slice_assign(tensor, slices, value)
}
fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
B::bool_cat(tensors, dim)
}
fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_equal(lhs, rhs)
}
fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B> {
B::bool_not(tensor)
}
fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_and(lhs, rhs)
}
fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_or(lhs, rhs)
}
fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_xor(lhs, rhs)
}
fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
AutodiffTensor::new(B::bool_into_float(tensor))
}
fn bool_swap_dims(
tensor: <Autodiff<B> as Backend>::BoolTensorPrimitive,
dim1: usize,
dim2: usize,
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive {
B::bool_swap_dims(tensor, dim1, dim2)
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
B::bool_permute(tensor, axes)
}
fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B> {
B::bool_flip(tensor, axes)
}
async fn bool_argwhere(tensor: BoolTensor<B>) -> IntTensor<B> {
B::bool_argwhere(tensor).await
}
fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
B::bool_expand(tensor, shape)
}
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
B::bool_repeat_dim(tensor, dim, times)
}
fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
B::bool_unfold(tensor, dim, size, step)
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
source: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_mask_where(tensor, mask, source)
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
B::bool_mask_fill(tensor, mask, value)
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
B::bool_gather(dim, tensor, indices)
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_scatter_or(dim, tensor, indices, value)
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
B::bool_equal_elem(lhs, rhs)
}
}

View File

@@ -0,0 +1,406 @@
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
use alloc::vec::Vec;
use burn_backend::{
Backend, Distribution, ExecutionError, Scalar, TensorData,
ops::IntTensorOps,
tensor::{BoolTensor, Device, IntTensor},
};
use burn_std::{IntDType, Shape};
impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
B::int_from_data(data, device)
}
async fn int_into_data(tensor: IntTensor<B>) -> Result<TensorData, ExecutionError> {
B::int_into_data(tensor).await
}
fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {
B::int_to_device(tensor, device)
}
fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
B::int_device(tensor)
}
fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
B::int_reshape(tensor, shape)
}
fn int_slice(tensor: IntTensor<B>, slices: &[burn_std::Slice]) -> IntTensor<B> {
B::int_slice(tensor, slices)
}
fn int_empty(
shape: Shape,
device: &<Autodiff<B> as Backend>::Device,
dtype: IntDType,
) -> IntTensor<B> {
B::int_empty(shape, device, dtype)
}
fn int_slice_assign(
tensor: IntTensor<B>,
slices: &[burn_std::Slice],
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_slice_assign(tensor, slices, value)
}
fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
B::int_cat(tensors, dim)
}
fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_equal(lhs, rhs)
}
fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_equal_elem(lhs, rhs)
}
fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_add(lhs, rhs)
}
fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_add_scalar(lhs, rhs)
}
fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
B::int_clamp_min(tensor, min)
}
fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
B::int_clamp_max(tensor, max)
}
fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
B::int_clamp(tensor, min, max)
}
fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_sub(lhs, rhs)
}
fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_sub_scalar(lhs, rhs)
}
fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_mul(lhs, rhs)
}
fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_mul_scalar(lhs, rhs)
}
fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_div(lhs, rhs)
}
fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_div_scalar(lhs, rhs)
}
fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_remainder(lhs, rhs)
}
fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_remainder_scalar(lhs, rhs)
}
fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_matmul(lhs, rhs)
}
fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
B::int_neg(tensor)
}
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
B::int_zeros(shape, device, dtype)
}
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
B::int_ones(shape, device, dtype)
}
fn int_full(
shape: Shape,
fill_value: Scalar,
device: &Device<Self>,
dtype: IntDType,
) -> IntTensor<B> {
B::int_full(shape, fill_value, device, dtype)
}
fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
B::int_sum(tensor)
}
fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_sum_dim(tensor, dim)
}
fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
B::int_mean(tensor)
}
fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_mean_dim(tensor, dim)
}
fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cumsum(tensor, dim)
}
fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cumprod(tensor, dim)
}
fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cummin(tensor, dim)
}
fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cummax(tensor, dim)
}
fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
B::int_repeat_dim(tensor, dim, times)
}
fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_greater(lhs, rhs)
}
fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_greater_elem(lhs, rhs)
}
fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_greater_equal(lhs, rhs)
}
fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_greater_equal_elem(lhs, rhs)
}
fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_lower(lhs, rhs)
}
fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_lower_elem(lhs, rhs)
}
fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_lower_equal(lhs, rhs)
}
fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_lower_equal_elem(lhs, rhs)
}
fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {
B::int_gather(dim, tensor, indices)
}
fn int_scatter_add(
dim: usize,
tensor: IntTensor<B>,
indices: IntTensor<B>,
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_scatter_add(dim, tensor, indices, value)
}
fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {
B::int_select(tensor, dim, indices)
}
fn int_select_add(
tensor: IntTensor<B>,
dim: usize,
indices: IntTensor<B>,
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_select_add(tensor, dim, indices, value)
}
fn int_mask_where(
tensor: IntTensor<B>,
mask: BoolTensor<B>,
value: IntTensor<B>,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
B::int_mask_where(tensor, mask, value)
}
fn int_mask_fill(
tensor: IntTensor<B>,
mask: BoolTensor<B>,
value: Scalar,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
B::int_mask_fill(tensor, mask, value)
}
fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_argmax(tensor, dim)
}
fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_argmin(tensor, dim)
}
fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
B::int_max(tensor)
}
fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
B::int_max_dim(tensor, dim)
}
fn int_max_dim_with_indices(
tensor: B::IntTensorPrimitive,
dim: usize,
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
B::int_max_dim_with_indices(tensor, dim)
}
fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
B::int_min(tensor)
}
fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
B::int_min_dim(tensor, dim)
}
fn int_min_dim_with_indices(
tensor: B::IntTensorPrimitive,
dim: usize,
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
B::int_min_dim_with_indices(tensor, dim)
}
fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
B::int_abs(tensor)
}
fn int_into_float(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
AutodiffTensor::new(B::int_into_float(tensor))
}
fn int_swap_dims(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
dim1: usize,
dim2: usize,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
B::int_swap_dims(tensor, dim1, dim2)
}
fn int_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> IntTensor<Self> {
B::int_random(shape, distribution, device)
}
fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
B::int_arange(range, device)
}
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
B::int_permute(tensor, axes)
}
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
B::int_flip(tensor, axes)
}
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::int_sign(tensor)
}
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::int_prod(tensor)
}
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
B::int_prod_dim(tensor, dim)
}
fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
B::int_expand(tensor, shape)
}
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
B::int_sort(tensor, dim, descending)
}
fn int_sort_with_indices(
tensor: IntTensor<Self>,
dim: usize,
descending: bool,
) -> (IntTensor<Self>, IntTensor<Self>) {
B::int_sort_with_indices(tensor, dim, descending)
}
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
B::int_argsort(tensor, dim, descending)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_and(lhs, rhs)
}
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_and_scalar(lhs, rhs)
}
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_or(lhs, rhs)
}
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_or_scalar(lhs, rhs)
}
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_xor(lhs, rhs)
}
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_xor_scalar(lhs, rhs)
}
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_not(tensor)
}
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_left_shift(lhs, rhs)
}
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_left_shift_scalar(lhs, rhs)
}
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_right_shift(lhs, rhs)
}
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_right_shift_scalar(lhs, rhs)
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
B::int_cast(tensor, dtype)
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
B::int_unfold(tensor, dim, size, step)
}
}

View File

@@ -0,0 +1,27 @@
use super::{Backward, Ops, unary};
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use burn_backend::{Backend, TensorMetadata};
use burn_std::Shape;
#[derive(Debug)]
pub(crate) struct MaxMinDim;
impl<B: Backend> Backward<B, 1> for MaxMinDim {
type State = (B::IntTensorPrimitive, Shape, usize);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
let (indices, shape, dim) = ops.state;
let device = B::float_device(&grad);
let dtype = grad.dtype();
let zeros = B::float_zeros(shape, &device, dtype.into());
B::float_scatter_add(dim, zeros, indices, grad)
});
}
}

View File

@@ -0,0 +1,15 @@
mod activation;
mod backward;
mod base;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;
pub(crate) mod maxmin;
pub(crate) mod sort;
pub use backward::*;
pub use base::*;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
use burn_backend::{
Backend, ExecutionError, TensorData,
ops::QTensorOps,
tensor::{
Device, FloatTensor, IntTensor, QuantizedTensor,
quantization::QuantizationParametersPrimitive,
},
};
use burn_std::{QuantScheme, Shape};
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
todo!()
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
todo!() // required for QAT
}
fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
) -> QuantizedTensor<Self> {
todo!()
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
todo!()
}
fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
B::q_device(tensor)
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
B::q_reshape(tensor, shape)
}
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
B::q_into_data(tensor).await
}
fn q_swap_dims(
_tensor: QuantizedTensor<Self>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_gather(
_dim: usize,
_tensor: QuantizedTensor<Self>,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_select(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_slice(
_tensor: QuantizedTensor<Self>,
_slices: &[burn_std::Slice],
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
B::q_argmax(tensor, dim)
}
fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
B::q_argmin(tensor, dim)
}
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
}

View File

@@ -0,0 +1,27 @@
use super::{Backward, Ops, unary};
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use burn_backend::{Backend, TensorMetadata};
use burn_std::Shape;
#[derive(Debug)]
pub(crate) struct SortDim;
impl<B: Backend> Backward<B, 1> for SortDim {
type State = (B::IntTensorPrimitive, Shape, usize);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
let (indices, shape, dim) = ops.state;
let device = B::float_device(&grad);
let dtype = grad.dtype();
let zeros = B::float_zeros(shape, &device, dtype.into());
B::float_scatter_add(dim, zeros, indices, grad)
});
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
use burn_backend::{
Backend, ExecutionError,
ops::{TransactionOps, TransactionPrimitive},
};
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
impl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C> {
async fn tr_execute(
transaction: TransactionPrimitive<Self>,
) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
B::tr_execute(TransactionPrimitive::new(
transaction
.read_floats
.into_iter()
.map(|t| t.primitive)
.collect(),
transaction.read_qfloats,
transaction.read_ints,
transaction.read_bools,
))
.await
}
}

View File

@@ -0,0 +1,18 @@
use crate::{
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::StepBoxed,
tensor::{AutodiffTensor, NodeRefCount},
};
use burn_backend::Backend;
/// Client used to communicate with the autodiff server.
pub trait AutodiffClient: Send + Clone {
/// Register a new step.
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
/// Call backpropagation from the given tensor.
fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
}
/// Client implementation in used.
pub type AutodiffClientImpl = super::graph::GraphMutexClient;

View File

@@ -0,0 +1,335 @@
use super::{AutodiffClient, server::AutodiffServer};
use crate::{
NodeId,
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::{Parent, StepBoxed},
runtime::server::NodeCleaner,
tensor::{AutodiffTensor, NodeRefCount},
};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_backend::Backend;
use hashbrown::{HashMap, HashSet};
#[cfg(feature = "std")]
use parking_lot::{Mutex, MutexGuard};
#[cfg(not(feature = "std"))]
use spin::{Mutex, MutexGuard};
/// A client for managing multiple graphs using mutex-based synchronization.
///
/// The biggest benefit of using this client implementation is that each graph can modify its own
/// data without blocking other graphs, which is essential for multi-device training.
///
/// # Notes
///
/// The [AutodiffServer] fully supports multiple graphs with sharing nodes, however those type of
/// graphs will be stored under a single mutex-protected graph by the client, limiting
/// parallelisation.
#[derive(Clone, new, Debug)]
pub struct GraphMutexClient;
/// Manages a collection of graphs, mapping [node ids](NodeId) to their respective graph.
///
/// The `GraphLocator` is responsible for selecting and merging graphs based on their IDs and parent
/// dependencies, ensuring proper synchronization and server allocation.
///
/// # Notes
///
/// Multiple node ids can point to the same graph, where the autodiff graph is stored.
#[derive(Default)]
pub struct GraphLocator {
graphs: HashMap<NodeId, Arc<Graph>>,
/// We keep a mapping of each original node id (graph id) => all nodes that point to that graph.
/// This is to ensure that when merging graphs, we correctly move all previous graphs to
/// the new merged one.
keys: HashMap<NodeId, HashSet<NodeId>>,
}
/// Represents a single computation graph with a mutex-protected server.
///
/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was
/// first created.
pub(crate) struct Graph {
origin: NodeId,
state: Mutex<GraphState>,
}
#[derive(Default)]
struct GraphState {
server: AutodiffServer,
}
impl core::fmt::Debug for Graph {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Graph")
.field("origin", &self.origin)
.finish()
}
}
static STATE: Mutex<Option<GraphLocator>> = Mutex::new(None);
impl GraphMutexClient {
/// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
///
/// # Parameters
/// - `node`: The unique identifier for the stream.
/// - `parents`: A slice of parent nodes that the stream depends on.
///
/// # Returns
/// An `Arc<Graph>` representing the selected or newly created stream.
fn graph(node: NodeId, parents: &[Parent]) -> Arc<Graph> {
let mut state = STATE.lock();
match state.as_mut() {
Some(locator) => locator.select(node, parents),
None => {
let mut locator = GraphLocator::default();
let stream = locator.select(node, parents);
*state = Some(locator);
stream
}
}
}
}
impl AutodiffClient for GraphMutexClient {
fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
let node_id = *node_id_ref;
let graph = GraphMutexClient::graph(node_id, step.parents());
let mut state = graph.state.lock();
state.server.register(node_id_ref, step, actions);
}
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
let node_id = root.node.id;
let graph = GraphMutexClient::graph(root.node.id, &[]);
let grads = Gradients::new::<B>(root.node, root.primitive);
let grads = {
let mut state = graph.state.lock();
state.server.backward::<GraphCleaner>(grads, node_id)
}; // lock released
GraphCleaner::cleanup_orphaned_entries();
grads
}
}
struct GraphCleaner<'a> {
guard: MutexGuard<'a, Option<GraphLocator>>,
}
impl<'a> GraphCleaner<'a> {
fn cleanup_orphaned_entries() {
let graphs = {
// Get the available graphs and release the lock
match STATE.lock().as_ref() {
Some(state) => state.graphs.clone(),
None => return,
}
};
let mut should_remove = Vec::new();
for graph in graphs.values() {
{
let mut guard = graph.state.lock();
// Double safety: in case it was marked as no longer useful, but other
// nodes are still relevant, we only check which nodes can safely be removed.
if !guard.server.maybe_useful() {
guard
.server
.free_unused_roots(|node| should_remove.push(*node));
}
}
}
if !should_remove.is_empty() {
let mut state = STATE.lock();
if let Some(state) = state.as_mut() {
for node in should_remove {
state.remove_entry(&node);
}
}
}
}
}
impl<'a> NodeCleaner for GraphCleaner<'a> {
fn init() -> Self {
let guard = STATE.lock();
Self { guard }
}
fn clean(&mut self, node: &NodeId) {
if let Some(state) = self.guard.as_mut() {
state.remove_entry(node);
}
}
}
impl GraphLocator {
/// Selects a single graph for the given [NodeId], considering parent dependencies.
///
/// If multiple graphs are found, they are merged into a single one.
///
/// # Parameters
/// - `node`: The node ID of the graph to select.
/// - `parents`: A slice of parent nodes that the graph depends on.
///
/// # Returns
///
/// An `Arc<Graph>` representing the selected or merged graph.
pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {
match self.analyse(node, parents) {
GraphAnalysis::NoCollision(graph) => {
if graph.origin != node {
self.graphs.insert(node, graph.clone());
self.register_key(graph.origin, node);
}
graph
}
GraphAnalysis::Collisions(graphs) => self.merge(node, graphs),
}
}
/// Analyses the graph for a given node and its parents, returning the associated `GraphAnalysis`.
fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalysis {
// If no parents, there is no collision, therefore a single graph is ok.
if parents.is_empty() {
let graph = match self.graphs.get(&node) {
Some(val) => val.clone(),
None => self.new_graph(node),
};
return GraphAnalysis::NoCollision(graph);
};
// We collect all graphs of parents and of the current node based on their origin node id.
let mut graphs = HashMap::<NodeId, Arc<Graph>>::new();
if let Some(val) = self.graphs.get(&node) {
graphs.insert(val.origin, val.clone());
}
for parent in parents {
match self.graphs.get(&parent.id) {
Some(graph) => graphs.insert(graph.origin, graph.clone()),
None => continue,
};
}
if graphs.is_empty() {
return match self.graphs.get(&node) {
Some(old) => GraphAnalysis::NoCollision(old.clone()),
None => GraphAnalysis::NoCollision(self.new_graph(node)),
};
}
if graphs.len() == 1 {
return GraphAnalysis::NoCollision(graphs.drain().next().unwrap().1);
}
GraphAnalysis::Collisions(graphs)
}
/// Merges multiple graphs associated with a node into a single graph.
fn merge(&mut self, node: NodeId, mut graphs: HashMap<NodeId, Arc<Graph>>) -> Arc<Graph> {
let mut graphs = graphs.drain().map(|g| g.1);
let main = graphs.next().expect("At least one graph");
self.register_key(main.origin, node);
let mut state = main.state.lock();
for graph in graphs {
self.merge_two(&mut state, &main, graph);
}
self.graphs.insert(main.origin, main.clone());
self.graphs.insert(node, main.clone());
core::mem::drop(state);
main
}
/// Registers a key for a given origin node.
fn register_key(&mut self, origin: NodeId, key: NodeId) {
if !self.keys.contains_key(&origin) {
// Ensure an entry exists for this origin
self.keys.insert(origin, HashSet::new());
}
if origin != key {
// Register this node to point to the origin graph
self.keys.get_mut(&origin).unwrap().insert(key);
}
}
/// Merges two graphs by combining their states and updating graph mappings.
fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>, merged: Arc<Graph>) {
let mut locked = merged.state.lock();
let mut state_old = GraphState::default();
core::mem::swap(&mut state_old, &mut locked);
main_state.server.extend(state_old.server);
// Re-map merged origin to the main graph
self.graphs.insert(merged.origin, main.clone());
// Move all keys (node IDs) from the merged graph to the main graph
if let Some(locator_keys) = self.keys.remove(&merged.origin) {
for k in locator_keys.iter() {
self.graphs.insert(*k, main.clone());
}
let locator_keys_main = self
.keys
.get_mut(&main.origin)
.expect("Should be init before the merge.");
locator_keys_main.extend(locator_keys);
}
}
/// Creates a new graph for a given node.
fn new_graph(&mut self, origin: NodeId) -> Arc<Graph> {
let graph = Arc::new(Graph {
origin,
state: Mutex::new(GraphState::default()),
});
self.graphs.insert(origin, graph.clone());
self.keys.insert(origin, HashSet::new());
graph
}
fn remove_entry(&mut self, node: &NodeId) {
if let Some(graph) = self.graphs.remove(node) {
let mut remove = false;
if let Some(entry) = self.keys.get_mut(&graph.origin) {
entry.remove(node);
if entry.is_empty() {
remove = true;
}
}
if remove {
self.keys.remove(&graph.origin);
}
}
}
}
/// Represents the analysis result of graph operations for a given node and its parents.
#[derive(Debug)]
enum GraphAnalysis {
/// No collision detected, contains the graph associated with the node.
NoCollision(Arc<Graph>),
/// Collision detected, contains a map of node IDs to their associated graphs.
Collisions(HashMap<NodeId, Arc<Graph>>),
}

View File

@@ -0,0 +1,294 @@
use crate::{
NodeId,
collections::{HashMap, HashSet},
graph::Parent,
tensor::NodeRefCount,
};
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
use core::mem;
#[derive(Default, Debug)]
pub struct GraphMemoryManagement {
nodes: HashMap<NodeRefCount, Vec<NodeId>>,
leaves: HashSet<NodeId>,
statuses: HashMap<NodeId, NodeMemoryStatus>,
}
#[derive(Debug, Clone, PartialEq)]
enum NodeMemoryStatus {
Useful,
Unavailable,
Unknown,
}
impl GraphMemoryManagement {
pub fn extend(&mut self, other: Self) {
self.nodes.extend(other.nodes);
self.leaves.extend(other.leaves);
self.statuses.extend(other.statuses);
}
/// Register a new node with its parent.
pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
let node_id = *node.as_ref();
for parent in parents.iter() {
self.leaves.remove(&parent.id);
}
self.leaves.insert(node_id);
self.nodes
.insert(node, parents.iter().map(|p| p.id).collect());
}
/// Free the node from the state.
pub fn consume_node(&mut self, node_id: NodeId) {
if !self.is_referenced(node_id) {
self.leaves.remove(&node_id);
self.nodes.remove(&node_id);
}
}
/// Free all nodes whose backward call has become impossible
///
/// This function goes into three steps, which must happen for all leaves
/// before going into the next step. Then it deletes what can be safely deleted
pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
let leaves = self.leaves.clone();
let mut new_leaves = HashSet::new();
let mut deletables = Vec::new();
// When consuming nodes with a backward pass, some other backward passes become
// unavailable because some of their parents have been consumed. They are
// identified here.
for leaf in leaves.clone() {
self.unavailable_propagation(leaf);
}
// Among the available nodes that remain, some may be useless if no
// available node with a tensor reference exist in their descendance.
// But some may seem useless from some leaf but be useful from another one,
// hence the need to iterate on all leaves.
self.useful_propagation(leaves.clone());
// New leaves are the roots of a useful backward sub-tree.
// Deletables are everything not marked as useful.
for leaf in leaves {
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
}
// Replace leaves by the new ones and delete everything not useful anymore
mem::swap(&mut self.leaves, &mut new_leaves);
self.clear_unused_roots(&mut deletables);
self.statuses.clear();
for node_to_delete in deletables {
self.nodes.remove(&node_to_delete);
on_free_graph(&node_to_delete)
}
}
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
let mut deletables = Vec::new();
self.clear_unused_roots(&mut deletables);
for node_id in deletables {
self.nodes.remove(&node_id);
on_free_graph(&node_id);
}
}
fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
for (id, parents) in self.nodes.iter() {
let is_useful = matches!(
self.statuses.get(id.as_ref()),
Some(NodeMemoryStatus::Useful)
);
// Check if parents are either empty or absent from self.nodes
let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
if !is_useful && Arc::strong_count(id) == 1 && parents_absent {
to_delete.push(*id.as_ref())
}
}
}
fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {
// If already visited
if let Some(status) = self.statuses.get(&node_id) {
return status.clone();
}
match self.nodes.get(&node_id).cloned() {
// If node exists and any of its parents is unavailable, it is unavailable as well
// If node exists but the parents vec is empty, it is a tensor that never had parents;
// the status remains unknown
Some(parents) => {
let mut node_status = NodeMemoryStatus::Unknown;
for parent in parents {
let parent_status = self.unavailable_propagation(parent);
if let NodeMemoryStatus::Unavailable = parent_status {
node_status = NodeMemoryStatus::Unavailable;
}
}
self.statuses.insert(node_id, node_status.clone());
node_status
}
// If node does not exist, it was
// deleted, so this and all its descendants are unavailable
None => {
self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
NodeMemoryStatus::Unavailable
}
}
}
fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
// Accumulate visited nodes
let mut explored = HashSet::new();
let mut tagged_useful = HashSet::new();
// Queue of nodes to visit
let mut to_tag_useful = PopNodeSet::default();
let mut to_explore = PopNodeSet::new(leaves);
// Utility function to iterate over a node's parents
let parents = |node_id| {
self.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
};
loop {
// Pop a node id, greedily looking at tag_useful ones first
let (node_id, status) = match to_tag_useful.pop() {
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
None => match to_explore.pop() {
Some(node_id) => {
let node_status = self
.statuses
.get(&node_id)
.expect("All nodes should have received a status during unavailable_propagation")
.to_owned();
if let NodeMemoryStatus::Unknown = node_status {
match self.is_referenced(node_id) {
true => (node_id, NodeMemoryStatus::Useful),
false => (node_id, NodeMemoryStatus::Unknown),
}
} else {
(node_id, node_status)
}
}
None => {
// There are no nodes in the queues anymore
break;
}
},
};
match status {
NodeMemoryStatus::Useful => {
tagged_useful.insert(node_id);
for parent in parents(node_id) {
// The node can be explored, as long as it's not already tagged useful
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
to_tag_useful.insert(parent);
}
}
}
_ => {
explored.insert(node_id);
for parent in parents(node_id) {
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
to_explore.insert(parent);
}
}
}
}
self.statuses.insert(node_id, status);
}
}
fn identify_leaves_and_deletables(
&self,
leaf_id: NodeId,
new_leaves: &mut HashSet<NodeId>,
to_delete: &mut Vec<NodeId>,
) {
let mut visited = HashSet::new();
let mut to_visit = vec![leaf_id];
while let Some(node_id) = to_visit.pop() {
visited.insert(node_id);
match self
.statuses
.get(&node_id)
.expect("Node should have status")
{
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
}
_ => {
to_delete.push(node_id);
for parent in self
.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
{
if !visited.contains(&parent) {
to_visit.push(parent);
}
}
}
};
}
}
fn is_referenced(&self, node_id: NodeId) -> bool {
match self.nodes.get_key_value(&node_id) {
Some((key, _value)) => Arc::strong_count(key) > 1,
None => panic!("Node should be in the nodes map"),
}
}
pub(crate) fn maybe_useful(&self) -> bool {
self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
}
}
/// Wrapper over hash set for fast popping of any node
#[derive(new, Default)]
struct PopNodeSet {
hash_set: HashSet<NodeId>,
}
impl PopNodeSet {
#[inline(always)]
fn pop(&mut self) -> Option<NodeId> {
self.hash_set
.iter()
.next()
.copied()
.and_then(|node_id| self.hash_set.take(&node_id))
}
#[inline(always)]
fn contains(&self, node_id: &NodeId) -> bool {
self.hash_set.contains(node_id)
}
#[inline(always)]
fn insert(&mut self, node_id: NodeId) {
self.hash_set.insert(node_id);
}
}

View File

@@ -0,0 +1,6 @@
mod client;
mod memory_management;
mod server;
pub mod graph;
pub use client::*;

View File

@@ -0,0 +1,143 @@
use super::memory_management::GraphMemoryManagement;
use crate::{
NodeId,
checkpoint::{
base::{Checkpointer, NodeTree},
builder::CheckpointerBuilder,
},
collections::HashMap,
grads::Gradients,
graph::{StepBoxed, traversal::BreadthFirstSearch},
tensor::NodeRefCount,
};
use alloc::vec::Vec;
#[derive(Default)]
pub struct AutodiffServer {
steps: HashMap<NodeId, StepBoxed>,
actions_builder: HashMap<NodeId, CheckpointerBuilder>,
memory_management: GraphMemoryManagement,
}
/// Defines how nodes are clean.
pub trait NodeCleaner {
/// Initialize a new cleaner.
fn init() -> Self;
/// Cleans a single [node](NodeId).
fn clean(&mut self, node: &NodeId);
}
impl AutodiffServer {
pub fn extend(&mut self, other: AutodiffServer) {
self.steps.extend(other.steps);
self.actions_builder.extend(other.actions_builder);
self.memory_management.extend(other.memory_management);
}
pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
let parents = step.parents();
let node_id = *rc.as_ref();
self.memory_management.register(rc, parents);
self.steps.insert(node_id, step);
self.actions_builder.insert(node_id, actions);
}
pub fn backward<NC: NodeCleaner>(&mut self, grads: Gradients, node_id: NodeId) -> Gradients {
let step = self.steps.remove(&node_id).expect(
"Node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);
let builder = self.actions_builder.remove(&node_id).unwrap();
let mut consumed = Vec::new();
let (tape, checkpointer) = self.build_tape(node_id, step, builder, &mut consumed);
let gradients = Self::execute_steps(tape, grads, checkpointer);
// Cleanup
let mut cleaner = NC::init();
self.memory_management
.free_unavailable_nodes(|node_id: &NodeId| {
self.steps.remove(node_id);
self.actions_builder.remove(node_id);
NC::clean(&mut cleaner, node_id);
});
for node_id in consumed {
cleaner.clean(&node_id)
}
gradients
}
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
self.memory_management.free_unused_roots(|node_id| {
self.steps.remove(node_id);
self.actions_builder.remove(node_id);
on_free_graph(node_id);
});
}
fn build_tape(
&mut self,
node: NodeId,
node_step: StepBoxed,
mut builder: CheckpointerBuilder,
consumed: &mut Vec<NodeId>,
) -> (Vec<Vec<StepBoxed>>, Checkpointer) {
let mut tape = (0..node_step.depth())
.map(|_| Vec::with_capacity(1))
.collect::<Vec<_>>();
let mut tree = HashMap::default();
BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {
self.memory_management.consume_node(id);
// Clean up consumed node
consumed.push(id);
let depth = step.depth();
if depth == 0 {
return;
}
if let Some(steps) = tape.get_mut(depth - 1) {
let parents = step.parents().iter().map(|p| p.id).filter(|s| *s != id);
tree.insert(id, parents.collect());
steps.push(step);
}
if let Some(node_builder) = self.actions_builder.remove(&id) {
builder.extend(node_builder);
}
});
let checkpointer = builder.build(NodeTree::new(tree));
(tape, checkpointer)
}
fn execute_steps(
tape: Vec<Vec<StepBoxed>>,
mut grads: Gradients,
mut checkpointer: Checkpointer,
) -> Gradients {
tape.into_iter().rev().for_each(|steps| {
steps
.into_iter()
.for_each(|step| step.step(&mut grads, &mut checkpointer))
});
// For checkpointing tests
#[cfg(feature = "export_tests")]
assert!(checkpointer.is_empty());
grads
}
pub(crate) fn maybe_useful(&self) -> bool {
self.memory_management.maybe_useful()
}
}

View File

@@ -0,0 +1,189 @@
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
graph::{ComputingProperty, Node, NodeId, NodeRef, Parent, Requirement, Step},
runtime::{AutodiffClient, AutodiffClientImpl},
};
use alloc::{boxed::Box, sync::Arc, vec};
use burn_backend::{Backend, TensorMetadata};
#[derive(Debug, Clone)]
pub struct AutodiffTensor<B: Backend> {
pub primitive: B::FloatTensorPrimitive,
pub node: NodeRef,
pub rc: NodeRefCount,
}
impl<B: Backend> TensorMetadata for AutodiffTensor<B> {
fn dtype(&self) -> burn_std::DType {
self.primitive.dtype()
}
fn shape(&self) -> burn_std::Shape {
self.primitive.shape()
}
fn rank(&self) -> usize {
self.primitive.rank()
}
}
pub type NodeRefCount = Arc<NodeId>;
#[derive(new, Debug)]
pub(crate) struct RootStep {
node: NodeRef,
}
impl Step for RootStep {
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
// Nothing to do
}
fn node(&self) -> NodeId {
self.node.id
}
fn parents(&self) -> &[Parent] {
&self.node.parents
}
fn depth(&self) -> usize {
self.node.order
}
}
impl<B: Backend> AutodiffTensor<B> {
/// Create a new leaf tensor.
pub fn new(primitive: B::FloatTensorPrimitive) -> Self {
let id = NodeId::new();
let node: NodeRef = Node::new(
vec![],
0,
id,
Requirement::None,
ComputingProperty::Ambiguous,
AutodiffClientImpl::new(),
)
.into();
Self {
rc: Arc::new(node.id),
primitive,
node: node.clone(),
}
}
pub fn is_tracked(&self) -> bool {
!self.node.requirement.is_none()
}
/// Mark the tensor as requiring gradients.
///
/// # Panics
///
/// It panics if the tensor is not a leaf.
pub fn require_grad(mut self) -> Self {
match self.node.requirement {
Requirement::Grad => self,
Requirement::GradInBackward => {
panic!("Can't convert a non leaf tensor into a tracked tensor")
}
Requirement::None => {
self.node = Node::new(
vec![],
0,
self.node.id,
Requirement::Grad,
self.node.properties.clone(),
self.node.client.clone(),
)
.into();
let step = RootStep::new(self.node.clone());
self.register_step(step, CheckpointerBuilder::default())
}
}
}
/// Create a tensor from parent infos.
pub fn from_parents(
primitive: B::FloatTensorPrimitive,
parent_nodes: &[NodeRef],
requirement: Requirement,
computing_properties: ComputingProperty,
) -> Self {
let order = parent_nodes
.iter()
.map(|node| node.order)
.reduce(usize::max)
.unwrap_or(0)
+ 1;
let client = parent_nodes
.first()
.map(|node| node.client.clone())
.unwrap_or_else(AutodiffClientImpl::new);
let node: NodeRef = Node::new(
parent_nodes
.iter()
.filter_map(|node| node.clone_if_require_grad())
.map(|node| Parent::new(node.id))
.collect(),
order,
NodeId::new(),
requirement,
computing_properties,
client,
)
.into();
Self {
rc: Arc::new(node.id),
primitive,
node,
}
}
/// Register a step into a graph for that tensor.
///
/// # Warning
///
/// This should be called only once per tensor.
pub fn register_step<S: Step + 'static>(
self,
step_that_created_the_tensor: S,
actions: CheckpointerBuilder,
) -> Self {
self.node.client.register(
self.rc.clone(),
Box::new(step_that_created_the_tensor),
actions,
);
self
}
pub fn into_primitive(self) -> B::FloatTensorPrimitive {
self.primitive
}
pub fn backward(self) -> Gradients {
let client = self.node.client.clone();
AutodiffClient::backward::<B>(&client, self)
}
pub fn grad(&self, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
grads.get::<B>(self)
}
pub fn grad_remove(&self, grads: &mut Gradients) -> Option<B::FloatTensorPrimitive> {
grads.remove::<B>(self)
}
pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPrimitive) {
grads.remove::<B>(self);
grads.register::<B>(self.node.id, grad);
}
}

View File

@@ -0,0 +1,25 @@
use alloc::vec::Vec;
use crate::graph::NodeRef;
/// Duplicate the given object for each node that requires gradients.
///
/// # Notes
///
/// This is useful since you don't have to keep N cloned references alive event if just 1 node
/// will be updated.
///
/// If the object is a tensor and if one reference exists, it can be updated inplace.
pub fn duplicate<T: Clone + core::fmt::Debug, const N: usize>(
nodes: &[Option<NodeRef>; N],
obj: Option<T>,
) -> [Option<T>; N] {
nodes
.iter()
.map(|node| match node {
Some(_) => obj.clone(),
None => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}