feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Automatic differentiation backend for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-autodiff"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-autodiff"
|
||||
documentation = "https://docs.rs/burn-autodiff"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std", "tracing"]
|
||||
std = ["dep:parking_lot"]
|
||||
export_tests = [] # check checkpointer is_empty in tests
|
||||
|
||||
tracing = [
|
||||
"dep:tracing",
|
||||
"burn-std/tracing",
|
||||
"burn-backend/tracing",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
parking_lot = { workspace = true, optional = true }
|
||||
log = { workspace = true }
|
||||
hashbrown = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
portable-atomic = { workspace = true }
|
||||
tracing = { workspace = true, optional = true, features = ["default"] }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["default"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,8 @@
|
||||
# Burn Autodiff
|
||||
|
||||
> [Burn](https://github.com/tracel-ai/burn) autodiff backend
|
||||
|
||||
[](https://crates.io/crates/burn-autodiff)
|
||||
[](https://github.com/tracel-ai/burn-autodiff/blob/master/README.md)
|
||||
|
||||
For now only first order reverse mode autodiff is supported.
|
||||
@@ -0,0 +1,138 @@
|
||||
use crate::{
|
||||
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
|
||||
grads::Gradients,
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use alloc::{format, string::String};
|
||||
use burn_backend::{
|
||||
backend::{AutodiffBackend, Backend, ExecutionError},
|
||||
tensor::{BoolTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Enable auto-differentiation on a backend.
|
||||
///
|
||||
/// This works as a backend decorator, extending the functionality of any backend with
|
||||
/// backpropagation.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct Autodiff<B, C = NoCheckpointing> {
|
||||
_b: PhantomData<B>,
|
||||
_checkpoint_strategy: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||
type Device = B::Device;
|
||||
|
||||
type FloatTensorPrimitive = AutodiffTensor<B>;
|
||||
type FloatElem = B::FloatElem;
|
||||
|
||||
type IntTensorPrimitive = B::IntTensorPrimitive;
|
||||
type IntElem = B::IntElem;
|
||||
|
||||
type BoolTensorPrimitive = B::BoolTensorPrimitive;
|
||||
type BoolElem = B::BoolElem;
|
||||
|
||||
type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
format!("autodiff<{}>", B::name(device))
|
||||
}
|
||||
|
||||
fn seed(device: &B::Device, seed: u64) {
|
||||
B::seed(device, seed)
|
||||
}
|
||||
|
||||
fn sync(device: &B::Device) -> Result<(), ExecutionError> {
|
||||
B::sync(device)
|
||||
}
|
||||
|
||||
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
|
||||
device: &Self::Device,
|
||||
input: Input,
|
||||
func: Func,
|
||||
) -> Output {
|
||||
B::memory_persistent_allocations(device, input, func)
|
||||
}
|
||||
|
||||
fn memory_cleanup(device: &Self::Device) {
|
||||
B::memory_cleanup(device)
|
||||
}
|
||||
|
||||
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
|
||||
where
|
||||
Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
|
||||
{
|
||||
B::staging(data, device);
|
||||
}
|
||||
|
||||
fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
|
||||
B::supports_dtype(device, dtype)
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
|
||||
B::dtype_usage(device, dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
|
||||
type InnerBackend = B;
|
||||
type Gradients = Gradients;
|
||||
|
||||
fn backward(tensor: AutodiffTensor<B>) -> Gradients {
|
||||
tensor.backward()
|
||||
}
|
||||
|
||||
fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
|
||||
tensor.grad(grads)
|
||||
}
|
||||
|
||||
fn grad_remove(
|
||||
tensor: &AutodiffTensor<B>,
|
||||
grads: &mut Gradients,
|
||||
) -> Option<B::FloatTensorPrimitive> {
|
||||
tensor.grad_remove(grads)
|
||||
}
|
||||
fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
|
||||
tensor.primitive
|
||||
}
|
||||
|
||||
fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
|
||||
AutodiffTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn grad_replace(
|
||||
tensor: &AutodiffTensor<B>,
|
||||
grads: &mut Self::Gradients,
|
||||
grad: B::FloatTensorPrimitive,
|
||||
) {
|
||||
tensor.grad_replace(grads, grad);
|
||||
}
|
||||
|
||||
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::{
|
||||
retro_forward::RetroForwards,
|
||||
state::{BackwardStates, State},
|
||||
};
|
||||
use crate::collections::HashMap;
|
||||
use crate::graph::NodeId;
|
||||
|
||||
use alloc::{vec, vec::Vec};
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Links a [NodeId] to its autodiff graph [NodeRef]
|
||||
pub(crate) struct NodeTree {
|
||||
map: HashMap<NodeId, Vec<NodeId>>,
|
||||
}
|
||||
|
||||
impl NodeTree {
|
||||
/// Gives the parents of the node in the autodiff graph
|
||||
pub(crate) fn parents(&self, node_id: &NodeId) -> Option<Vec<NodeId>> {
|
||||
self.map.get(node_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Struct responsible of fetching the output for a node in the autodiff graph during a backward pass
|
||||
pub struct Checkpointer {
|
||||
backward_states: BackwardStates,
|
||||
retro_forwards: RetroForwards,
|
||||
node_tree: NodeTree,
|
||||
}
|
||||
|
||||
impl Checkpointer {
|
||||
/// Gives the output of the given node, by recursively asking parents to compute themselves
|
||||
/// or give their pre-computed tensors.
|
||||
pub fn retrieve_node_output<T>(&mut self, node_id: NodeId) -> T
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
self.topological_sort(node_id).into_iter().for_each(|node| {
|
||||
self.retro_forwards
|
||||
.execute_retro_forward(node, &mut self.backward_states)
|
||||
});
|
||||
|
||||
self.backward_states.get_state::<T>(&node_id)
|
||||
}
|
||||
|
||||
/// Sorts the ancestors of NodeId in a way such that all parents come before their children
|
||||
/// Useful to avoid recursivity later when mutating the states
|
||||
///
|
||||
/// The sort on a compute bound state or a memory bound that is already computed is trivial.
|
||||
/// The match on State::Computed also serves as a stopping criterion for the sort,
|
||||
/// we don't need to look higher than that during recursivity.
|
||||
fn topological_sort(&self, node_id: NodeId) -> Vec<NodeId> {
|
||||
match self.backward_states.get_state_ref(&node_id) {
|
||||
Some(state) => match state {
|
||||
State::Recompute { n_required: _ } => {
|
||||
let mut sorted = Vec::new();
|
||||
let parents = self.node_tree.parents(&node_id).unwrap();
|
||||
for parent_node in parents {
|
||||
let parent_sorted = self.topological_sort(parent_node);
|
||||
for ps in parent_sorted {
|
||||
if !sorted.contains(&ps) {
|
||||
sorted.push(ps)
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted.push(node_id);
|
||||
sorted
|
||||
}
|
||||
State::Computed {
|
||||
state_content: _,
|
||||
n_required: _,
|
||||
} => vec![node_id],
|
||||
},
|
||||
None => panic!("Node {node_id:?} is not in the backward_states. "),
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if checkpointer has been drained adequately. Useful for testing
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.backward_states.is_empty() && self.retro_forwards.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
use crate::{
|
||||
collections::HashMap,
|
||||
graph::{ComputingProperty, NodeId},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use alloc::{boxed::Box, sync::Arc, vec::Vec};
|
||||
use burn_backend::Backend;
|
||||
use core::any::Any;
|
||||
|
||||
use super::{
|
||||
base::{Checkpointer, NodeTree},
|
||||
retro_forward::{RetroForward, RetroForwards},
|
||||
state::{BackwardStates, State},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Determines if a node should checkpoint its computed output or its retro_forward for recomputation
|
||||
/// The action is normally created by the child of the node, once the node is determined to be needed
|
||||
pub enum CheckpointingAction {
|
||||
/// The node's already computed output should be saved
|
||||
Computed {
|
||||
/// The node
|
||||
node_id: NodeId,
|
||||
/// The node's output
|
||||
state_content: Box<dyn Any + Send>,
|
||||
},
|
||||
/// The node should recompute itself when asked
|
||||
Recompute {
|
||||
/// The node
|
||||
node_id: NodeId,
|
||||
/// How the node should recompute itself
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Remove that when proper client server.
|
||||
unsafe impl Send for CheckpointingAction {}
|
||||
|
||||
impl CheckpointingAction {
|
||||
/// Utility function to access the id of the node of the checkpointing action
|
||||
pub fn id(&self) -> NodeId {
|
||||
match self {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => *node_ref,
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: node_ref,
|
||||
retro_forward: _,
|
||||
} => *node_ref,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug, Default)]
|
||||
/// Accumulates checkpoints as checkpointing actions during the forward pass,
|
||||
/// and builds a checkpointer right before the backward pass
|
||||
pub struct CheckpointerBuilder {
|
||||
explicit_actions: Vec<CheckpointingAction>,
|
||||
backup_actions: Vec<CheckpointingAction>,
|
||||
}
|
||||
|
||||
/// Determines if a checkpoint should impact the n_required values (Main)
|
||||
/// or if it should just keep the state in case it's required (Backup)
|
||||
///
|
||||
pub(crate) enum ActionType {
|
||||
/// Explicit actions have been explicitly requested by some operation to retrieve their state
|
||||
Explicit,
|
||||
/// Backup actions are not always needed. They exist to save the output of an operation
|
||||
/// whose child is memory bound, in case the state is indirectly needed when computing
|
||||
/// the child's retro_forward. If no explicit action ever asks for the child's output, then
|
||||
/// the backup output will go out of scope when the checkpointer is built.
|
||||
Backup,
|
||||
}
|
||||
|
||||
impl CheckpointerBuilder {
|
||||
pub(crate) fn checkpoint<B: Backend>(
|
||||
&mut self,
|
||||
tensor: &AutodiffTensor<B>,
|
||||
action_type: ActionType,
|
||||
) {
|
||||
let action_list = match action_type {
|
||||
ActionType::Explicit => &mut self.explicit_actions,
|
||||
ActionType::Backup => &mut self.backup_actions,
|
||||
};
|
||||
match &tensor.node.properties {
|
||||
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
|
||||
action_list.push(CheckpointingAction::Computed {
|
||||
node_id: tensor.node.id,
|
||||
state_content: Box::new(tensor.primitive.clone()),
|
||||
})
|
||||
}
|
||||
ComputingProperty::MemoryBound { retro_forward } => {
|
||||
action_list.push(CheckpointingAction::Recompute {
|
||||
node_id: tensor.node.id,
|
||||
retro_forward: retro_forward.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn extend(&mut self, other: CheckpointerBuilder) {
|
||||
for other_action in other.explicit_actions {
|
||||
self.explicit_actions.push(other_action)
|
||||
}
|
||||
for other_unsure in other.backup_actions {
|
||||
self.backup_actions.push(other_unsure)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {
|
||||
let mut backward_states_map = HashMap::new();
|
||||
let mut retro_forwards_map = HashMap::new();
|
||||
|
||||
// Find recursion stopping points
|
||||
let stop_nodes: Vec<NodeId> = self.find_stop_nodes();
|
||||
|
||||
// We start by identifying how many times each node will be required.
|
||||
let n_required_map = self.build_n_required_map(&node_tree, stop_nodes);
|
||||
|
||||
// Then we checkpoint the nodes with the corresponding n_required value
|
||||
self.insert_checkpoints(
|
||||
&mut backward_states_map,
|
||||
&mut retro_forwards_map,
|
||||
n_required_map,
|
||||
);
|
||||
|
||||
Checkpointer::new(
|
||||
BackwardStates::new(backward_states_map),
|
||||
RetroForwards::new(retro_forwards_map),
|
||||
node_tree,
|
||||
)
|
||||
}
|
||||
|
||||
fn find_stop_nodes(&self) -> Vec<NodeId> {
|
||||
let mut stop_nodes = Vec::default();
|
||||
for action in self
|
||||
.explicit_actions
|
||||
.iter()
|
||||
.chain(self.backup_actions.iter())
|
||||
{
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => stop_nodes.push(*node_ref),
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: _,
|
||||
retro_forward: _,
|
||||
} => {}
|
||||
}
|
||||
}
|
||||
stop_nodes
|
||||
}
|
||||
|
||||
fn build_n_required_map(
|
||||
&self,
|
||||
node_tree: &NodeTree,
|
||||
stop_nodes: Vec<NodeId>,
|
||||
) -> HashMap<NodeId, usize> {
|
||||
let mut n_required_map = HashMap::<NodeId, usize>::default();
|
||||
|
||||
for action in self.explicit_actions.iter() {
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => {
|
||||
let id = *node_ref;
|
||||
match n_required_map.remove(&id) {
|
||||
Some(n) => {
|
||||
n_required_map.insert(id, n + 1);
|
||||
}
|
||||
None => {
|
||||
n_required_map.insert(id, 1);
|
||||
}
|
||||
};
|
||||
}
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: node_ref,
|
||||
retro_forward: _,
|
||||
} => {
|
||||
let id = *node_ref;
|
||||
Self::update_n_required_of_parents(
|
||||
id,
|
||||
&mut n_required_map,
|
||||
node_tree,
|
||||
&stop_nodes,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n_required_map
|
||||
}
|
||||
|
||||
fn insert_checkpoints(
|
||||
mut self,
|
||||
backward_states_map: &mut HashMap<NodeId, State>,
|
||||
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
|
||||
n_required_map: HashMap<NodeId, usize>,
|
||||
) {
|
||||
// We do not loop over checkpointing actions anymore because they can contain
|
||||
// duplicates or miss some that are in backup. We loop over the n_required_map
|
||||
// from which we use the ids to find them again in the checkpointing actions
|
||||
for (node_id, n_required) in n_required_map {
|
||||
// We find the checkpointing action for node_id. It's likely in checkpointing_actions
|
||||
// so we check there first, otherwise it will be in backup.
|
||||
// Technically it can be there several times but can never be of both types, so we can assume the first we find is fine
|
||||
|
||||
let action = match self
|
||||
.explicit_actions
|
||||
.iter()
|
||||
.position(|action| action.id() == node_id)
|
||||
{
|
||||
Some(pos) => self.explicit_actions.remove(pos),
|
||||
None => {
|
||||
let pos = self
|
||||
.backup_actions
|
||||
.iter()
|
||||
.position(|action| action.id() == node_id);
|
||||
self.backup_actions.remove(pos.unwrap_or_else(|| {
|
||||
panic!("Node {:?} is needed but never checkpointed", &node_id)
|
||||
}))
|
||||
}
|
||||
};
|
||||
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: _,
|
||||
state_content,
|
||||
} => {
|
||||
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
|
||||
}
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: _,
|
||||
retro_forward,
|
||||
} => self.checkpoint_lazy(
|
||||
backward_states_map,
|
||||
retro_forward_map,
|
||||
node_id,
|
||||
retro_forward,
|
||||
n_required,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn update_n_required_of_parents(
|
||||
id: NodeId,
|
||||
n_required_map: &mut HashMap<NodeId, usize>,
|
||||
node_tree: &NodeTree,
|
||||
stop_nodes: &Vec<NodeId>,
|
||||
) {
|
||||
match n_required_map.remove(&id) {
|
||||
Some(n) => {
|
||||
n_required_map.insert(id, n + 1);
|
||||
}
|
||||
None => {
|
||||
n_required_map.insert(id, 1);
|
||||
if !stop_nodes.contains(&id)
|
||||
&& let Some(parents) = node_tree.parents(&id)
|
||||
{
|
||||
for p in parents {
|
||||
Self::update_n_required_of_parents(
|
||||
p,
|
||||
n_required_map,
|
||||
node_tree,
|
||||
stop_nodes,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn checkpoint_compute(
|
||||
&self,
|
||||
backward_states_map: &mut HashMap<NodeId, State>,
|
||||
node_id: NodeId,
|
||||
state_content: Box<dyn Any + Send>,
|
||||
n_required: usize,
|
||||
) {
|
||||
backward_states_map.insert(
|
||||
node_id,
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn checkpoint_lazy(
|
||||
&self,
|
||||
backward_states_map: &mut HashMap<NodeId, State>,
|
||||
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
|
||||
node_id: NodeId,
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
n_required: usize,
|
||||
) {
|
||||
retro_forward_map.insert(node_id, retro_forward);
|
||||
backward_states_map.insert(node_id, State::Recompute { n_required });
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
/// Checkpointer module
|
||||
pub mod base;
|
||||
pub(crate) mod builder;
|
||||
/// RetroForward module
|
||||
pub mod retro_forward;
|
||||
/// BackwardStates module
|
||||
pub mod state;
|
||||
/// CheckpointStrategy module
|
||||
pub mod strategy;
|
||||
@@ -0,0 +1,116 @@
|
||||
use crate::collections::HashMap;
|
||||
use crate::graph::NodeId;
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use core::fmt::Debug;
|
||||
|
||||
use super::state::{BackwardStates, State};
|
||||
|
||||
/// Definition of the forward function of a node, called during retropropagation only.
|
||||
/// This is different from the normal forward function because it reads and writes from
|
||||
/// the [BackwardStates] map instead of having a clear function signature.
|
||||
pub trait RetroForward: Debug + Send + 'static {
|
||||
/// Applies the forward pass for retropropagation.
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId);
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Links [NodeId]s to their corresponding [RetroForward]
|
||||
pub(crate) struct RetroForwards {
|
||||
map: HashMap<NodeId, Arc<dyn RetroForward>>,
|
||||
}
|
||||
|
||||
impl RetroForwards {
|
||||
/// Executes the [RetroForward] for a given [NodeId] if the node's
|
||||
/// [State] is [State::Recompute], otherwise does nothing.
|
||||
pub(crate) fn execute_retro_forward(
|
||||
&mut self,
|
||||
node_id: NodeId,
|
||||
backward_states: &mut BackwardStates,
|
||||
) {
|
||||
if let State::Recompute { n_required: _ } = backward_states
|
||||
.get_state_ref(&node_id)
|
||||
.unwrap_or_else(|| panic!("Should find node {node_id:?}"))
|
||||
{
|
||||
// Retro forwards are always used only once because afterwards their state is computed
|
||||
let retro_forward = self.map.remove(&node_id).unwrap();
|
||||
retro_forward.forward(backward_states, node_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Creates a RetroForward struct for unary scalar operations
|
||||
macro_rules! retro_unary_scalar {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend> {
|
||||
lhs_id: NodeId,
|
||||
rhs: Scalar,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RetroForward for $name<B> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
|
||||
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
|
||||
let out = $ops(lhs, self.rhs);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Creates a RetroForward struct for unary scalar operations
|
||||
macro_rules! retro_unary {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend> {
|
||||
input_id: NodeId,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RetroForward for $name<B> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
|
||||
let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);
|
||||
let out = $ops(input);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Creates a RetroForward struct for binary operations
|
||||
macro_rules! retro_binary {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend> {
|
||||
lhs_id: NodeId,
|
||||
rhs_id: NodeId,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RetroForward for $name<B> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
|
||||
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
|
||||
let rhs = states.get_state::<B::FloatTensorPrimitive>(&self.rhs_id);
|
||||
let out = $ops(lhs, rhs);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
use core::any::Any;
|
||||
|
||||
use crate::collections::HashMap;
|
||||
use crate::graph::NodeId;
|
||||
use alloc::boxed::Box;
|
||||
|
||||
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
|
||||
pub(crate) type StateContent = Box<dyn Any + Send>;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// The state contained at one node. Encapsulates the node output if precomputed,
|
||||
/// or clearly asks that it needs to be recomputed from the parents.
|
||||
/// Also keeps track of the number of times the state is required so it can be removed
|
||||
/// from the map of states on its last use.
|
||||
pub(crate) enum State {
|
||||
/// The state was not checkpointed, will need to recompute it from the node's parents
|
||||
Recompute { n_required: usize },
|
||||
/// The state was checkpointed or computed during retropropagation and can be directly accessed
|
||||
Computed {
|
||||
state_content: StateContent,
|
||||
n_required: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl State {
|
||||
/// Returns a reference to the (not yet) downcasted node output, if checkpointed
|
||||
pub(crate) fn to_state_content(&self) -> &StateContent {
|
||||
match self {
|
||||
State::Recompute { n_required: _ } => {
|
||||
unreachable!(
|
||||
"Can't get state content of recompute state. A child has likely been accessed before its parents."
|
||||
)
|
||||
}
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required: _,
|
||||
} => state_content,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a (not yet) downcasted node output, if checkpointed
|
||||
pub(crate) fn into_state_content(self) -> StateContent {
|
||||
match self {
|
||||
State::Recompute { n_required: _ } => {
|
||||
unreachable!(
|
||||
"Can't get state content of recompute state. A child has likely been accessed before its parents."
|
||||
)
|
||||
}
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required: _,
|
||||
} => state_content,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of time the state is required
|
||||
pub(crate) fn n_required(&self) -> usize {
|
||||
match self {
|
||||
State::Recompute { n_required } => *n_required,
|
||||
State::Computed {
|
||||
state_content: _,
|
||||
n_required,
|
||||
} => *n_required,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Default, Debug)]
|
||||
/// Links [NodeId]s to their current state
|
||||
pub struct BackwardStates {
|
||||
map: HashMap<NodeId, State>,
|
||||
}
|
||||
|
||||
impl BackwardStates {
|
||||
/// Returns the output in the state of the given [NodeId],
|
||||
/// and decrements the number of times this state is required.
|
||||
/// This function always gives ownership of the output, but will clone it if needed for further uses.
|
||||
pub fn get_state<T>(&mut self, node_id: &NodeId) -> T
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
// Fetch the state and decrement its number of required
|
||||
let state = self.map.remove(node_id).unwrap();
|
||||
let remaining_n_required = state.n_required() - 1;
|
||||
|
||||
// Downcast the state to whatever it is supposed to be
|
||||
// If still needed after giving ownership, we copy it back to the hashmap
|
||||
if remaining_n_required > 0 {
|
||||
let new_stored_state = match state {
|
||||
State::Recompute { n_required: _ } => unreachable!(),
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required: _,
|
||||
} => State::Computed {
|
||||
state_content,
|
||||
n_required: remaining_n_required,
|
||||
},
|
||||
};
|
||||
|
||||
let downcasted = new_stored_state
|
||||
.to_state_content()
|
||||
.downcast_ref::<T>()
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
self.insert_state(*node_id, new_stored_state);
|
||||
|
||||
downcasted
|
||||
} else {
|
||||
let downcasted = state.into_state_content().downcast::<T>().unwrap();
|
||||
*downcasted
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the [State] of the given node
|
||||
/// Useful when we need [State] information without needing the underlying tensor
|
||||
pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> {
|
||||
self.map.get(node_id)
|
||||
}
|
||||
|
||||
/// Associates a [State] to its [NodeId]
|
||||
pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) {
|
||||
self.map.insert(node_id, state);
|
||||
}
|
||||
|
||||
/// Saves the output to the state of the given [NodeId].
|
||||
pub fn save<T>(&mut self, node_id: NodeId, saved_output: T)
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
|
||||
self.insert_state(
|
||||
node_id,
|
||||
State::Computed {
|
||||
state_content: Box::new(saved_output),
|
||||
n_required,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
use core::fmt::Debug;
|
||||
|
||||
use burn_backend::Backend;
|
||||
|
||||
use crate::{graph::ComputingProperty, tensor::AutodiffTensor};
|
||||
use alloc::sync::Arc;
|
||||
|
||||
use super::{
|
||||
builder::{ActionType, CheckpointerBuilder},
|
||||
retro_forward::RetroForward,
|
||||
};
|
||||
|
||||
/// Strategy for the amount of checkpointing to do during autodiff
|
||||
pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'static {
|
||||
/// May modify the compute property depending on the strategy
|
||||
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty;
|
||||
|
||||
/// Checkpoints parents if necessary in the strategy
|
||||
fn checkpoint_parents<'a, B2, A>(
|
||||
parents: A,
|
||||
builder: &mut CheckpointerBuilder,
|
||||
) -> Result<(), CheckpointingError>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Error that can happen when trying to checkpoint a tensor.
|
||||
pub enum CheckpointingError {
|
||||
/// When a parent is untracked, we can't easily checkpoint its state, since we don't know the
|
||||
/// requirements in advanced.
|
||||
UntrackedParent,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
/// All operations are considered compute bound, notwithstanding how they are marked
|
||||
pub struct NoCheckpointing {}
|
||||
|
||||
impl CheckpointStrategy for NoCheckpointing {
|
||||
/// An operation marked as memory bound is actually compute bound.
|
||||
fn compute_property<R: RetroForward>(_retro_forward: R) -> ComputingProperty {
|
||||
ComputingProperty::ComputeBound
|
||||
}
|
||||
|
||||
/// An operation marked as memory bound is actually compute bound.
|
||||
/// It's therefore useless to checkpoint the parents
|
||||
fn checkpoint_parents<'a, B2, A>(
|
||||
_parents: A,
|
||||
_builder: &mut CheckpointerBuilder,
|
||||
) -> Result<(), CheckpointingError>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
|
||||
{
|
||||
// Nothing to do here
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
/// Operation properties are as they are marked (compute or memory bound)
|
||||
pub struct BalancedCheckpointing {}
|
||||
|
||||
impl CheckpointStrategy for BalancedCheckpointing {
|
||||
/// An operation marked as memory bound is memory bound.
|
||||
/// When memory bound, an operation needs to save its RetroForward
|
||||
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty {
|
||||
ComputingProperty::MemoryBound {
|
||||
retro_forward: Arc::new(retro_forward),
|
||||
}
|
||||
}
|
||||
|
||||
/// An operation marked as memory bound is really memory bound.
|
||||
/// Since the operation may not checkpoint its parents but may need them indirectly
|
||||
/// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them
|
||||
fn checkpoint_parents<'a, B2, A>(
|
||||
parents: A,
|
||||
builder: &mut CheckpointerBuilder,
|
||||
) -> Result<(), CheckpointingError>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
|
||||
{
|
||||
let mut can_checkpoint = true;
|
||||
|
||||
for tensor in parents.into_iter() {
|
||||
if let crate::graph::Requirement::None = tensor.node.requirement {
|
||||
can_checkpoint = false;
|
||||
} else {
|
||||
builder.checkpoint(tensor, ActionType::Backup);
|
||||
}
|
||||
}
|
||||
|
||||
if !can_checkpoint {
|
||||
*builder = CheckpointerBuilder::default();
|
||||
return Err(CheckpointingError::UntrackedParent);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
use burn_backend::{
|
||||
Backend, TensorMetadata, TensorPrimitive,
|
||||
tensor::{FloatTensor, TensorContainer},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
NodeId,
|
||||
graph::{NodeRef, Requirement},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
|
||||
/// Gradient identifier.
|
||||
pub type GradID = u64;
|
||||
|
||||
/// Gradients container used during the backward pass.
|
||||
pub struct Gradients {
|
||||
container: TensorContainer<GradID>,
|
||||
}
|
||||
|
||||
impl Gradients {
|
||||
/// Creates a new gradients container.
|
||||
pub fn new<B: Backend>(root_node: NodeRef, root_tensor: FloatTensor<B>) -> Self {
|
||||
let mut gradients = Self {
|
||||
container: TensorContainer::new(),
|
||||
};
|
||||
gradients.register::<B>(
|
||||
root_node.id,
|
||||
B::float_ones(
|
||||
root_tensor.shape(),
|
||||
&B::float_device(&root_tensor),
|
||||
root_tensor.dtype().into(),
|
||||
),
|
||||
);
|
||||
gradients
|
||||
}
|
||||
|
||||
/// Consumes the gradients for a given tensor.
|
||||
///
|
||||
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
|
||||
/// backward pass, otherwise, it may be consume multiple times.
|
||||
pub fn consume<B: Backend>(&mut self, node: &NodeRef) -> FloatTensor<B> {
|
||||
match node.requirement {
|
||||
Requirement::Grad => self
|
||||
.container
|
||||
.get::<B>(&node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
.expect("Can't consume the gradients before they are registered at least once."),
|
||||
Requirement::GradInBackward => self
|
||||
.container
|
||||
.remove::<B>(&node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
.expect("Can't consume the gradients before they are registered at least once."),
|
||||
Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a grad tensor from the container.
|
||||
pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
|
||||
self.container
|
||||
.remove::<B>(&tensor.node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
}
|
||||
|
||||
/// Gets a grad tensor from the container.
|
||||
pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
|
||||
self.container
|
||||
.get::<B>(&tensor.node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
}
|
||||
|
||||
/// Register a grad tensor in the container.
|
||||
///
|
||||
/// If the tensor already exists, add both tensors together before saving the result.
|
||||
pub fn register<B: Backend>(&mut self, node_id: NodeId, value: FloatTensor<B>) {
|
||||
if let Some(tensor_old) = self.container.remove::<B>(&node_id.value) {
|
||||
self.container.register::<B>(
|
||||
node_id.value,
|
||||
TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())),
|
||||
);
|
||||
} else {
|
||||
self.container
|
||||
.register::<B>(node_id.value, TensorPrimitive::Float(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
use super::NodeId;
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent};
|
||||
use alloc::boxed::Box;
|
||||
|
||||
/// Backward step for reverse mode autodiff.
|
||||
pub trait Step: Send + core::fmt::Debug {
|
||||
/// Executes the step and consumes it.
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
|
||||
/// Depth of the operation relative to the first node added to a graph.
|
||||
fn depth(&self) -> usize;
|
||||
/// The node associated to the step.
|
||||
fn node(&self) -> NodeId;
|
||||
/// The parents of the node associated to the step.
|
||||
fn parents(&self) -> &[Parent];
|
||||
}
|
||||
|
||||
pub type StepBoxed = Box<dyn Step>;
|
||||
@@ -0,0 +1,9 @@
|
||||
mod base;
|
||||
mod node;
|
||||
mod requirement;
|
||||
|
||||
pub mod traversal;
|
||||
|
||||
pub use base::*;
|
||||
pub use node::*;
|
||||
pub use requirement::*;
|
||||
@@ -0,0 +1,87 @@
|
||||
use alloc::{sync::Arc, vec::Vec};
|
||||
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
use core::sync::atomic::{AtomicU64, Ordering};
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
use portable_atomic::{AtomicU64, Ordering};
|
||||
|
||||
use crate::checkpoint::retro_forward::RetroForward;
|
||||
use crate::runtime::AutodiffClientImpl;
|
||||
|
||||
use super::Requirement;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ComputingProperty {
|
||||
ComputeBound,
|
||||
MemoryBound {
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
},
|
||||
Ambiguous, // Maybe autotune someday
|
||||
}
|
||||
|
||||
/// This is safe only because we only call RetroForward on the autodiff server.
|
||||
/// Therefore, the trait will never be used by multiple threads at the same time.
|
||||
///
|
||||
/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
|
||||
/// Arc, which will make (dyn RetroForward) safely implement Send.
|
||||
unsafe impl Send for ComputingProperty {}
|
||||
/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
|
||||
unsafe impl Sync for ComputingProperty {}
|
||||
|
||||
/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Node {
|
||||
pub parents: Vec<Parent>,
|
||||
pub order: usize,
|
||||
pub id: NodeId,
|
||||
pub requirement: Requirement,
|
||||
pub properties: ComputingProperty,
|
||||
pub client: AutodiffClientImpl,
|
||||
}
|
||||
pub type NodeRef = Arc<Node>;
|
||||
|
||||
#[derive(new, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Parent {
|
||||
pub id: NodeId,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Returns the [node](Node) only if gradients are required.
|
||||
pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
|
||||
match self.requirement.is_none() {
|
||||
true => None,
|
||||
false => Some(self.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier generated for each node.
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
|
||||
pub struct NodeId {
|
||||
/// The integer representation of the id
|
||||
pub value: u64,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for NodeId {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!("NodeId({})", self.value))
|
||||
}
|
||||
}
|
||||
|
||||
impl NodeId {
|
||||
/// Create a unique [node id](NodeId).
|
||||
pub fn new() -> Self {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
if value == u64::MAX {
|
||||
panic!("NodeId overflowed");
|
||||
}
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use super::NodeRef;
|
||||
|
||||
/// Requirement for each tensor in the graph.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Requirement {
|
||||
/// Operations that require gradients.
|
||||
Grad,
|
||||
/// Operations that require gradients only for backprop.
|
||||
GradInBackward,
|
||||
/// Operations that don't need gradients, therefore not to be included in the graph.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Requirement {
|
||||
/// Returns true if gradients are not required.
|
||||
pub fn is_none(&self) -> bool {
|
||||
matches!(self, Self::None)
|
||||
}
|
||||
/// Returns the right requirement from a list of nodes.
|
||||
pub fn from_nodes(nodes: &[NodeRef]) -> Self {
|
||||
if nodes.len() == 1 {
|
||||
return nodes[0].requirement.infer(&Requirement::None);
|
||||
}
|
||||
|
||||
nodes
|
||||
.iter()
|
||||
.map(|node| node.requirement)
|
||||
.reduce(|acc, requirement| requirement.infer(&acc))
|
||||
.unwrap_or(Requirement::None)
|
||||
}
|
||||
|
||||
fn infer(&self, other: &Self) -> Self {
|
||||
match self.is_none() && other.is_none() {
|
||||
true => Self::None,
|
||||
false => Self::GradInBackward,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
use super::{Step, StepBoxed};
|
||||
use crate::{
|
||||
NodeId,
|
||||
collections::{HashMap, HashSet},
|
||||
graph::Parent,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Breadth for search algorithm.
|
||||
pub struct BreadthFirstSearch;
|
||||
|
||||
pub trait TraversalItem {
|
||||
fn id(&self) -> NodeId;
|
||||
fn parents(&self) -> &[Parent];
|
||||
fn parent_nodes(&self) -> Vec<NodeId> {
|
||||
self.parents().iter().map(|p| p.id).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl BreadthFirstSearch {
|
||||
/// Traverse the graph of backward steps from a root node.
|
||||
pub fn traverse<F, I>(
|
||||
&self,
|
||||
root_id: NodeId,
|
||||
root_step: I,
|
||||
steps: &mut HashMap<NodeId, I>,
|
||||
mut callback: F,
|
||||
) where
|
||||
F: FnMut(NodeId, I),
|
||||
I: TraversalItem,
|
||||
{
|
||||
let mut visited = HashSet::new();
|
||||
let mut parents = Vec::new();
|
||||
|
||||
visited.insert(root_id);
|
||||
parents.append(&mut root_step.parent_nodes());
|
||||
|
||||
callback(root_id, root_step);
|
||||
|
||||
while let Some(id) = parents.pop() {
|
||||
let step = match steps.remove(&id) {
|
||||
Some(step) => step,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let step_node = step.id();
|
||||
let step_parents = step.parent_nodes();
|
||||
|
||||
if visited.contains(&step_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
visited.insert(step_node);
|
||||
|
||||
for id in step_parents.iter() {
|
||||
if !visited.contains(id) {
|
||||
parents.push(*id);
|
||||
}
|
||||
}
|
||||
|
||||
callback(step_node, step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TraversalItem for StepBoxed {
|
||||
fn id(&self) -> NodeId {
|
||||
Step::node(self.as_ref())
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
Step::parents(self.as_ref())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! # Burn Autodiff
|
||||
//!
|
||||
//! This autodiff library is a part of the Burn project. It is a standalone crate
|
||||
//! that can be used to perform automatic differentiation on tensors. It is
|
||||
//! designed to be used with the Burn Tensor crate, but it can be used with any
|
||||
//! tensor library that implements the `Backend` trait.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Checkpoint module.
|
||||
pub mod checkpoint;
|
||||
/// Gradients module.
|
||||
pub mod grads;
|
||||
/// Operation module.
|
||||
pub mod ops;
|
||||
|
||||
pub(crate) mod graph;
|
||||
// Exported for backend extension
|
||||
pub use graph::NodeId;
|
||||
pub(crate) mod tensor;
|
||||
pub(crate) mod utils;
|
||||
|
||||
mod backend;
|
||||
|
||||
pub(crate) mod runtime;
|
||||
|
||||
pub use backend::*;
|
||||
|
||||
/// A facade around for HashMap and HashSet.
|
||||
/// This avoids elaborate import wrangling having to happen in every module.
|
||||
mod collections {
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use hashbrown::{HashMap, HashSet};
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::collections::{HashMap, HashSet};
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
Autodiff,
|
||||
checkpoint::{
|
||||
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,
|
||||
strategy::CheckpointStrategy,
|
||||
},
|
||||
grads::Gradients,
|
||||
graph::NodeId,
|
||||
ops::{Backward, Ops, OpsKind, unary},
|
||||
retro_unary,
|
||||
};
|
||||
use burn_backend::{Backend, ops::ActivationOps, tensor::FloatTensor};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodiff<B, C> {
|
||||
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct Gelu;
|
||||
|
||||
retro_unary!(RetroGelu, B::gelu);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for Gelu {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::gelu_backward(input, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Gelu
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroGelu::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::gelu(tensor.primitive.clone()))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct Relu;
|
||||
|
||||
retro_unary!(RetroRelu, B::relu);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for Relu {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let state = checkpointer.retrieve_node_output(ops.state);
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::relu_backward(state, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Relu
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroRelu::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::relu(tensor.primitive))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::relu(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct Sigmoid;
|
||||
|
||||
retro_unary!(RetroSigmoid, B::sigmoid);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for Sigmoid {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
let output = B::sigmoid(input);
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::sigmoid_backward(output, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Sigmoid
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSigmoid::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::sigmoid(tensor.primitive))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct LogSigmoid;
|
||||
|
||||
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for LogSigmoid {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::log_sigmoid_backward(input, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match LogSigmoid
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroLogSigmoid::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use super::{Ops, OpsPrep};
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, NodeRef, Requirement},
|
||||
utils::duplicate,
|
||||
};
|
||||
use burn_backend::Backend;
|
||||
|
||||
/// Trait for all operations.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Concrete types implementing this trait should not have any state.
|
||||
/// If a state is necessary during the backward pass,
|
||||
/// they should be declared with the associated type 'State'.
|
||||
pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
/// Associated type to compute the backward pass.
|
||||
type State: Clone + Send + core::fmt::Debug + 'static;
|
||||
|
||||
/// The backward pass.
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, N>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
);
|
||||
|
||||
/// Prepare the backward ops.
|
||||
fn prepare<C: CheckpointStrategy>(
|
||||
self,
|
||||
nodes: [NodeRef; N],
|
||||
) -> OpsPrep<Self, B, Self::State, C, N> {
|
||||
let requirement = Requirement::from_nodes(&nodes);
|
||||
OpsPrep::new(
|
||||
nodes,
|
||||
requirement,
|
||||
self,
|
||||
ComputingProperty::Ambiguous, // If not specified we start with ambiguous
|
||||
CheckpointerBuilder::default(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a binary operation during the backward step.
|
||||
pub fn binary<B, FLhs, FRhs>(
|
||||
parents: [Option<NodeRef>; 2],
|
||||
node: NodeRef,
|
||||
grads: &mut Gradients,
|
||||
func_lhs: FLhs,
|
||||
func_rhs: FRhs,
|
||||
) where
|
||||
B: Backend,
|
||||
FLhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
|
||||
FRhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
|
||||
{
|
||||
let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::<B>(&node)));
|
||||
let [node_lhs, node_rhs] = parents;
|
||||
|
||||
if let Some(node) = node_lhs {
|
||||
let grad = func_lhs(grad_4lhs.unwrap());
|
||||
grads.register::<B>(node.id, grad)
|
||||
}
|
||||
|
||||
if let Some(node) = node_rhs {
|
||||
let grad = func_rhs(grad_4rhs.unwrap());
|
||||
grads.register::<B>(node.id, grad)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a unary operation during the backward step.
|
||||
pub fn unary<B, F>(parents: [Option<NodeRef>; 1], node: NodeRef, grads: &mut Gradients, func: F)
|
||||
where
|
||||
B: Backend,
|
||||
F: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
|
||||
{
|
||||
let [parent_node] = parents;
|
||||
let grad = grads.consume::<B>(&node);
|
||||
|
||||
if let Some(node) = parent_node {
|
||||
let grad = func(grad);
|
||||
grads.register::<B>(node.id, grad)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
use super::Backward;
|
||||
use crate::{
|
||||
checkpoint::{
|
||||
base::Checkpointer,
|
||||
builder::{ActionType, CheckpointerBuilder},
|
||||
retro_forward::RetroForward,
|
||||
strategy::CheckpointStrategy,
|
||||
},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use alloc::boxed::Box;
|
||||
use burn_backend::{Backend, TensorMetadata, tensor::FloatTensor};
|
||||
use burn_std::Shape;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Operation in preparation.
|
||||
///
|
||||
/// Each mode has its own set of functions to minimize cloning for unused backward states.
|
||||
#[derive(new)]
|
||||
pub struct OpsPrep<Backward, B, S, C, const N: usize, Mode = Init> {
|
||||
nodes: [NodeRef; N],
|
||||
requirement: Requirement,
|
||||
backward: Backward,
|
||||
compute_property: ComputingProperty,
|
||||
checkpointer_builder: CheckpointerBuilder,
|
||||
checkpoint_strategy: PhantomData<C>,
|
||||
phantom_backend: PhantomData<B>,
|
||||
phantom_state: PhantomData<S>,
|
||||
marker: PhantomData<Mode>,
|
||||
}
|
||||
|
||||
/// Operation is initialized
|
||||
pub struct Init;
|
||||
/// Operation has been tagged as memory bound
|
||||
pub struct MemoryBound;
|
||||
/// Memory bound operation has received its RetroForward
|
||||
pub struct MemoryBoundRetroForward;
|
||||
/// Operation's compute property is fixed
|
||||
pub struct ComputePropertyDone;
|
||||
/// Tracked operation tag.
|
||||
pub struct Tracked;
|
||||
/// Untracked operation tag.
|
||||
pub struct UnTracked;
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Init>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Indicates that the operation is compute bound, meaning its computation
|
||||
/// is heavy and should not be recomputed
|
||||
pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
ComputingProperty::ComputeBound,
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
|
||||
/// Indicates that the operation is memory bound, meaning its computation
|
||||
/// is light and can be recomputed
|
||||
pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, N, MemoryBound> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBound>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = S>,
|
||||
C: CheckpointStrategy,
|
||||
{
|
||||
/// Registers the retro forward, if needed
|
||||
pub fn retro_forward<R: RetroForward>(
|
||||
self,
|
||||
retro_forward: R,
|
||||
) -> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
C::compute_property(retro_forward),
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = S>,
|
||||
C: CheckpointStrategy,
|
||||
{
|
||||
/// Checkpoints the parents, if needed
|
||||
pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
|
||||
{
|
||||
let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder)
|
||||
{
|
||||
Ok(..) => self.compute_property,
|
||||
Err(..) => ComputingProperty::ComputeBound,
|
||||
};
|
||||
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
compute_property,
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, C, const N: usize> OpsPrep<BO, B, (), C, N, ComputePropertyDone>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = ()>,
|
||||
{
|
||||
/// Prepare a stateless operation.
|
||||
pub fn stateless(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
|
||||
match self.stateful() {
|
||||
OpsKind::Tracked(prep) => prep.finish((), output),
|
||||
OpsKind::UnTracked(prep) => prep.finish(output),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + core::fmt::Debug + 'static,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Prepare an operation that requires a state during the backward pass.
|
||||
pub fn stateful(self) -> OpsKind<BO, B, S, C, N> {
|
||||
match self.requirement.is_none() {
|
||||
false => OpsKind::Tracked(OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
)),
|
||||
true => OpsKind::UnTracked(OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + core::fmt::Debug + 'static,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Finish the preparation of an untracked operation and returns the output tensor.
|
||||
pub fn finish(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
|
||||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&self.nodes,
|
||||
self.requirement,
|
||||
self.compute_property,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, output.node.clone(), ());
|
||||
|
||||
// We register the ops in the graph even if untracked, otherwise memory bound operations
|
||||
// that have an untracked parent would not be able to retrieve it
|
||||
output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + core::fmt::Debug + 'static,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Finish the preparation of a tracked operation and returns the output tensor.
|
||||
pub fn finish(self, state: S, output: FloatTensor<B>) -> AutodiffTensor<B> {
|
||||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&self.nodes,
|
||||
self.requirement,
|
||||
self.compute_property,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, output.node.clone(), state);
|
||||
|
||||
output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)
|
||||
}
|
||||
|
||||
/// Checkpoints the tensor
|
||||
pub fn checkpoint(&mut self, tensor: &AutodiffTensor<B>) -> NodeId {
|
||||
self.checkpointer_builder
|
||||
.checkpoint(tensor, ActionType::Explicit);
|
||||
|
||||
tensor.node.id
|
||||
}
|
||||
}
|
||||
|
||||
/// Enum used before finishing tracked and untracked operations.
|
||||
pub enum OpsKind<BO, B, S, C, const N: usize> {
|
||||
/// Tracked operation preparation.
|
||||
Tracked(OpsPrep<BO, B, S, C, N, Tracked>),
|
||||
/// Untracked operation preparation.
|
||||
UnTracked(OpsPrep<BO, B, S, C, N, UnTracked>),
|
||||
}
|
||||
|
||||
/// Operation containing its parent nodes, its own node and the backward step state.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Ops<S, const N: usize> {
|
||||
/// Parents nodes.
|
||||
pub parents: [Option<NodeRef>; N],
|
||||
/// The node.
|
||||
pub node: NodeRef,
|
||||
/// The state.
|
||||
pub state: S,
|
||||
}
|
||||
|
||||
/// Operation implementing backward [step](Step) with type erasing.
|
||||
#[derive(new, Debug)]
|
||||
struct OpsStep<B, T, SB, const N: usize>
|
||||
where
|
||||
B: Backend,
|
||||
T: Backward<B, N, State = SB>,
|
||||
SB: Clone + Send + core::fmt::Debug + 'static,
|
||||
{
|
||||
ops: Ops<SB, N>,
|
||||
backward: T,
|
||||
phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>
|
||||
where
|
||||
B: Backend,
|
||||
T: Backward<B, N, State = SB>,
|
||||
SB: Clone + Send + core::fmt::Debug + 'static,
|
||||
{
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
|
||||
self.backward.backward(self.ops, grads, checkpointer);
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeId {
|
||||
self.ops.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
&self.ops.node.parents
|
||||
}
|
||||
|
||||
fn depth(&self) -> usize {
|
||||
self.ops.node.order
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct UntrackedOpsStep<const N: usize> {
|
||||
ops: Ops<(), N>,
|
||||
}
|
||||
|
||||
impl<const N: usize> Step for UntrackedOpsStep<N> {
|
||||
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeId {
|
||||
self.ops.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
&self.ops.node.parents
|
||||
}
|
||||
fn depth(&self) -> usize {
|
||||
self.ops.node.order
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure the grad tensor has the given shape.
|
||||
///
|
||||
/// If broadcasting happened during the forward pass, the gradients will be sum along the
|
||||
/// broadcasted dimension.
|
||||
pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {
|
||||
let shape_grad = grad.shape();
|
||||
let ndims = shape_grad.num_dims();
|
||||
|
||||
for i in 0..ndims {
|
||||
if shape_grad[i] != shape[i] {
|
||||
if shape[i] != 1 {
|
||||
panic!(
|
||||
"Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}",
|
||||
shape, shape_grad, "Expected the shape of the next grad to be 1."
|
||||
);
|
||||
}
|
||||
grad = B::float_sum_dim(grad, i);
|
||||
}
|
||||
}
|
||||
|
||||
grad
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_backend::{
|
||||
Backend, ExecutionError, Scalar, TensorData,
|
||||
ops::BoolTensorOps,
|
||||
tensor::{BoolTensor, Device, IntTensor},
|
||||
};
|
||||
use burn_std::Shape;
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
||||
fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_from_data(data, device)
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<B>) -> Result<TensorData, ExecutionError> {
|
||||
B::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B> {
|
||||
B::bool_into_int(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<B>) -> Device<B> {
|
||||
B::bool_device(tensor)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
|
||||
B::bool_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_std::Slice]) -> BoolTensor<B> {
|
||||
B::bool_slice(tensor, slices)
|
||||
}
|
||||
|
||||
fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_empty(shape, device)
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_zeros(shape, device)
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_ones(shape, device)
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
slices: &[burn_std::Slice],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
|
||||
B::bool_cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_not(tensor)
|
||||
}
|
||||
|
||||
fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_and(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_or(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_xor(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
|
||||
AutodiffTensor::new(B::bool_into_float(tensor))
|
||||
}
|
||||
|
||||
fn bool_swap_dims(
|
||||
tensor: <Autodiff<B> as Backend>::BoolTensorPrimitive,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive {
|
||||
B::bool_swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
B::bool_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B> {
|
||||
B::bool_flip(tensor, axes)
|
||||
}
|
||||
|
||||
async fn bool_argwhere(tensor: BoolTensor<B>) -> IntTensor<B> {
|
||||
B::bool_argwhere(tensor).await
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
|
||||
B::bool_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
|
||||
B::bool_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
source: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_mask_where(tensor, mask, source)
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_mask_fill(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_scatter_or(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
B::bool_equal_elem(lhs, rhs)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_backend::{
|
||||
Backend, Distribution, ExecutionError, Scalar, TensorData,
|
||||
ops::IntTensorOps,
|
||||
tensor::{BoolTensor, Device, IntTensor},
|
||||
};
|
||||
use burn_std::{IntDType, Shape};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
||||
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
|
||||
B::int_from_data(data, device)
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: IntTensor<B>) -> Result<TensorData, ExecutionError> {
|
||||
B::int_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {
|
||||
B::int_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
|
||||
B::int_device(tensor)
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
|
||||
B::int_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_slice(tensor: IntTensor<B>, slices: &[burn_std::Slice]) -> IntTensor<B> {
|
||||
B::int_slice(tensor, slices)
|
||||
}
|
||||
|
||||
fn int_empty(
|
||||
shape: Shape,
|
||||
device: &<Autodiff<B> as Backend>::Device,
|
||||
dtype: IntDType,
|
||||
) -> IntTensor<B> {
|
||||
B::int_empty(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: IntTensor<B>,
|
||||
slices: &[burn_std::Slice],
|
||||
value: IntTensor<B>,
|
||||
) -> IntTensor<B> {
|
||||
B::int_slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
|
||||
B::int_clamp_min(tensor, min)
|
||||
}
|
||||
|
||||
fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
|
||||
B::int_clamp_max(tensor, max)
|
||||
}
|
||||
|
||||
fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
|
||||
B::int_clamp(tensor, min, max)
|
||||
}
|
||||
|
||||
fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_remainder(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_remainder_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_matmul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_neg(tensor)
|
||||
}
|
||||
|
||||
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
|
||||
B::int_zeros(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
|
||||
B::int_ones(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_full(
|
||||
shape: Shape,
|
||||
fill_value: Scalar,
|
||||
device: &Device<Self>,
|
||||
dtype: IntDType,
|
||||
) -> IntTensor<B> {
|
||||
B::int_full(shape, fill_value, device, dtype)
|
||||
}
|
||||
|
||||
fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_sum(tensor)
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_mean(tensor)
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cumsum(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cumprod(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cummin(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cummax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
|
||||
B::int_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_greater(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_greater_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_greater_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_greater_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_lower(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_lower_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_lower_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_lower_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter_add(
|
||||
dim: usize,
|
||||
tensor: IntTensor<B>,
|
||||
indices: IntTensor<B>,
|
||||
value: IntTensor<B>,
|
||||
) -> IntTensor<B> {
|
||||
B::int_scatter_add(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_select_add(
|
||||
tensor: IntTensor<B>,
|
||||
dim: usize,
|
||||
indices: IntTensor<B>,
|
||||
value: IntTensor<B>,
|
||||
) -> IntTensor<B> {
|
||||
B::int_select_add(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_mask_where(
|
||||
tensor: IntTensor<B>,
|
||||
mask: BoolTensor<B>,
|
||||
value: IntTensor<B>,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
|
||||
B::int_mask_where(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn int_mask_fill(
|
||||
tensor: IntTensor<B>,
|
||||
mask: BoolTensor<B>,
|
||||
value: Scalar,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
|
||||
B::int_mask_fill(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_argmax(tensor, dim)
|
||||
}
|
||||
fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_argmin(tensor, dim)
|
||||
}
|
||||
fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
|
||||
B::int_max(tensor)
|
||||
}
|
||||
fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
|
||||
B::int_max_dim(tensor, dim)
|
||||
}
|
||||
fn int_max_dim_with_indices(
|
||||
tensor: B::IntTensorPrimitive,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
|
||||
B::int_max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
|
||||
B::int_min(tensor)
|
||||
}
|
||||
fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
|
||||
B::int_min_dim(tensor, dim)
|
||||
}
|
||||
fn int_min_dim_with_indices(
|
||||
tensor: B::IntTensorPrimitive,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
|
||||
B::int_min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
|
||||
B::int_abs(tensor)
|
||||
}
|
||||
fn int_into_float(
|
||||
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
|
||||
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
|
||||
AutodiffTensor::new(B::int_into_float(tensor))
|
||||
}
|
||||
|
||||
fn int_swap_dims(
|
||||
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
|
||||
B::int_swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
B::int_random(shape, distribution, device)
|
||||
}
|
||||
|
||||
fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
|
||||
B::int_arange(range, device)
|
||||
}
|
||||
|
||||
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
B::int_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
B::int_flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::int_sign(tensor)
|
||||
}
|
||||
|
||||
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::int_prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
B::int_prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
|
||||
B::int_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
B::int_sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn int_sort_with_indices(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (IntTensor<Self>, IntTensor<Self>) {
|
||||
B::int_sort_with_indices(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
B::int_argsort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_and(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_and_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_or(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_or_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_xor(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_xor_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_not(tensor)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_left_shift(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_left_shift_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_right_shift(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_right_shift_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
B::int_cast(tensor, dtype)
|
||||
}
|
||||
|
||||
fn int_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
B::int_unfold(tensor, dim, size, step)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::{Backward, Ops, unary};
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
|
||||
use burn_backend::{Backend, TensorMetadata};
|
||||
use burn_std::Shape;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct MaxMinDim;
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for MaxMinDim {
|
||||
type State = (B::IntTensorPrimitive, Shape, usize);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let (indices, shape, dim) = ops.state;
|
||||
let device = B::float_device(&grad);
|
||||
let dtype = grad.dtype();
|
||||
let zeros = B::float_zeros(shape, &device, dtype.into());
|
||||
|
||||
B::float_scatter_add(dim, zeros, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
mod activation;
|
||||
mod backward;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
|
||||
pub(crate) mod maxmin;
|
||||
pub(crate) mod sort;
|
||||
|
||||
pub use backward::*;
|
||||
pub use base::*;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,106 @@
|
||||
use burn_backend::{
|
||||
Backend, ExecutionError, TensorData,
|
||||
ops::QTensorOps,
|
||||
tensor::{
|
||||
Device, FloatTensor, IntTensor, QuantizedTensor,
|
||||
quantization::QuantizationParametersPrimitive,
|
||||
},
|
||||
};
|
||||
use burn_std::{QuantScheme, Shape};
|
||||
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
|
||||
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
todo!() // required for QAT
|
||||
}
|
||||
|
||||
fn quantize_dynamic(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
|
||||
B::q_device(tensor)
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
B::q_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
B::q_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
_dim: usize,
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_slice(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_slices: &[burn_std::Slice],
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
B::q_argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
B::q_argmin(tensor, dim)
|
||||
}
|
||||
|
||||
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::{Backward, Ops, unary};
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
|
||||
use burn_backend::{Backend, TensorMetadata};
|
||||
use burn_std::Shape;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SortDim;
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for SortDim {
|
||||
type State = (B::IntTensorPrimitive, Shape, usize);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let (indices, shape, dim) = ops.state;
|
||||
let device = B::float_device(&grad);
|
||||
let dtype = grad.dtype();
|
||||
let zeros = B::float_zeros(shape, &device, dtype.into());
|
||||
|
||||
B::float_scatter_add(dim, zeros, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
use burn_backend::{
|
||||
Backend, ExecutionError,
|
||||
ops::{TransactionOps, TransactionPrimitive},
|
||||
};
|
||||
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C> {
|
||||
async fn tr_execute(
|
||||
transaction: TransactionPrimitive<Self>,
|
||||
) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
|
||||
B::tr_execute(TransactionPrimitive::new(
|
||||
transaction
|
||||
.read_floats
|
||||
.into_iter()
|
||||
.map(|t| t.primitive)
|
||||
.collect(),
|
||||
transaction.read_qfloats,
|
||||
transaction.read_ints,
|
||||
transaction.read_bools,
|
||||
))
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use crate::{
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::StepBoxed,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
};
|
||||
use burn_backend::Backend;
|
||||
|
||||
/// Client used to communicate with the autodiff server.
|
||||
pub trait AutodiffClient: Send + Clone {
|
||||
/// Register a new step.
|
||||
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
|
||||
/// Call backpropagation from the given tensor.
|
||||
fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
|
||||
}
|
||||
|
||||
/// Client implementation in used.
|
||||
pub type AutodiffClientImpl = super::graph::GraphMutexClient;
|
||||
@@ -0,0 +1,335 @@
|
||||
use super::{AutodiffClient, server::AutodiffServer};
|
||||
use crate::{
|
||||
NodeId,
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::{Parent, StepBoxed},
|
||||
runtime::server::NodeCleaner,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::Backend;
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use spin::{Mutex, MutexGuard};
|
||||
|
||||
/// A client for managing multiple graphs using mutex-based synchronization.
|
||||
///
|
||||
/// The biggest benefit of using this client implementation is that each graph can modify its own
|
||||
/// data without blocking other graphs, which is essential for multi-device training.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The [AutodiffServer] fully supports multiple graphs with sharing nodes, however those type of
|
||||
/// graphs will be stored under a single mutex-protected graph by the client, limiting
|
||||
/// parallelisation.
|
||||
#[derive(Clone, new, Debug)]
|
||||
pub struct GraphMutexClient;
|
||||
|
||||
/// Manages a collection of graphs, mapping [node ids](NodeId) to their respective graph.
|
||||
///
|
||||
/// The `GraphLocator` is responsible for selecting and merging graphs based on their IDs and parent
|
||||
/// dependencies, ensuring proper synchronization and server allocation.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Multiple node ids can point to the same graph, where the autodiff graph is stored.
|
||||
#[derive(Default)]
|
||||
pub struct GraphLocator {
|
||||
graphs: HashMap<NodeId, Arc<Graph>>,
|
||||
/// We keep a mapping of each original node id (graph id) => all nodes that point to that graph.
|
||||
/// This is to ensure that when merging graphs, we correctly move all previous graphs to
|
||||
/// the new merged one.
|
||||
keys: HashMap<NodeId, HashSet<NodeId>>,
|
||||
}
|
||||
|
||||
/// Represents a single computation graph with a mutex-protected server.
|
||||
///
|
||||
/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was
|
||||
/// first created.
|
||||
pub(crate) struct Graph {
|
||||
origin: NodeId,
|
||||
state: Mutex<GraphState>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GraphState {
|
||||
server: AutodiffServer,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for Graph {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.debug_struct("Graph")
|
||||
.field("origin", &self.origin)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
static STATE: Mutex<Option<GraphLocator>> = Mutex::new(None);
|
||||
|
||||
impl GraphMutexClient {
|
||||
/// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `node`: The unique identifier for the stream.
|
||||
/// - `parents`: A slice of parent nodes that the stream depends on.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `Arc<Graph>` representing the selected or newly created stream.
|
||||
fn graph(node: NodeId, parents: &[Parent]) -> Arc<Graph> {
|
||||
let mut state = STATE.lock();
|
||||
|
||||
match state.as_mut() {
|
||||
Some(locator) => locator.select(node, parents),
|
||||
None => {
|
||||
let mut locator = GraphLocator::default();
|
||||
let stream = locator.select(node, parents);
|
||||
*state = Some(locator);
|
||||
stream
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutodiffClient for GraphMutexClient {
|
||||
fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
let node_id = *node_id_ref;
|
||||
let graph = GraphMutexClient::graph(node_id, step.parents());
|
||||
let mut state = graph.state.lock();
|
||||
|
||||
state.server.register(node_id_ref, step, actions);
|
||||
}
|
||||
|
||||
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
|
||||
let node_id = root.node.id;
|
||||
let graph = GraphMutexClient::graph(root.node.id, &[]);
|
||||
|
||||
let grads = Gradients::new::<B>(root.node, root.primitive);
|
||||
let grads = {
|
||||
let mut state = graph.state.lock();
|
||||
state.server.backward::<GraphCleaner>(grads, node_id)
|
||||
}; // lock released
|
||||
|
||||
GraphCleaner::cleanup_orphaned_entries();
|
||||
|
||||
grads
|
||||
}
|
||||
}
|
||||
|
||||
struct GraphCleaner<'a> {
|
||||
guard: MutexGuard<'a, Option<GraphLocator>>,
|
||||
}
|
||||
|
||||
impl<'a> GraphCleaner<'a> {
|
||||
fn cleanup_orphaned_entries() {
|
||||
let graphs = {
|
||||
// Get the available graphs and release the lock
|
||||
match STATE.lock().as_ref() {
|
||||
Some(state) => state.graphs.clone(),
|
||||
None => return,
|
||||
}
|
||||
};
|
||||
|
||||
let mut should_remove = Vec::new();
|
||||
for graph in graphs.values() {
|
||||
{
|
||||
let mut guard = graph.state.lock();
|
||||
// Double safety: in case it was marked as no longer useful, but other
|
||||
// nodes are still relevant, we only check which nodes can safely be removed.
|
||||
if !guard.server.maybe_useful() {
|
||||
guard
|
||||
.server
|
||||
.free_unused_roots(|node| should_remove.push(*node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !should_remove.is_empty() {
|
||||
let mut state = STATE.lock();
|
||||
if let Some(state) = state.as_mut() {
|
||||
for node in should_remove {
|
||||
state.remove_entry(&node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> NodeCleaner for GraphCleaner<'a> {
|
||||
fn init() -> Self {
|
||||
let guard = STATE.lock();
|
||||
Self { guard }
|
||||
}
|
||||
|
||||
fn clean(&mut self, node: &NodeId) {
|
||||
if let Some(state) = self.guard.as_mut() {
|
||||
state.remove_entry(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphLocator {
|
||||
/// Selects a single graph for the given [NodeId], considering parent dependencies.
|
||||
///
|
||||
/// If multiple graphs are found, they are merged into a single one.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `node`: The node ID of the graph to select.
|
||||
/// - `parents`: A slice of parent nodes that the graph depends on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An `Arc<Graph>` representing the selected or merged graph.
|
||||
pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {
|
||||
match self.analyse(node, parents) {
|
||||
GraphAnalysis::NoCollision(graph) => {
|
||||
if graph.origin != node {
|
||||
self.graphs.insert(node, graph.clone());
|
||||
self.register_key(graph.origin, node);
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
GraphAnalysis::Collisions(graphs) => self.merge(node, graphs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyses the graph for a given node and its parents, returning the associated `GraphAnalysis`.
|
||||
fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalysis {
|
||||
// If no parents, there is no collision, therefore a single graph is ok.
|
||||
if parents.is_empty() {
|
||||
let graph = match self.graphs.get(&node) {
|
||||
Some(val) => val.clone(),
|
||||
None => self.new_graph(node),
|
||||
};
|
||||
return GraphAnalysis::NoCollision(graph);
|
||||
};
|
||||
|
||||
// We collect all graphs of parents and of the current node based on their origin node id.
|
||||
let mut graphs = HashMap::<NodeId, Arc<Graph>>::new();
|
||||
|
||||
if let Some(val) = self.graphs.get(&node) {
|
||||
graphs.insert(val.origin, val.clone());
|
||||
}
|
||||
|
||||
for parent in parents {
|
||||
match self.graphs.get(&parent.id) {
|
||||
Some(graph) => graphs.insert(graph.origin, graph.clone()),
|
||||
None => continue,
|
||||
};
|
||||
}
|
||||
|
||||
if graphs.is_empty() {
|
||||
return match self.graphs.get(&node) {
|
||||
Some(old) => GraphAnalysis::NoCollision(old.clone()),
|
||||
None => GraphAnalysis::NoCollision(self.new_graph(node)),
|
||||
};
|
||||
}
|
||||
|
||||
if graphs.len() == 1 {
|
||||
return GraphAnalysis::NoCollision(graphs.drain().next().unwrap().1);
|
||||
}
|
||||
|
||||
GraphAnalysis::Collisions(graphs)
|
||||
}
|
||||
|
||||
/// Merges multiple graphs associated with a node into a single graph.
|
||||
fn merge(&mut self, node: NodeId, mut graphs: HashMap<NodeId, Arc<Graph>>) -> Arc<Graph> {
|
||||
let mut graphs = graphs.drain().map(|g| g.1);
|
||||
|
||||
let main = graphs.next().expect("At least one graph");
|
||||
self.register_key(main.origin, node);
|
||||
|
||||
let mut state = main.state.lock();
|
||||
|
||||
for graph in graphs {
|
||||
self.merge_two(&mut state, &main, graph);
|
||||
}
|
||||
|
||||
self.graphs.insert(main.origin, main.clone());
|
||||
self.graphs.insert(node, main.clone());
|
||||
|
||||
core::mem::drop(state);
|
||||
|
||||
main
|
||||
}
|
||||
|
||||
/// Registers a key for a given origin node.
|
||||
fn register_key(&mut self, origin: NodeId, key: NodeId) {
|
||||
if !self.keys.contains_key(&origin) {
|
||||
// Ensure an entry exists for this origin
|
||||
self.keys.insert(origin, HashSet::new());
|
||||
}
|
||||
|
||||
if origin != key {
|
||||
// Register this node to point to the origin graph
|
||||
self.keys.get_mut(&origin).unwrap().insert(key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Merges two graphs by combining their states and updating graph mappings.
|
||||
fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>, merged: Arc<Graph>) {
|
||||
let mut locked = merged.state.lock();
|
||||
let mut state_old = GraphState::default();
|
||||
core::mem::swap(&mut state_old, &mut locked);
|
||||
main_state.server.extend(state_old.server);
|
||||
|
||||
// Re-map merged origin to the main graph
|
||||
self.graphs.insert(merged.origin, main.clone());
|
||||
|
||||
// Move all keys (node IDs) from the merged graph to the main graph
|
||||
if let Some(locator_keys) = self.keys.remove(&merged.origin) {
|
||||
for k in locator_keys.iter() {
|
||||
self.graphs.insert(*k, main.clone());
|
||||
}
|
||||
|
||||
let locator_keys_main = self
|
||||
.keys
|
||||
.get_mut(&main.origin)
|
||||
.expect("Should be init before the merge.");
|
||||
locator_keys_main.extend(locator_keys);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new graph for a given node.
|
||||
fn new_graph(&mut self, origin: NodeId) -> Arc<Graph> {
|
||||
let graph = Arc::new(Graph {
|
||||
origin,
|
||||
state: Mutex::new(GraphState::default()),
|
||||
});
|
||||
self.graphs.insert(origin, graph.clone());
|
||||
self.keys.insert(origin, HashSet::new());
|
||||
graph
|
||||
}
|
||||
|
||||
fn remove_entry(&mut self, node: &NodeId) {
|
||||
if let Some(graph) = self.graphs.remove(node) {
|
||||
let mut remove = false;
|
||||
|
||||
if let Some(entry) = self.keys.get_mut(&graph.origin) {
|
||||
entry.remove(node);
|
||||
if entry.is_empty() {
|
||||
remove = true;
|
||||
}
|
||||
}
|
||||
|
||||
if remove {
|
||||
self.keys.remove(&graph.origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the analysis result of graph operations for a given node and its parents.
|
||||
#[derive(Debug)]
|
||||
enum GraphAnalysis {
|
||||
/// No collision detected, contains the graph associated with the node.
|
||||
NoCollision(Arc<Graph>),
|
||||
/// Collision detected, contains a map of node IDs to their associated graphs.
|
||||
Collisions(HashMap<NodeId, Arc<Graph>>),
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
use crate::{
|
||||
NodeId,
|
||||
collections::{HashMap, HashSet},
|
||||
graph::Parent,
|
||||
tensor::NodeRefCount,
|
||||
};
|
||||
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
|
||||
use core::mem;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct GraphMemoryManagement {
|
||||
nodes: HashMap<NodeRefCount, Vec<NodeId>>,
|
||||
leaves: HashSet<NodeId>,
|
||||
statuses: HashMap<NodeId, NodeMemoryStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NodeMemoryStatus {
|
||||
Useful,
|
||||
Unavailable,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl GraphMemoryManagement {
|
||||
pub fn extend(&mut self, other: Self) {
|
||||
self.nodes.extend(other.nodes);
|
||||
self.leaves.extend(other.leaves);
|
||||
self.statuses.extend(other.statuses);
|
||||
}
|
||||
|
||||
/// Register a new node with its parent.
|
||||
pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
|
||||
let node_id = *node.as_ref();
|
||||
|
||||
for parent in parents.iter() {
|
||||
self.leaves.remove(&parent.id);
|
||||
}
|
||||
|
||||
self.leaves.insert(node_id);
|
||||
self.nodes
|
||||
.insert(node, parents.iter().map(|p| p.id).collect());
|
||||
}
|
||||
|
||||
/// Free the node from the state.
|
||||
pub fn consume_node(&mut self, node_id: NodeId) {
|
||||
if !self.is_referenced(node_id) {
|
||||
self.leaves.remove(&node_id);
|
||||
self.nodes.remove(&node_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Free all nodes whose backward call has become impossible
|
||||
///
|
||||
/// This function goes into three steps, which must happen for all leaves
|
||||
/// before going into the next step. Then it deletes what can be safely deleted
|
||||
pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
|
||||
let leaves = self.leaves.clone();
|
||||
let mut new_leaves = HashSet::new();
|
||||
let mut deletables = Vec::new();
|
||||
|
||||
// When consuming nodes with a backward pass, some other backward passes become
|
||||
// unavailable because some of their parents have been consumed. They are
|
||||
// identified here.
|
||||
for leaf in leaves.clone() {
|
||||
self.unavailable_propagation(leaf);
|
||||
}
|
||||
|
||||
// Among the available nodes that remain, some may be useless if no
|
||||
// available node with a tensor reference exist in their descendance.
|
||||
// But some may seem useless from some leaf but be useful from another one,
|
||||
// hence the need to iterate on all leaves.
|
||||
self.useful_propagation(leaves.clone());
|
||||
|
||||
// New leaves are the roots of a useful backward sub-tree.
|
||||
// Deletables are everything not marked as useful.
|
||||
for leaf in leaves {
|
||||
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
|
||||
}
|
||||
|
||||
// Replace leaves by the new ones and delete everything not useful anymore
|
||||
mem::swap(&mut self.leaves, &mut new_leaves);
|
||||
|
||||
self.clear_unused_roots(&mut deletables);
|
||||
|
||||
self.statuses.clear();
|
||||
for node_to_delete in deletables {
|
||||
self.nodes.remove(&node_to_delete);
|
||||
on_free_graph(&node_to_delete)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
|
||||
let mut deletables = Vec::new();
|
||||
self.clear_unused_roots(&mut deletables);
|
||||
|
||||
for node_id in deletables {
|
||||
self.nodes.remove(&node_id);
|
||||
on_free_graph(&node_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
|
||||
for (id, parents) in self.nodes.iter() {
|
||||
let is_useful = matches!(
|
||||
self.statuses.get(id.as_ref()),
|
||||
Some(NodeMemoryStatus::Useful)
|
||||
);
|
||||
|
||||
// Check if parents are either empty or absent from self.nodes
|
||||
let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
|
||||
|
||||
if !is_useful && Arc::strong_count(id) == 1 && parents_absent {
|
||||
to_delete.push(*id.as_ref())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {
|
||||
// If already visited
|
||||
if let Some(status) = self.statuses.get(&node_id) {
|
||||
return status.clone();
|
||||
}
|
||||
|
||||
match self.nodes.get(&node_id).cloned() {
|
||||
// If node exists and any of its parents is unavailable, it is unavailable as well
|
||||
// If node exists but the parents vec is empty, it is a tensor that never had parents;
|
||||
// the status remains unknown
|
||||
Some(parents) => {
|
||||
let mut node_status = NodeMemoryStatus::Unknown;
|
||||
for parent in parents {
|
||||
let parent_status = self.unavailable_propagation(parent);
|
||||
if let NodeMemoryStatus::Unavailable = parent_status {
|
||||
node_status = NodeMemoryStatus::Unavailable;
|
||||
}
|
||||
}
|
||||
self.statuses.insert(node_id, node_status.clone());
|
||||
node_status
|
||||
}
|
||||
// If node does not exist, it was
|
||||
// deleted, so this and all its descendants are unavailable
|
||||
None => {
|
||||
self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
|
||||
NodeMemoryStatus::Unavailable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
|
||||
// Accumulate visited nodes
|
||||
let mut explored = HashSet::new();
|
||||
let mut tagged_useful = HashSet::new();
|
||||
|
||||
// Queue of nodes to visit
|
||||
let mut to_tag_useful = PopNodeSet::default();
|
||||
let mut to_explore = PopNodeSet::new(leaves);
|
||||
|
||||
// Utility function to iterate over a node's parents
|
||||
let parents = |node_id| {
|
||||
self.nodes
|
||||
.get(&node_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
};
|
||||
|
||||
loop {
|
||||
// Pop a node id, greedily looking at tag_useful ones first
|
||||
let (node_id, status) = match to_tag_useful.pop() {
|
||||
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
|
||||
None => match to_explore.pop() {
|
||||
Some(node_id) => {
|
||||
let node_status = self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("All nodes should have received a status during unavailable_propagation")
|
||||
.to_owned();
|
||||
|
||||
if let NodeMemoryStatus::Unknown = node_status {
|
||||
match self.is_referenced(node_id) {
|
||||
true => (node_id, NodeMemoryStatus::Useful),
|
||||
false => (node_id, NodeMemoryStatus::Unknown),
|
||||
}
|
||||
} else {
|
||||
(node_id, node_status)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// There are no nodes in the queues anymore
|
||||
break;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
match status {
|
||||
NodeMemoryStatus::Useful => {
|
||||
tagged_useful.insert(node_id);
|
||||
for parent in parents(node_id) {
|
||||
// The node can be explored, as long as it's not already tagged useful
|
||||
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
|
||||
to_tag_useful.insert(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
explored.insert(node_id);
|
||||
for parent in parents(node_id) {
|
||||
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
|
||||
to_explore.insert(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.statuses.insert(node_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
fn identify_leaves_and_deletables(
|
||||
&self,
|
||||
leaf_id: NodeId,
|
||||
new_leaves: &mut HashSet<NodeId>,
|
||||
to_delete: &mut Vec<NodeId>,
|
||||
) {
|
||||
let mut visited = HashSet::new();
|
||||
let mut to_visit = vec![leaf_id];
|
||||
|
||||
while let Some(node_id) = to_visit.pop() {
|
||||
visited.insert(node_id);
|
||||
|
||||
match self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("Node should have status")
|
||||
{
|
||||
NodeMemoryStatus::Useful => {
|
||||
new_leaves.insert(node_id);
|
||||
}
|
||||
_ => {
|
||||
to_delete.push(node_id);
|
||||
|
||||
for parent in self
|
||||
.nodes
|
||||
.get(&node_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
{
|
||||
if !visited.contains(&parent) {
|
||||
to_visit.push(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn is_referenced(&self, node_id: NodeId) -> bool {
|
||||
match self.nodes.get_key_value(&node_id) {
|
||||
Some((key, _value)) => Arc::strong_count(key) > 1,
|
||||
None => panic!("Node should be in the nodes map"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_useful(&self) -> bool {
|
||||
self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper over hash set for fast popping of any node
|
||||
#[derive(new, Default)]
|
||||
struct PopNodeSet {
|
||||
hash_set: HashSet<NodeId>,
|
||||
}
|
||||
|
||||
impl PopNodeSet {
|
||||
#[inline(always)]
|
||||
fn pop(&mut self) -> Option<NodeId> {
|
||||
self.hash_set
|
||||
.iter()
|
||||
.next()
|
||||
.copied()
|
||||
.and_then(|node_id| self.hash_set.take(&node_id))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn contains(&self, node_id: &NodeId) -> bool {
|
||||
self.hash_set.contains(node_id)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn insert(&mut self, node_id: NodeId) {
|
||||
self.hash_set.insert(node_id);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
mod client;
|
||||
mod memory_management;
|
||||
mod server;
|
||||
|
||||
pub mod graph;
|
||||
pub use client::*;
|
||||
@@ -0,0 +1,143 @@
|
||||
use super::memory_management::GraphMemoryManagement;
|
||||
use crate::{
|
||||
NodeId,
|
||||
checkpoint::{
|
||||
base::{Checkpointer, NodeTree},
|
||||
builder::CheckpointerBuilder,
|
||||
},
|
||||
collections::HashMap,
|
||||
grads::Gradients,
|
||||
graph::{StepBoxed, traversal::BreadthFirstSearch},
|
||||
tensor::NodeRefCount,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AutodiffServer {
|
||||
steps: HashMap<NodeId, StepBoxed>,
|
||||
actions_builder: HashMap<NodeId, CheckpointerBuilder>,
|
||||
memory_management: GraphMemoryManagement,
|
||||
}
|
||||
|
||||
/// Defines how nodes are clean.
|
||||
pub trait NodeCleaner {
|
||||
/// Initialize a new cleaner.
|
||||
fn init() -> Self;
|
||||
/// Cleans a single [node](NodeId).
|
||||
fn clean(&mut self, node: &NodeId);
|
||||
}
|
||||
|
||||
impl AutodiffServer {
|
||||
pub fn extend(&mut self, other: AutodiffServer) {
|
||||
self.steps.extend(other.steps);
|
||||
self.actions_builder.extend(other.actions_builder);
|
||||
self.memory_management.extend(other.memory_management);
|
||||
}
|
||||
|
||||
pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
let parents = step.parents();
|
||||
let node_id = *rc.as_ref();
|
||||
|
||||
self.memory_management.register(rc, parents);
|
||||
|
||||
self.steps.insert(node_id, step);
|
||||
self.actions_builder.insert(node_id, actions);
|
||||
}
|
||||
|
||||
pub fn backward<NC: NodeCleaner>(&mut self, grads: Gradients, node_id: NodeId) -> Gradients {
|
||||
let step = self.steps.remove(&node_id).expect(
|
||||
"Node should have a step registered, did you forget to call \
|
||||
`Tensor::register_grad` on the tensor where you need gradients?",
|
||||
);
|
||||
let builder = self.actions_builder.remove(&node_id).unwrap();
|
||||
|
||||
let mut consumed = Vec::new();
|
||||
let (tape, checkpointer) = self.build_tape(node_id, step, builder, &mut consumed);
|
||||
|
||||
let gradients = Self::execute_steps(tape, grads, checkpointer);
|
||||
|
||||
// Cleanup
|
||||
let mut cleaner = NC::init();
|
||||
self.memory_management
|
||||
.free_unavailable_nodes(|node_id: &NodeId| {
|
||||
self.steps.remove(node_id);
|
||||
self.actions_builder.remove(node_id);
|
||||
NC::clean(&mut cleaner, node_id);
|
||||
});
|
||||
for node_id in consumed {
|
||||
cleaner.clean(&node_id)
|
||||
}
|
||||
|
||||
gradients
|
||||
}
|
||||
|
||||
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
|
||||
self.memory_management.free_unused_roots(|node_id| {
|
||||
self.steps.remove(node_id);
|
||||
self.actions_builder.remove(node_id);
|
||||
on_free_graph(node_id);
|
||||
});
|
||||
}
|
||||
|
||||
fn build_tape(
|
||||
&mut self,
|
||||
node: NodeId,
|
||||
node_step: StepBoxed,
|
||||
mut builder: CheckpointerBuilder,
|
||||
consumed: &mut Vec<NodeId>,
|
||||
) -> (Vec<Vec<StepBoxed>>, Checkpointer) {
|
||||
let mut tape = (0..node_step.depth())
|
||||
.map(|_| Vec::with_capacity(1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut tree = HashMap::default();
|
||||
|
||||
BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {
|
||||
self.memory_management.consume_node(id);
|
||||
// Clean up consumed node
|
||||
consumed.push(id);
|
||||
|
||||
let depth = step.depth();
|
||||
|
||||
if depth == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(steps) = tape.get_mut(depth - 1) {
|
||||
let parents = step.parents().iter().map(|p| p.id).filter(|s| *s != id);
|
||||
tree.insert(id, parents.collect());
|
||||
steps.push(step);
|
||||
}
|
||||
|
||||
if let Some(node_builder) = self.actions_builder.remove(&id) {
|
||||
builder.extend(node_builder);
|
||||
}
|
||||
});
|
||||
|
||||
let checkpointer = builder.build(NodeTree::new(tree));
|
||||
|
||||
(tape, checkpointer)
|
||||
}
|
||||
|
||||
fn execute_steps(
|
||||
tape: Vec<Vec<StepBoxed>>,
|
||||
mut grads: Gradients,
|
||||
mut checkpointer: Checkpointer,
|
||||
) -> Gradients {
|
||||
tape.into_iter().rev().for_each(|steps| {
|
||||
steps
|
||||
.into_iter()
|
||||
.for_each(|step| step.step(&mut grads, &mut checkpointer))
|
||||
});
|
||||
|
||||
// For checkpointing tests
|
||||
#[cfg(feature = "export_tests")]
|
||||
assert!(checkpointer.is_empty());
|
||||
|
||||
grads
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_useful(&self) -> bool {
|
||||
self.memory_management.maybe_useful()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, Node, NodeId, NodeRef, Parent, Requirement, Step},
|
||||
runtime::{AutodiffClient, AutodiffClientImpl},
|
||||
};
|
||||
use alloc::{boxed::Box, sync::Arc, vec};
|
||||
use burn_backend::{Backend, TensorMetadata};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutodiffTensor<B: Backend> {
|
||||
pub primitive: B::FloatTensorPrimitive,
|
||||
pub node: NodeRef,
|
||||
pub rc: NodeRefCount,
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorMetadata for AutodiffTensor<B> {
|
||||
fn dtype(&self) -> burn_std::DType {
|
||||
self.primitive.dtype()
|
||||
}
|
||||
|
||||
fn shape(&self) -> burn_std::Shape {
|
||||
self.primitive.shape()
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.primitive.rank()
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeRefCount = Arc<NodeId>;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub(crate) struct RootStep {
|
||||
node: NodeRef,
|
||||
}
|
||||
|
||||
impl Step for RootStep {
|
||||
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeId {
|
||||
self.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
&self.node.parents
|
||||
}
|
||||
|
||||
fn depth(&self) -> usize {
|
||||
self.node.order
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> AutodiffTensor<B> {
|
||||
/// Create a new leaf tensor.
|
||||
pub fn new(primitive: B::FloatTensorPrimitive) -> Self {
|
||||
let id = NodeId::new();
|
||||
let node: NodeRef = Node::new(
|
||||
vec![],
|
||||
0,
|
||||
id,
|
||||
Requirement::None,
|
||||
ComputingProperty::Ambiguous,
|
||||
AutodiffClientImpl::new(),
|
||||
)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
rc: Arc::new(node.id),
|
||||
primitive,
|
||||
node: node.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tracked(&self) -> bool {
|
||||
!self.node.requirement.is_none()
|
||||
}
|
||||
|
||||
/// Mark the tensor as requiring gradients.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// It panics if the tensor is not a leaf.
|
||||
pub fn require_grad(mut self) -> Self {
|
||||
match self.node.requirement {
|
||||
Requirement::Grad => self,
|
||||
Requirement::GradInBackward => {
|
||||
panic!("Can't convert a non leaf tensor into a tracked tensor")
|
||||
}
|
||||
Requirement::None => {
|
||||
self.node = Node::new(
|
||||
vec![],
|
||||
0,
|
||||
self.node.id,
|
||||
Requirement::Grad,
|
||||
self.node.properties.clone(),
|
||||
self.node.client.clone(),
|
||||
)
|
||||
.into();
|
||||
let step = RootStep::new(self.node.clone());
|
||||
|
||||
self.register_step(step, CheckpointerBuilder::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tensor from parent infos.
|
||||
pub fn from_parents(
|
||||
primitive: B::FloatTensorPrimitive,
|
||||
parent_nodes: &[NodeRef],
|
||||
requirement: Requirement,
|
||||
computing_properties: ComputingProperty,
|
||||
) -> Self {
|
||||
let order = parent_nodes
|
||||
.iter()
|
||||
.map(|node| node.order)
|
||||
.reduce(usize::max)
|
||||
.unwrap_or(0)
|
||||
+ 1;
|
||||
|
||||
let client = parent_nodes
|
||||
.first()
|
||||
.map(|node| node.client.clone())
|
||||
.unwrap_or_else(AutodiffClientImpl::new);
|
||||
|
||||
let node: NodeRef = Node::new(
|
||||
parent_nodes
|
||||
.iter()
|
||||
.filter_map(|node| node.clone_if_require_grad())
|
||||
.map(|node| Parent::new(node.id))
|
||||
.collect(),
|
||||
order,
|
||||
NodeId::new(),
|
||||
requirement,
|
||||
computing_properties,
|
||||
client,
|
||||
)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
rc: Arc::new(node.id),
|
||||
primitive,
|
||||
node,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a step into a graph for that tensor.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This should be called only once per tensor.
|
||||
pub fn register_step<S: Step + 'static>(
|
||||
self,
|
||||
step_that_created_the_tensor: S,
|
||||
actions: CheckpointerBuilder,
|
||||
) -> Self {
|
||||
self.node.client.register(
|
||||
self.rc.clone(),
|
||||
Box::new(step_that_created_the_tensor),
|
||||
actions,
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn into_primitive(self) -> B::FloatTensorPrimitive {
|
||||
self.primitive
|
||||
}
|
||||
|
||||
pub fn backward(self) -> Gradients {
|
||||
let client = self.node.client.clone();
|
||||
|
||||
AutodiffClient::backward::<B>(&client, self)
|
||||
}
|
||||
|
||||
pub fn grad(&self, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
|
||||
grads.get::<B>(self)
|
||||
}
|
||||
|
||||
pub fn grad_remove(&self, grads: &mut Gradients) -> Option<B::FloatTensorPrimitive> {
|
||||
grads.remove::<B>(self)
|
||||
}
|
||||
|
||||
pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPrimitive) {
|
||||
grads.remove::<B>(self);
|
||||
grads.register::<B>(self.node.id, grad);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use crate::graph::NodeRef;
|
||||
/// Duplicate the given object for each node that requires gradients.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is useful since you don't have to keep N cloned references alive event if just 1 node
|
||||
/// will be updated.
|
||||
///
|
||||
/// If the object is a tensor and if one reference exists, it can be updated inplace.
|
||||
pub fn duplicate<T: Clone + core::fmt::Debug, const N: usize>(
|
||||
nodes: &[Option<NodeRef>; N],
|
||||
obj: Option<T>,
|
||||
) -> [Option<T>; N] {
|
||||
nodes
|
||||
.iter()
|
||||
.map(|node| match node {
|
||||
Some(_) => obj.clone(),
|
||||
None => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
Reference in New Issue
Block a user